本文共 1850 字,大约阅读时间需要 6 分钟。
import matplotlib.pyplot as pltimport numpy as npdef get_file(acc_list,loss_list,filename): f1 = open(filename, 'r') for line in f1.readlines(): # print(line) line = line.strip('\n') str_list = line.split(':') loss = float(str_list[-2].split(',')[0]) acc = float(str_list[-1]) acc_list.append(acc) loss_list.append(loss)baseline_acc=[]baseline_loss=[]dgc_acc=[]dgc_loss=[]ip_acc=[]ip_loss=[]rand_4_dgc_acc=[]rand_4_dgc_loss=[]# 获取acc,loss数据get_file(baseline_acc,baseline_loss,'train_baseline_resnet50.txt')get_file(dgc_acc,dgc_loss,'train_dgc_resnet50.txt')get_file(ip_acc,ip_loss,'train_ip_resnet50.txt')get_file(rand_4_dgc_acc,rand_4_dgc_loss,'train_rand_4_dgc_resnet50.txt')plt.figure()# 横坐标轴x = [i for i in range(0,88)]# 在绘制时设置lable, 逗号是必须的# l1, = plt.plot(x, baseline_loss[0:88], label = 'line', color = 'b', linewidth = 1.0, linestyle = '-')# l2, = plt.plot(x, dgc_loss[2:90], label = 'line', color = 'g', linewidth = 1.0, linestyle = '-')# l3, = plt.plot(x, ip_loss[2:90], label = 'line', color = 'r', linewidth = 1.0, linestyle = '-')# l4, = plt.plot(x, rand_4_dgc_loss[2:90], label = 'line', color = 'c', linewidth = 1.0, linestyle = '-')l1, = plt.plot(x, baseline_acc[0:88], label = 'line', color = 'b', linewidth = 1.0, linestyle = '-')l2, = plt.plot(x, dgc_acc[2:90], label = 'line', color = 'g', linewidth = 1.0, linestyle = '-')l3, = plt.plot(x, ip_acc[2:90], label = 'line', color = 'r', linewidth = 1.0, linestyle = '-')l4, = plt.plot(x, rand_4_dgc_acc[2:90], label = 'line', color = 'c', linewidth = 1.0, linestyle = '-')# 设置坐标轴的lableplt.xlabel('epoch')# plt.ylabel('loss')plt.ylabel('accuracy')# 设置legendplt.legend(handles = [l1, l2, l3,l4], labels = ['baseline', 'dgc','ip','rand_4_dgc'], loc = 'best')plt.grid()# plt.show()# plt.savefig('loss.jpg')plt.savefig('accuracy.jpg')
转载地址:http://xfwmi.baihongyu.com/