网站首页 > 博客文章 正文
接上篇:模型评估与模型监控——混淆矩阵
ROC曲线
ROC(Receiver Operating Char)曲线尝试在不同的cut-off下计算出模型的TPR和FPR,以TPR为纵轴,FPR为横轴绘制出ROC曲线。
ROC曲线体现了模型对好客户和坏客户的覆盖程度,首先当cut-off取0时,模型将全部客户预测为坏客户,因故坏客户召回率为1,好客户为0;当cut-off取1时,模型将全部客户预测为好客户,因故坏客户召回率为0,好客户为1。当cut-off在0-1之间变化时,好坏客户召回率会不断变化,对应图中曲线。若随机预测,则好坏客户召回率呈线性变化,对应图中对角线。所以使用曲线下面积AUC来衡量模型预测的覆盖情况。AUC越大说明模型效果越好,反之则不好。
计算示例:
假设我们通过训练集训练了一个二分类模型,在测试集上进行预测每个样本所属的类别,输出了属于类别”1“的概率。现在假设当P>=0.5时,预测的类标签为”1“。
1.导入相关库
import pandas as pd import matplotlib.pyplot as plt import numpy as np %matplotlib inline #测试样本的数量 parameter=40
2.随机生成结果集
data=pd.DataFrame(index=range(0,parameter),columns=('probability','The true label')) data['The true label']=np.random.randint(0,2,size=len(data)) data['probability']=np.random.choice(np.arange(0.1,1,0.1),len(data['probability']))
结果如下:
3.计算混淆矩阵
cm=np.arange(4).reshape(2,2) cm[0,0]=len(data[data['The ture label']==0][data['probability']<0.5]) #TN cm[0,1]=len(data[data['The ture label']==0][data['probability']>=0.5])#FP cm[1,0]=len(data[data['The ture label']==1][data['probability']<0.5]) #FN cm[1,1]=len(data[data['The ture label']==1][data['probability']>=0.5])#TP
4.计算假正率和真正率
首先,画出混淆矩阵。
import itertools classes = [0,1] plt.figure() plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) plt.title('Confusion matrix') tick_marks = np.arange(len(classes)) plt.xticks(tick_marks, classes, rotation=0) plt.yticks(tick_marks, classes) thresh = cm.max() / 2. for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): plt.text(j, i, cm[i, j], horizontalalignment="center", color="white" if cm[i, j] > thresh else "black") plt.tight_layout() plt.ylabel('True label') plt.xlabel('Predicted label')
然后,threshold=0.5上的假正率和真正率容易计算,为: FPR=6/(5+6)=0.55,TPR=13/(13+6)=0.68
5.ROC曲线和AUC值
ROC曲线是一系列threshold下的(FPR,TPR)数值点的连线。此时的threshold的取值分别为测试数据集中各样本的预测概率。但取各个概率的顺序是从大到小的。
5.1 按概率值排序
首先,按预测概率从大到小的顺序排序:
data.sort_values('probability',inplace=True,ascending=False)
排序结果如下:
此时,threshold依次取0.9,0.9,0.9,0.9,0.9,0.9,0.8,0.8,0.7,...。 比如,当threshold=0.9(第3个0.9),一个”0“预测错误,两个”1“预测正确,FPR=1/11=0.09,TPR=2/19=0.11。 当threshold=0.9(第5个0.9),一个”0“预测错误,四个”1“预测正确,FPR=1/11=0.09,TPR=4/19=0.21。 当threshold=0.6(第1个0.6),三个”0“预测错误,九个”1“预测正确,FPR=3/11=0.27,TPR=9/19=0.47。
5.2 计算全部概率值下的FPR和TPR
TPRandFPR=pd.DataFrame(index=range(len(data)),columns=('TP','FP')) for j in range(len(data)): data1=data.head(n=j+1) FP=len(data1[data1['The ture label']==0] [data1['probability']>=data1.head(len(data1))['probabi lity']]) /float(len(data[data['The ture label']==0])) TP=len(data1[data1['The ture label']==1][data1['probability']>=data1.head(len(data1))['probability']]) /float(len(data[data['The ture label']==1])) TPRandFPR.iloc[j]=[TP,FP]
最后,(FPR,TPR)点矩阵如下:
5.3 画出最终的ROC曲线和计算AUC值
from sklearn.metrics import auc AUC= auc(TPRandFPR['FP'],TPRandFPR['TP']) plt.scatter(x=TPRandFPR['FP'],y=TPRandFPR['TP'],label='(FPR,TPR)',color='k') plt.plot(TPRandFPR['FP'], TPRandFPR['TP'], 'k',label='AUC = %0.2f'% AUC) plt.legend(loc='lower right') plt.title('Receiver Operating Characteristic') plt.plot([(0,0),(1,1)],'r--') plt.xlim([-0.01,1.01]) plt.ylim([-0.01,01.01]) plt.ylabel('True Positive Rate') plt.xlabel('False Positive Rate') plt.show()
下图的黑色线即为ROC曲线,测试样本中的数据点越多,曲线越平滑:
AUC(Area Under roc Cure),顾名思义,其就是ROC曲线小的面积,在此例子中AUC=0.62。AUC越大,说明分类效果越好。
碎片时间,关注收藏。
- 上一篇: 为机器学习模型选择正确的度量评估(第二部分)
- 下一篇: R数据分析:ROC曲线与模型评价实例
猜你喜欢
- 2024-10-29 模型评估(一)(你应该知道的模型评估的五个方法)
- 2024-10-29 精度是远远不够的:如何最好地评估一个分类器?
- 2024-10-29 Python机器学习理论与实战 第二章 Logistic回归模型(下)
- 2024-10-29 R数据分析:ROC曲线与模型评价实例
- 2024-10-29 为机器学习模型选择正确的度量评估(第二部分)
- 2024-10-29 从另外一个角度解释AUC(从另一个角度去看问题)
- 2024-10-29 机器学习算法评估方法+Spring学习笔记
- 2024-10-29 你真的了解模型评估与选择嘛(模型评价的术语)
- 2024-10-29 机器学习中的评价指标(机器学习模型的评价指标有哪些)
- 2024-10-29 python机器学习:分类问题学习模型的评价方法及代码实现
你 发表评论:
欢迎- 06-23MySQL合集-mysql5.7及mysql8的一些特性
- 06-23MySQL CREATE TABLE 简单设计模板交流
- 06-23MYSQL表设计规范(mysql设计表注意事项)
- 06-23MySQL数据库入门(四)数据类型简介
- 06-23数据丢失?别慌!MySQL备份恢复攻略
- 06-23MySQL设计规范(mysql 设计)
- 06-23MySQL数据实时增量同步到Elasticsearch
- 06-23MySQL 避坑指南之隐式数据类型转换
- 最近发表
- 标签列表
-
- powershellfor (55)
- messagesource (56)
- aspose.pdf破解版 (56)
- promise.race (63)
- 2019cad序列号和密钥激活码 (62)
- window.performance (66)
- qt删除文件夹 (72)
- mysqlcaching_sha2_password (64)
- ubuntu升级gcc (58)
- nacos启动失败 (64)
- ssh-add (70)
- jwt漏洞 (58)
- macos14下载 (58)
- yarnnode (62)
- abstractqueuedsynchronizer (64)
- source~/.bashrc没有那个文件或目录 (65)
- springboot整合activiti工作流 (70)
- jmeter插件下载 (61)
- 抓包分析 (60)
- idea创建mavenweb项目 (65)
- vue回到顶部 (57)
- qcombobox样式表 (68)
- vue数组concat (56)
- tomcatundertow (58)
- pastemac (61)
本文暂时没有评论,来添加一个吧(●'◡'●)