混淆矩阵绘制


from sklearn.metrics import classification_report,confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
#定义混淆矩阵绘制函数
def plot_confusion_matrix(cm, labels_name, title):
plt.figure(figsize=(8,6))
#归一化
#cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
plt.imshow(cm, interpolation='nearest') # 在特定的窗口上显示图像
plt.title(title) # 图像标题
plt.colorbar()
num_local = np.array(range(len(labels_name)))
plt.xticks(num_local, labels_name, rotation=0) # 将标签印在x轴坐标上
plt.yticks(num_local, labels_name) # 将标签印在y轴坐标上
thresh = cm.max() / 2.
iters = np.reshape([[[i, j] for j in range(len(labels_name))] for i in range(len(labels_name))], (cm.size, 2))
for i, j in iters:
plt.text(j, i, format(cm[i, j]), horizontalalignment="center")
plt.ylabel('True label')
plt.xlabel('Predicted label')
if __name__ == "__main__":
#true为真实标签,predict为预测标签
confusion=confusion_matrix(true,predict)
#labels_name为类别名
labels_name=[i for i in range(10)]
plot_confusion_matrix(confusion, labels_name, "Confusion Matrix")
结果展示: