混淆矩阵绘制

作者: 小墨 分类: 数据预处理方法 发布时间: 2021-08-25 17:36 访问量:13,378
FavoriteLoading收藏
from sklearn.metrics import classification_report,confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
#定义混淆矩阵绘制函数
def plot_confusion_matrix(cm, labels_name, title):
    plt.figure(figsize=(8,6))
    #归一化
    #cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]  
    plt.imshow(cm, interpolation='nearest')    # 在特定的窗口上显示图像
    plt.title(title)    # 图像标题
    plt.colorbar()
    num_local = np.array(range(len(labels_name)))    
    plt.xticks(num_local, labels_name, rotation=0)    # 将标签印在x轴坐标上
    plt.yticks(num_local, labels_name)    # 将标签印在y轴坐标上
    thresh = cm.max() / 2.
    iters = np.reshape([[[i, j] for j in range(len(labels_name))] for i in range(len(labels_name))], (cm.size, 2))
    for i, j in iters:
        plt.text(j, i, format(cm[i, j]), horizontalalignment="center")  
    plt.ylabel('True label')    
    plt.xlabel('Predicted label')

if __name__ == "__main__":
    #true为真实标签,predict为预测标签
    confusion=confusion_matrix(true,predict)
    #labels_name为类别名
    labels_name=[i for i in range(10)]
    plot_confusion_matrix(confusion, labels_name, "Confusion Matrix")

结果展示:

     

如果觉得小墨的文章对您有用,请随意打赏。您的支持将鼓励我继续创作!

发表评论