def evaluate(data, model, device, class_no): model.eval() with torch.no_grad(): # f1 = 0 test_iou = 0 test_h_dist = 0 recall = 0 precision = 0 # # for index in evaluate_index: for j, (testimg, testlabel, testimgname) in enumerate(data): # =========================================================== # =========================================================== testimg = testimg.to(device=device, dtype=torch.float32) testlabel = testlabel.to(device=device, dtype=torch.float32) testoutput = model(testimg) if class_no == 2: # testoutput = torch.sigmoid(testoutput) testoutput = (testoutput > 0.5).float() # else: # _, testoutput = torch.max(testoutput, dim=1) # # mean_iu_ = segmentation_scores(testlabel.cpu().detach().numpy(), testoutput.cpu().detach().numpy(), class_no) mean_iu_ = intersectionAndUnion(testoutput.cpu().detach(), testlabel.cpu().detach(), class_no) f1_, recall_, precision_ = f1_score(testlabel.cpu().detach().numpy(), testoutput.cpu().detach().numpy(), class_no) f1 += f1_ test_iou += mean_iu_ recall += recall_ precision += precision_ # return test_iou / len(evaluate_index), f1 / len(evaluate_index), recall / len(evaluate_index), precision / len(evaluate_index) return test_iou / (j + 1), f1 / (j + 1), recall / (j + 1), precision / (j + 1)
def trainSingleModel(model, model_name, num_epochs, learning_rate, datasettag, train_dataset, train_batchsize, trainloader, validateloader, testdata, reverse_mode, lr_schedule, class_no): # change log names training_amount = len(train_dataset) iteration_amount = training_amount // train_batchsize iteration_amount = iteration_amount - 1 device = torch.device('cuda') lr_str = str(learning_rate) epoches_str = str(num_epochs) save_model_name = model_name + '_' + datasettag + '_e' + epoches_str + '_lr' + lr_str saved_information_path = './Results' try: os.mkdir(saved_information_path) except OSError as exc: if exc.errno != errno.EEXIST: raise pass saved_information_path = saved_information_path + '/' + save_model_name try: os.mkdir(saved_information_path) except OSError as exc: if exc.errno != errno.EEXIST: raise pass saved_model_path = saved_information_path + '/trained_models' try: os.mkdir(saved_model_path) except OSError as exc: if exc.errno != errno.EEXIST: raise pass print('The current model is:') print(save_model_name) print('\n') writer = SummaryWriter('./Results/Log_' + datasettag + '/' + save_model_name) model.to(device) threshold = torch.tensor([0.5], dtype=torch.float32, device=device, requires_grad=False) upper = torch.tensor([1.0], dtype=torch.float32, device=device, requires_grad=False) lower = torch.tensor([0.0], dtype=torch.float32, device=device, requires_grad=False) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-5) if lr_schedule is True: # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.2, patience=10, threshold=0.001) scheduler = lr_scheduler.MultiStepLR( optimizer, milestones=[num_epochs // 2, 3 * num_epochs // 4], gamma=0.1) start = timeit.default_timer() for epoch in range(num_epochs): model.train() train_iou = [] train_loss = [] # j: index of iteration for j, (images, labels, imagename) in enumerate(trainloader): optimizer.zero_grad() images = images.to(device=device, dtype=torch.float32) if class_no == 2: labels = labels.to(device=device, dtype=torch.float32) else: labels = labels.to(device=device, dtype=torch.long) outputs = model(images) if class_no == 2: prob_outputs = torch.sigmoid(outputs) loss = dice_loss(prob_outputs, labels) class_outputs = torch.where(prob_outputs > threshold, upper, lower) else: prob_outputs = torch.softmax(outputs, dim=1) # loss = nn.CrossEntropyLoss(reduction='mean')(prob_outputs, labels) loss = nn.CrossEntropyLoss(reduction='mean')(prob_outputs, labels.squeeze(1)) _, class_outputs = torch.max(outputs, dim=1) loss.backward() optimizer.step() mean_iu_, _, __ = segmentation_scores(labels, class_outputs, class_no) train_iou.append(mean_iu_) train_loss.append(loss.item()) if lr_schedule is True: # scheduler.step(validate_iou) scheduler.step() else: pass model.eval() with torch.no_grad(): validate_iou = [] validate_f1 = [] validate_h_dist = [] for i, (val_images, val_label, imagename) in enumerate(validateloader): val_img = val_images.to(device=device, dtype=torch.float32) if class_no == 2: val_label = val_label.to(device=device, dtype=torch.float32) else: val_label = val_label.to(device=device, dtype=torch.long) assert torch.max(val_label) != 100.0 val_outputs = model(val_img) if class_no == 2: val_class_outputs = torch.sigmoid(val_outputs) val_class_outputs = (val_class_outputs > 0.5).float() else: val_class_outputs = torch.softmax(val_outputs, dim=1) _, val_class_outputs = torch.max(val_class_outputs, dim=1) # b, c, h, w = val_label.size() # val_class_outputs = val_class_outputs.reshape(b, c, h, w) eval_mean_iu_, _, __ = segmentation_scores( val_label, val_class_outputs, class_no) eval_f1_, eval_recall_, eval_precision_, eTP, eTN, eFP, eFN, eP, eN = f1_score( val_label, val_class_outputs, class_no) validate_iou.append(eval_mean_iu_) validate_f1.append(eval_f1_) if (val_class_outputs == 1).sum() > 1 and ( val_label == 1).sum() > 1 and class_no == 2: v_dist_ = hd95(val_class_outputs, val_label, class_no) validate_h_dist.append(v_dist_) print('Step [{}/{}], ' 'Train loss: {:.4f}, ' 'Train iou: {:.4f}, ' 'val iou:{:.4f}, '.format(epoch + 1, num_epochs, np.nanmean(train_loss), np.nanmean(train_iou), np.nanmean(validate_iou))) writer.add_scalars( 'acc metrics', { 'train iou': np.nanmean(train_iou), 'val iou': np.nanmean(validate_iou), 'val f1': np.nanmean(validate_f1) }, epoch + 1) if epoch > num_epochs - 10: save_model_name_full = saved_model_path + '/epoch' + str(epoch) save_model_name_full = save_model_name_full + '.pt' path_model = save_model_name_full torch.save(model, path_model) test(testdata, saved_model_path, device, reverse_mode=reverse_mode, class_no=class_no, save_path=saved_model_path) # save model stop = timeit.default_timer() print('Time: ', stop - start) print('\nTraining finished and model saved\n') return model
def test(testdata, models_path, device, reverse_mode, class_no, save_path): all_models = glob.glob(os.path.join(models_path, '*.pt')) # with torch.no_grad(): test_f1 = [] test_iou = [] test_h_dist = [] test_acc = [] test_w_acc = [] test_recall = [] test_precision = [] test_bf = [] test_iou_adv = [] test_h_dist_adv = [] for model in all_models: model = torch.load(model) model.eval() for j, (testimg, testlabel, testname) in enumerate(testdata): # validate batch size will be set up as 2 # testimg = torch.from_numpy(testimg).to(device=device, dtype=torch.float32) # testlabel = torch.from_numpy(testlabel).to(device=device, dtype=torch.float32) testimg = testimg.to(device=device, dtype=torch.float32) if class_no == 2: testlabel = testlabel.to(device=device, dtype=torch.float32) else: testlabel = testlabel.to(device=device, dtype=torch.long) # b, c, h, w = testimg.size() # testimg = testimg[:, 0, :, :].view(b, 1, h, w).contiguous() # testlabel = testlabel[:, 0, :, :].view(b, 1, h, w).contiguous() if torch.max(testlabel) == 255.: testlabel = testlabel / 255. testimg.requires_grad = True threshold = torch.tensor([0.5], dtype=torch.float32, device=device, requires_grad=False) upper = torch.tensor([1.0], dtype=torch.float32, device=device, requires_grad=False) lower = torch.tensor([0.0], dtype=torch.float32, device=device, requires_grad=False) # c, h, w = testimg.size() # testimg = testimg.expand(1, c, h, w) testoutput = model(testimg) if class_no == 2: prob_testoutput = torch.sigmoid(testoutput) testoutput = (prob_testoutput > 0.5).float() else: prob_testoutput = torch.softmax(testoutput, dim=1) _, testoutput = torch.max(prob_testoutput, dim=1) # attack testing data: if class_no == 2: loss = dice_loss(prob_testoutput, testlabel) else: loss = nn.CrossEntropyLoss(reduction='mean')( prob_testoutput, testlabel.squeeze(1)) model.zero_grad() loss.backward() data_grad = testimg.grad.data perturbed_data = fgsm_attack(testimg, 0.2, data_grad) output_attack = model(perturbed_data) if class_no == 2: output_attack = torch.sigmoid(output_attack) output_attack = (output_attack > 0.5).float() else: output_attack = torch.softmax(output_attack, dim=1) _, output_attack = torch.max(output_attack, dim=1) mean_iu_, acc_, w_acc_ = segmentation_scores( testlabel, testoutput, class_no) test_iou.append(mean_iu_) test_acc.append(acc_) test_w_acc.append(w_acc_) mean_iu_adv_, _, __ = segmentation_scores(testlabel, output_attack, class_no) test_iou_adv.append(mean_iu_adv_) if (testoutput == 1).sum() > 1 and ( testlabel == 1).sum() > 1 and class_no == 2: h_dis95_ = hd95(testoutput, testlabel, class_no) test_h_dist.append(h_dis95_) if (output_attack == 1).sum() > 1 and ( testlabel == 1).sum() > 1 and class_no == 2: h_dis95_attack_ = hd95(output_attack, testlabel, class_no) test_h_dist_adv.append(h_dis95_attack_) f1_, recall_, precision_, TP, TN, FP, FN, P, N = f1_score( testlabel, testoutput, class_no) bf_ = 2 * precision_ * recall_ / (recall_ + precision_) test_f1.append(f1_) test_recall.append(recall_) test_precision.append(precision_) test_bf.append(bf_) prediction_map_path = save_path + '/Test' try: os.mkdir(prediction_map_path) except OSError as exc: if exc.errno != errno.EEXIST: raise pass result_dictionary = { 'Test IoU mean': str(np.mean(test_iou)), 'Test IoU std': str(np.std(test_iou)), 'Test Acc mean': str(np.mean(test_acc)), 'Test Acc std': str(np.std(test_acc)), 'Test W ACC mean': str(np.mean(test_w_acc)), 'Test W ACC std': str(np.std(test_w_acc)), 'Test BF mean': str(np.mean(test_bf)), 'Test BF std': str(np.std(test_bf)), 'Test f1 mean': str(np.mean(test_f1)), 'Test f1 std': str(np.std(test_f1)), 'Test H-dist mean': str(np.mean(test_h_dist)), 'Test H-dist std': str(np.std(test_h_dist)), 'Test precision mean': str(np.mean(test_precision)), 'Test precision std': str(np.std(test_precision)), 'Test recall mean': str(np.mean(test_recall)), 'Test recall std': str(np.std(test_recall)), 'Test IoU attack mean': str(np.mean(test_iou_adv)), 'Test IoU attack std': str(np.std(test_iou_adv)), 'Test H-dist attack mean': str(np.mean(test_h_dist_adv)), 'Test H-dist attack std': str(np.std(test_h_dist_adv)), } ff_path = prediction_map_path + '/test_result_data.txt' ff = open(ff_path, 'w') ff.write(str(result_dictionary)) ff.close() print('Test h-dist: {:.4f}, ' 'Test iou: {:.4f}, '.format(np.mean(test_h_dist), np.mean(test_iou)))
def evaluate(evaluatedata, model, device, reverse_mode, class_no): model.eval() f1 = 0 test_iou = 0 test_h_dist = 0 recall = 0 precision = 0 FPs_Ns = 0 FNs_Ps = 0 FPs_Ps = 0 FNs_Ns = 0 TPs = 0 TNs = 0 FNs = 0 FPs = 0 Ps = 0 Ns = 0 test_iou_attack = 0 test_h_dist_attack = 0 effective_h = 0 effective_h_attack = 0 for j, (testimg, testlabel, testname) in enumerate(evaluatedata): # validate batch size will be set up as 2 # j will be close enough to the # testimg = testimg.to(device=device, dtype=torch.float32) testimg = testimg.to(device=device, dtype=torch.float32) testlabel = testlabel.to(device=device, dtype=torch.float32) # b, c, h, w = testimg.size() # testimg = testimg[:, 0, :, :].view(b, 1, h, w).contiguous() # testlabel = testlabel[:, 0, :, :].view(b, 1, h, w).contiguous() # if torch.max(testlabel) == 255.: # testlabel = testlabel / 255. testimg.requires_grad = True # testlabel = testlabel.to(device=device, dtype=torch.float32) threshold = torch.tensor([0.5], dtype=torch.float32, device=device, requires_grad=False) upper = torch.tensor([1.0], dtype=torch.float32, device=device, requires_grad=False) lower = torch.tensor([0.0], dtype=torch.float32, device=device, requires_grad=False) testoutput = model(testimg) prob_testoutput = torch.sigmoid(testoutput) # attack testing data: loss = dice_loss(prob_testoutput, testlabel) model.zero_grad() loss.backward() data_grad = testimg.grad.data perturbed_data = fgsm_attack(testimg, 0.2, data_grad) output_attack = model(perturbed_data) output_attack = torch.sigmoid(output_attack) if reverse_mode is True: testoutput = torch.where(prob_testoutput < threshold, upper, lower) output_attack = torch.where(output_attack < threshold, upper, lower) else: testoutput = torch.where(prob_testoutput > threshold, upper, lower) output_attack = torch.where(output_attack > threshold, upper, lower) mean_iu_, _, __ = segmentation_scores(testlabel, testoutput, class_no) mean_iu_attack_, _, __ = segmentation_scores(testlabel, output_attack, class_no) if (testoutput == 1).sum() > 1 and (testlabel == 1).sum() > 1: h_dis95_ = hd95(testoutput, testlabel, class_no) test_h_dist += h_dis95_ effective_h = effective_h + 1 if (output_attack == 1).sum() > 1 and (testlabel == 1).sum() > 1: h_dis95_attack_ = hd95(output_attack, testlabel, class_no) effective_h_attack = effective_h_attack + 1 test_h_dist_attack += h_dis95_attack_ f1_, recall_, precision_, TP, TN, FP, FN, P, N = f1_score( testlabel, testoutput, class_no) f1 += f1_ test_iou += mean_iu_ recall += recall_ precision += precision_ TPs += TP TNs += TN FPs += FP FNs += FN Ps += P Ns += N FNs_Ps += (FN + 1e-10) / (P + 1e-10) FPs_Ns += (FP + 1e-10) / (N + 1e-10) FNs_Ns += (FN + 1e-10) / (N + 1e-10) FPs_Ps += (FP + 1e-10) / (P + 1e-10) test_iou_attack += mean_iu_attack_ return test_iou / (j + 1), f1 / (j + 1), recall / (j + 1), precision / ( j + 1), FPs_Ns / (j + 1), FPs_Ps / (j + 1), FNs_Ns / (j + 1), FNs_Ps / ( j + 1), FPs / (j + 1), FNs / (j + 1), TPs / (j + 1), TNs / ( j + 1), Ps / (j + 1), Ns / (j + 1), test_h_dist / ( effective_h + 1), test_iou_attack / ( j + 1), test_h_dist_attack / (effective_h_attack + 1)
def test1(data_1, model, device, class_no, save_location): model.eval() data_1_testoutputs = [] with torch.no_grad(): f1_1 = 0 test_iou_1 = 0 # test_h_dist_1 = 0 recall_1 = 0 precision_1 = 0 mse_1 = 0 # ============================================== evaluate_index_all_1 = range(0, len(data_1) - 1) # # ============================================== # evaluate_index_all_2 = range(0, len(data_2) - 1) # for j, (testimg, testlabel, testimgname) in enumerate(data_1): # extract a few random indexs every time in a range of the data # ======================================================================== # ======================================================================== testimg = torch.from_numpy(testimg).to(device=device, dtype=torch.float32) testlabel = torch.from_numpy(testlabel).to(device=device, dtype=torch.float32) c, h, w = testimg.size() testimg = testimg.expand(1, c, h, w) testoutput_original = model(testimg) if class_no == 2: testoutput = torch.sigmoid(testoutput_original.view(1, h, w)) testoutput = (testoutput > 0.5).float() data_1_testoutputs.append(testoutput) # else: # _, testoutput = torch.max(testoutput_original, dim=1) # mean_iu_ = intersectionAndUnion(testoutput.cpu().detach(), testlabel.cpu().detach(), class_no) f1_, recall_, precision_ = f1_score(testlabel.cpu().detach().numpy(), testoutput.cpu().detach().numpy(), class_no) mse_ = (np.square(testlabel.cpu().detach().numpy() - testoutput.cpu().detach().numpy())).mean() f1_1 += f1_ test_iou_1 += mean_iu_ recall_1 += recall_ precision_1 += precision_ mse_1 += mse_ # # # Plotting segmentation: # testoutput_original = np.asarray(testoutput_original.cpu().detach().numpy(), dtype=np.uint8) # testoutput_original = np.squeeze(testoutput_original, axis=0) # testoutput_original = np.repeat(testoutput_original[:, :, np.newaxis], 3, axis=2) # # # if class_no == 2: # segmentation_map = np.zeros((h, w, 3), dtype=np.uint8) # # # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 255 # segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 0 # segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 0 # # # else: # segmentation_map = np.zeros((h, w, 3), dtype=np.uint8) # if class_no == 4: # # multi class for brats 2018 # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 255 # segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 0 # segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 0 # # # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 2, testoutput_original[:, :, 1] == 2, testoutput_original[:, :, 2] == 2)] = 0 # segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 2, testoutput_original[:, :, 1] == 2, testoutput_original[:, :, 2] == 2)] = 255 # segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 2, testoutput_original[:, :, 1] == 2, testoutput_original[:, :, 2] == 2)] = 0 # # # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 3, testoutput_original[:, :, 1] == 3, testoutput_original[:, :, 2] == 3)] = 0 # segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 3, testoutput_original[:, :, 1] == 3, testoutput_original[:, :, 2] == 3)] = 0 # segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 3, testoutput_original[:, :, 1] == 3, testoutput_original[:, :, 2] == 3)] = 255 # # # elif class_no == 8: # # multi class for cityscapes # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 0, testoutput_original[:, :, 1] == 0, testoutput_original[:, :, 2] == 0)] = 255 # segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 0, testoutput_original[:, :, 1] == 0, testoutput_original[:, :, 2] == 0)] = 0 # segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 0, testoutput_original[:, :, 1] == 0, testoutput_original[:, :, 2] == 0)] = 0 # # # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 0 # segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 255 # segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 0 # # # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 2, testoutput_original[:, :, 1] == 2, testoutput_original[:, :, 2] == 2)] = 0 # segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 2, testoutput_original[:, :, 1] == 2, testoutput_original[:, :, 2] == 2)] = 0 # segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 2, testoutput_original[:, :, 1] == 2, testoutput_original[:, :, 2] == 2)] = 255 # # # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 3, testoutput_original[:, :, 1] == 3, testoutput_original[:, :, 2] == 3)] = 255 # segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 3, testoutput_original[:, :, 1] == 3, testoutput_original[:, :, 2] == 3)] = 255 # segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 3, testoutput_original[:, :, 1] == 3, testoutput_original[:, :, 2] == 3)] = 0 # # # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 4, testoutput_original[:, :, 1] == 4, testoutput_original[:, :, 2] == 4)] = 153 # segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 4, testoutput_original[:, :, 1] == 4, testoutput_original[:, :, 2] == 4)] = 51 # segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 4, testoutput_original[:, :, 1] == 4, testoutput_original[:, :, 2] == 4)] = 255 # # # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 5, testoutput_original[:, :, 1] == 5, testoutput_original[:, :, 2] == 5)] = 255 # segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 5, testoutput_original[:, :, 1] == 5, testoutput_original[:, :, 2] == 5)] = 102 # segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 5, testoutput_original[:, :, 1] == 5, testoutput_original[:, :, 2] == 5)] = 178 # # # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 6, testoutput_original[:, :, 1] == 6, testoutput_original[:, :, 2] == 6)] = 102 # segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 6, testoutput_original[:, :, 1] == 6, testoutput_original[:, :, 2] == 6)] = 255 # segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 6, testoutput_original[:, :, 1] == 6, testoutput_original[:, :, 2] == 6)] = 102 # # # prediction_name = 'seg_' + test_imagename + '.png' # full_error_map_name = os.path.join(prediction_map_path, prediction_name) # imageio.imsave(full_error_map_name, segmentation_map) # prediction_map_path = save_location + '/' + 'Results_map' # try: os.mkdir(prediction_map_path) except OSError as exc: if exc.errno != errno.EEXIST: raise pass # save numerical results: result_dictionary = {'Test IoU data 1': str(test_iou_1 / len(evaluate_index_all_1)), 'Test f1 data 1': str(f1_1 / len(evaluate_index_all_1)), 'Test recall data 1': str(recall_1 / len(evaluate_index_all_1)), 'Test Precision data 1': str(precision_1 / len(evaluate_index_all_1)), 'Test MSE data 1': str(mse_1 / len(evaluate_index_all_1)) } ff_path = prediction_map_path + '/test_result_data.txt' ff = open(ff_path, 'w') ff.write(str(result_dictionary)) ff.close() return test_iou_1 / len(evaluate_index_all_1), f1_1 / len(evaluate_index_all_1), recall_1 / len(evaluate_index_all_1), precision_1 / len(evaluate_index_all_1), mse_1 / len(evaluate_index_all_1), \ data_1_testoutputs
def trainSingleModel(model, model_name, num_epochs, learning_rate, datasettag, train_dataset, train_batchsize, trainloader, validateloader, testdata, reverse_mode, lr_schedule, class_no): # change log names training_amount = len(train_dataset) iteration_amount = training_amount // train_batchsize iteration_amount = iteration_amount - 1 device = torch.device('cuda') lr_str = str(learning_rate) epoches_str = str(num_epochs) save_model_name = model_name + '_' + datasettag + '_e' + epoches_str + '_lr' + lr_str saved_information_path = './Results' try: os.mkdir(saved_information_path) except OSError as exc: if exc.errno != errno.EEXIST: raise pass saved_information_path = saved_information_path + '/' + save_model_name try: os.mkdir(saved_information_path) except OSError as exc: if exc.errno != errno.EEXIST: raise pass saved_model_path = saved_information_path + '/trained_models' try: os.mkdir(saved_model_path) except OSError as exc: if exc.errno != errno.EEXIST: raise pass print('The current model is:') print(save_model_name) print('\n') writer = SummaryWriter('./Results/Log_' + datasettag + '/' + save_model_name) model.to(device) threshold = torch.tensor([0.5], dtype=torch.float32, device=device, requires_grad=False) upper = torch.tensor([1.0], dtype=torch.float32, device=device, requires_grad=False) lower = torch.tensor([0.0], dtype=torch.float32, device=device, requires_grad=False) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-5) if lr_schedule is True: # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.2, patience=10, threshold=0.001) scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[num_epochs // 2, 3*num_epochs // 4], gamma=0.1) start = timeit.default_timer() for epoch in range(num_epochs): model.train() h_dists = 0 f1 = 0 accuracy_iou = 0 running_loss = 0 recall = 0 precision = 0 t_FPs_Ns = 0 t_FPs_Ps = 0 t_FNs_Ns = 0 t_FNs_Ps = 0 t_FPs = 0 t_FNs = 0 t_TPs = 0 t_TNs = 0 t_Ps = 0 t_Ns = 0 effective_h = 0 # j: index of iteration for j, (images, labels, imagename) in enumerate(trainloader): # check training data: # image = images[0, :, :, :].squeeze().detach().cpu().numpy() # label = labels[0, :, :, :].squeeze().detach().cpu().numpy() # image = np.transpose(image, (1, 2, 0)) # label = np.expand_dims(label, axis=2) # label = np.concatenate((label, label, label), axis=2) # plt.imshow(0.5*image + 0.5*label) # plt.show() optimizer.zero_grad() images = images.to(device=device, dtype=torch.float32) labels = labels.to(device=device, dtype=torch.float32) images.requires_grad = True if reverse_mode is True: inverse_labels = torch.ones_like(labels) inverse_labels = inverse_labels.to(device=device, dtype=torch.float32) inverse_labels = inverse_labels - labels else: pass outputs = model(images) prob_outputs = torch.sigmoid(outputs) if reverse_mode is True: loss = dice_loss(prob_outputs, inverse_labels) else: loss = dice_loss(prob_outputs, labels) loss.backward() optimizer.step() # The taks of binary segmentation is too easy, to compensate the simplicity of the task, # we add adversarial noises in the testing images: data_grad = images.grad.data perturbed_data = fgsm_attack(images, 0.2, data_grad) prob_outputs = model(perturbed_data) prob_outputs = torch.sigmoid(prob_outputs) if reverse_mode is True: class_outputs = torch.where(prob_outputs < threshold, upper, lower) else: class_outputs = torch.where(prob_outputs > threshold, upper, lower) if class_no == 2: # hasudorff distance is for binary if (class_outputs == 1).sum() > 1 and (labels == 1).sum() > 1: dist_ = hd95(class_outputs, labels, class_no) h_dists += dist_ effective_h = effective_h + 1 else: pass else: pass mean_iu_ = segmentation_scores(labels, class_outputs, class_no) f1_, recall_, precision_, TPs_, TNs_, FPs_, FNs_, Ps_, Ns_ = f1_score(labels, class_outputs, class_no) running_loss += loss f1 += f1_ accuracy_iou += mean_iu_ recall += recall_ precision += precision_ t_TPs += TPs_ t_TNs += TNs_ t_FPs += FPs_ t_FNs += FNs_ t_Ps += Ps_ t_Ns += Ns_ t_FNs_Ps += (FNs_ + 1e-8) / (Ps_ + 1e-8) t_FPs_Ns += (FPs_ + 1e-8) / (Ns_ + 1e-8) t_FNs_Ns += (FNs_ + 1e-8) / (Ns_ + 1e-8) t_FPs_Ps += (FPs_ + 1e-8) / (Ps_ + 1e-8) if (j + 1) % iteration_amount == 0: validate_iou, validate_f1, validate_recall, validate_precision, v_FPs_Ns, v_FPs_Ps, v_FNs_Ns, v_FNs_Ps, v_FPs, v_FNs, v_TPs, v_TNs, v_Ps, v_Ns, v_h_dist = evaluate(validateloader, model, device, reverse_mode=reverse_mode, class_no=class_no) print( 'Step [{}/{}], Train loss: {:.4f}, ' 'Train iou: {:.4f}, ' 'Train h-dist:{:.4f}, ' 'Val iou: {:.4f},' 'Val h-dist: {:.4f}'.format(epoch + 1, num_epochs, running_loss / (j + 1), accuracy_iou / (j + 1), h_dists / (effective_h + 1), validate_iou, v_h_dist)) # # # ================================================================== # # # # TensorboardX Logging # # # # # ================================================================ # writer.add_scalars('acc metrics', {'train iou': accuracy_iou / (j+1), 'train hausdorff dist': h_dists / (effective_h+1), 'val iou': validate_iou, 'val hasudorff distance': v_h_dist, 'loss': running_loss / (j+1)}, epoch + 1) writer.add_scalars('train confusion matrices analysis', {'train FPs/Ns': t_FPs_Ns / (j+1), 'train FNs/Ps': t_FNs_Ps / (j+1), 'train FPs/Ps': t_FPs_Ps / (j+1), 'train FNs/Ns': t_FNs_Ns / (j+1), 'train FNs': t_FNs / (j+1), 'train FPs': t_FPs / (j+1), 'train TNs': t_TNs / (j+1), 'train TPs': t_TPs / (j+1), 'train Ns': t_Ns / (j+1), 'train Ps': t_Ps / (j+1), 'train imbalance': t_Ps / (t_Ps + t_Ns)}, epoch + 1) writer.add_scalars('val confusion matrices analysis', {'val FPs/Ns': v_FPs_Ns, 'val FNs/Ps': v_FNs_Ps, 'val FPs/Ps': v_FPs_Ps, 'val FNs/Ns': v_FNs_Ns, 'val FNs': v_FNs, 'val FPs': v_FPs, 'val TNs': v_TNs, 'val TPs': v_TPs, 'val Ns': v_Ns, 'val Ps': v_Ps, 'val imbalance': v_Ps / (v_Ps + v_Ns)}, epoch + 1) else: pass # A learning rate schedule plan for fn attention: # we ramp-up linearly inside of each iteration # without the warm-up, it is hard to train sometimes if 'fn' in model_name or 'FN' in model_name: if reverse_mode is True: if epoch < 10: for param_group in optimizer.param_groups: param_group['lr'] = learning_rate * (j / len(trainloader)) else: pass else: pass else: pass if lr_schedule is True: scheduler.step() else: pass # save models at last 10 epochs if epoch >= (num_epochs - 10): save_model_name_full = saved_model_path + '/' + save_model_name + '_epoch' + str(epoch) + '.pt' path_model = save_model_name_full torch.save(model, path_model) # Test on all models and average them: test(testdata, saved_model_path, device, reverse_mode=reverse_mode, class_no=class_no, save_path=saved_information_path) # save model stop = timeit.default_timer() print('Time: ', stop - start) print('\n') print('\nTraining finished and model saved\n') return model
def test( testdata, models_path, device, reverse_mode, class_no, save_path): all_models = glob.glob(os.path.join(models_path, '*.pt')) test_f1 = [] test_iou = [] test_h_dist = [] test_recall = [] test_precision = [] test_iou_adv = [] test_h_dist_adv = [] for model in all_models: model = torch.load(model) model.eval() for j, (testimg, testlabel, testname) in enumerate(testdata): # validate batch size will be set up as 2 # testimg = torch.from_numpy(testimg).to(device=device, dtype=torch.float32) # testlabel = torch.from_numpy(testlabel).to(device=device, dtype=torch.float32) testimg = testimg.to(device=device, dtype=torch.float32) testimg.requires_grad = True testlabel = testlabel.to(device=device, dtype=torch.float32) threshold = torch.tensor([0.5], dtype=torch.float32, device=device, requires_grad=False) upper = torch.tensor([1.0], dtype=torch.float32, device=device, requires_grad=False) lower = torch.tensor([0.0], dtype=torch.float32, device=device, requires_grad=False) # c, h, w = testimg.size() # testimg = testimg.expand(1, c, h, w) testoutput = model(testimg) # (todo) add for multi-class prob_testoutput = torch.sigmoid(testoutput) if class_no == 2: if reverse_mode is True: testoutput = torch.where(prob_testoutput < threshold, upper, lower) else: testoutput = torch.where(prob_testoutput > threshold, upper, lower) # metrics before attack: mean_iu_ = segmentation_scores(testlabel, testoutput, class_no) test_iou.append(mean_iu_) if (testoutput == 1).sum() > 1 and (testlabel == 1).sum() > 1: h_dis95_ = hd95(testoutput, testlabel, class_no) test_h_dist.append(h_dis95_) f1_, recall_, precision_, TP, TN, FP, FN, P, N = f1_score(testlabel, testoutput, class_no) test_f1.append(f1_) test_recall.append(recall_) test_precision.append(precision_) # attack testing data: loss = dice_loss(prob_testoutput, testlabel) model.zero_grad() loss.backward() data_grad = testimg.grad.data perturbed_data = fgsm_attack(testimg, 0.2, data_grad) prob_testoutput_adv = model(perturbed_data) prob_testoutput_adv = torch.sigmoid(prob_testoutput_adv) if class_no == 2: if reverse_mode is True: testoutput_adv = torch.where(prob_testoutput_adv < threshold, upper, lower) else: testoutput_adv = torch.where(prob_testoutput_adv > threshold, upper, lower) mean_iu_adv_ = segmentation_scores(testlabel, testoutput_adv, class_no) test_iou_adv.append(mean_iu_adv_) if (testoutput_adv == 1).sum() > 1 and (testlabel == 1).sum() > 1: h_dis95_adv_ = hd95(testoutput_adv, testlabel, class_no) test_h_dist_adv.append(h_dis95_adv_) # store the test metrics prediction_map_path = save_path + '/Test_result' try: os.mkdir(prediction_map_path) except OSError as exc: if exc.errno != errno.EEXIST: raise pass # save numerical results: result_dictionary = { 'Test IoU mean': str(np.mean(test_iou)), 'Test IoU std': str(np.std(test_iou)), 'Test f1 mean': str(np.mean(test_f1)), 'Test f1 std': str(np.std(test_f1)), 'Test H-dist mean': str(np.mean(test_h_dist)), 'Test H-dist std': str(np.std(test_h_dist)), 'Test precision mean': str(np.mean(test_precision)), 'Test precision std': str(np.std(test_precision)), 'Test recall mean': str(np.mean(test_recall)), 'Test recall std': str(np.std(test_recall)), 'Test IoU attack mean': str(np.mean(test_iou_adv)), 'Test IoU attack std': str(np.std(test_iou_adv)), 'Test H-dist attack mean': str(np.mean(test_h_dist_adv)), 'Test H-dist attack std': str(np.std(test_h_dist_adv)), } ff_path = prediction_map_path + '/test_results.txt' ff = open(ff_path, 'w') ff.write(str(result_dictionary)) ff.close() print( 'Test h-dist: {:.4f}, ' 'Val iou: {:.4f}, '.format(np.mean(test_h_dist), np.mean(test_iou)))