def draw_RR_curve(): class_num = 3 vaild_ttrans = Transforms.Compose([ # Transforms.RandomVerticalFlip(p=0.5), # transforms.RandomRotation(30), # transforms.RandomCrop(100), # transforms.RandomResizedCrop(112), # Transforms.ColorJitter(brightness=0.5), # transforms.RandomErasing(p=0.2, scale=(0.02, 0.03), ratio=(0.3, 0.3), value=0, ), Transforms.Resize((24, 24)), Transforms.ToTensor(), Transforms.Normalize((0.45, 0.448, 0.455), (0.082, 0.082, 0.082)), ]) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") txt_path = '/media/omnisky/D4T/JSH/faceFenlei/eye/mbhlk_hl_0128/mix_train.txt' vaild_txt = '/media/omnisky/D4T/JSH/faceFenlei/eye/mbhlk_hl_0128/mix_valid.txt' # model mixnet = MixNet(input_size=(24,24), num_classes=3) weight_dict = torch.load("/media/omnisky/D4T/JSH/faceFenlei/Projects/hul_eye_class/weight/relabel_04_mix_SGD_mutillabel_24_24_20210302/Mixnet_epoch_49.pth") new_state_dict = OrderedDict() for k, v in weight_dict.items(): name = k[7:] new_state_dict[name] = v mixnet.load_state_dict(new_state_dict) # stat(net,(3,48,48)) mixnet.to('cuda:0') mixnet.eval() vaild_data = mbhk_get_signal_eye(vaild_txt,vaild_ttrans) valid_data_loader = DataLoader(vaild_data,batch_size=128,shuffle=False,num_workers=12) score_list = [] #存储预测得分 label_list = [] #存储真实标签 with torch.no_grad(): for imgs,labels,_ in valid_data_loader: for timg in imgs: test_result = mixnet(timg.cuda()) # result = torch.max(test_result,1)[1] result = torch.nn.functional.softmax(test_result,dim=1) score_list.extend(result.cpu().numpy()) label_list.extend(labels.cpu().numpy()) label_list = np.array(label_list).reshape((-1,1)) score_list = np.array(score_list).reshape((-1,3)) plot_RR(label_list,score_list,['0','1','2'])
def main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_trans = Transforms.Compose([ Transforms.Resize((24, 24)), Transforms.ToTensor(), Transforms.Normalize((0.45, 0.448, 0.455), (0.082, 0.082, 0.082)), ]) test_data = mbhk_get_signal_eye(img_path, transform=data_trans) test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=False, num_workers=4) # 定义模型 输入图像24*48 # net = MixNet(input_size=(24,48), num_classes=3) # weight_dict = torch.load("weight/change_mix_data_0202/Mixnet_epoch_59.pth") net = MixNet(input_size=(24, 24), num_classes=3) weight_dict = torch.load( "weight/mix_mbhk_change_signal_eye_24_24/Mixnet_epoch_59.pth") new_state_dict = OrderedDict() for k, v in weight_dict.items(): name = k[7:] new_state_dict[name] = v net.load_state_dict(new_state_dict) # stat(net,(3,48,48)) net.to(device) net.eval() acc = 0.0 val_num = len(test_data) #创建文件夹存储错误识别的图片 class_img = ["close_eye", "open_eye", "other"] error_class_mbhk_img_path = "./error_class/mbhk_img" error_class_change_img_path = "./error_class/change_img" if not os.path.exists(error_class_mbhk_img_path): os.makedirs(error_class_mbhk_img_path) if not os.path.exists(error_class_change_img_path): os.makedirs(error_class_change_img_path) for tp in class_img: if not os.path.exists(os.path.join(error_class_mbhk_img_path, tp)): os.mkdir(os.path.join(error_class_mbhk_img_path, tp)) if not os.path.exists(os.path.join(error_class_change_img_path, tp)): os.mkdir(os.path.join(error_class_change_img_path, tp)) #创建错误日志保存到txt文件 error_lod_mbhk = open( os.path.join(error_class_mbhk_img_path, "error_mbhk_log.txt"), 'w') error_lod_change = open( os.path.join(error_class_change_img_path, "error_change_log.txt"), 'w') with torch.no_grad(): for i in tqdm(range(len(test_data))): # img,res_label,[img_path,json_path] timg, label, tpath = test_data.__getitem__(i) # timg = timg.unsqueeze(0) #用于计数,防止重复操作 count = 0 for img in timg: #增加维度 img = img.unsqueeze(0) # label = label.unsqueeze(0) outputs = net(img.to(device)) result = torch.max(outputs, 1)[1] if result.item() != label and count == 0: count += 1 if result == 0: if "/imge" in tpath[0]: #则说明是mbhk数据 shutil.copy( tpath[0], os.path.join(error_class_mbhk_img_path, "open_eye")) shutil.copy( tpath[1], os.path.join(error_class_mbhk_img_path, "open_eye")) error_lod_mbhk.write("{} {} {}\n".format( tpath[0], label_idx[result.item()], label_idx[label])) else: error_lod_change.write("{} {} {}\n".format( tpath[0], label_idx[result.item()], label_idx[label])) shutil.copy( tpath[0], os.path.join(error_class_change_img_path, "open_eye")) shutil.copy( tpath[1], os.path.join(error_class_change_img_path, "open_eye")) elif result == 1: if "/imge" in tpath[0]: #则说明是mbhk数据 shutil.copy( tpath[1], os.path.join(error_class_mbhk_img_path, "close_eye")) shutil.copy( tpath[0], os.path.join(error_class_mbhk_img_path, "close_eye")) error_lod_mbhk.write("{} {} {}\n".format( tpath[0], label_idx[result.item()], label_idx[label])) else: error_lod_change.write("{} {} {}\n".format( tpath[0], label_idx[result.item()], label_idx[label])) shutil.copy( tpath[0], os.path.join(error_class_change_img_path, "close_eye")) shutil.copy( tpath[1], os.path.join(error_class_change_img_path, "close_eye")) elif result == 2: if "/imge" in tpath[0]: #则说明是mbhk数据 shutil.copy( tpath[1], os.path.join(error_class_mbhk_img_path, "other")) shutil.copy( tpath[0], os.path.join(error_class_mbhk_img_path, "other")) error_lod_mbhk.write("{} {} {}\n".format( tpath[0], label_idx[result.item()], label_idx[label])) else: error_lod_change.write("{} {} {}\n".format( tpath[0], label_idx[result.item()], label_idx[label])) shutil.copy( tpath[0], os.path.join(error_class_change_img_path, "other")) shutil.copy( tpath[1], os.path.join(error_class_change_img_path, "other")) error_lod_mbhk.close() error_lod_change.close()
def main(): class_num = 3 vaild_ttrans = Transforms.Compose([ # Transforms.RandomVerticalFlip(p=0.5), # transforms.RandomRotation(30), # transforms.RandomCrop(100), # transforms.RandomResizedCrop(112), # Transforms.ColorJitter(brightness=0.5), # transforms.RandomErasing(p=0.2, scale=(0.02, 0.03), ratio=(0.3, 0.3), value=0, ), Transforms.Resize((24, 24)), Transforms.ToTensor(), Transforms.Normalize((0.45, 0.448, 0.455), (0.082, 0.082, 0.082)), ]) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") txt_path = '/media/omnisky/D4T/JSH/faceFenlei/eye/mbhlk_hl_0128/mix_train.txt' vaild_txt = '/media/omnisky/D4T/JSH/faceFenlei/eye/mbhlk_hl_0128/mix_valid.txt' # model mixnet = MixNet(input_size=(24,24), num_classes=3) weight_dict = torch.load("/media/omnisky/D4T/JSH/faceFenlei/Projects/hul_eye_class/weight/relabel_04_mix_SGD_mutillabel_24_24_20210302/Mixnet_epoch_49.pth") new_state_dict = OrderedDict() for k, v in weight_dict.items(): name = k[7:] new_state_dict[name] = v mixnet.load_state_dict(new_state_dict) # stat(net,(3,48,48)) mixnet.to('cuda:0') mixnet.eval() vaild_data = mbhk_get_signal_eye(vaild_txt,vaild_ttrans) valid_data_loader = DataLoader(vaild_data,batch_size=128,shuffle=False,num_workers=12) score_list = [] #存储预测得分 label_list = [] #存储真实标签 with torch.no_grad(): for imgs,labels,_ in valid_data_loader: for timg in imgs: test_result = mixnet(timg.cuda()) # result = torch.max(test_result,1)[1] result = torch.nn.functional.softmax(test_result,dim=1) score_list.extend(result.cpu().numpy()) label_list.extend(torch.nn.functional.one_hot(labels,num_classes=3).numpy()) tlabel_list = np.array(label_list).reshape((-1,3)) tscore_list = np.array(score_list).reshape((-1,3)) # 调用sklearn,计算每个类别对应的fpr和tpr fpr_dict = dict() tpr_dict = dict() roc_auc_dict = dict() for i in range(class_num): fpr_dict[i],tpr_dict[i],_ = roc_curve(tlabel_list[:,i],tscore_list[:,i]) roc_auc_dict[i] = auc(fpr_dict[i],tpr_dict[i]) # Compute micro-average ROC curve and ROC area fpr_dict["micro"],tpr_dict["micro"],_ = roc_curve(tlabel_list.ravel(),tscore_list.ravel()) roc_auc_dict["micro"] = auc(fpr_dict["micro"],tpr_dict["micro"]) #绘制所有类别平均的roc曲线 # macro # First aggregate all false positive rates all_fpr = np.unique(np.concatenate([fpr_dict[i] for i in range(class_num)])) # Then interpolate all ROC curves at this points mean_ptr = np.zeros_like(all_fpr) for i in range(class_num): mean_ptr += interp(all_fpr,fpr_dict[i],tpr_dict[i]) # Finally average it and compute AUC mean_ptr /= class_num fpr_dict["macro"] = all_fpr tpr_dict["macro"] = mean_ptr roc_auc_dict["macro"] = auc(fpr_dict['macro'],tpr_dict["macro"]) plt.figure() lw = 2 plt.plot(fpr_dict["micro"], tpr_dict["micro"], label='micro-average ROC curve (area = {0:0.2f})' ''.format(roc_auc_dict["micro"]), color='deeppink', linestyle=':', linewidth=4) plt.plot(fpr_dict["macro"], tpr_dict["macro"], label='macro-average ROC curve (area = {0:0.2f})' ''.format(roc_auc_dict["macro"]), color='navy', linestyle=':', linewidth=4) colors = cycle(['aqua', 'darkorange', 'cornflowerblue']) for i, color in zip(range(class_num), colors): plt.plot(fpr_dict[i], tpr_dict[i], color=color, lw=lw, label='ROC curve of class {0} (area = {1:0.2f})' ''.format(i, roc_auc_dict[i])) plt.plot([0, 1], [0, 1], 'k--', lw=lw) plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('Some extension of Receiver operating characteristic to multi-class') plt.legend(loc="lower right") plt.savefig('set113_roc.jpg') plt.show()