def trainSingleModel(model_name, depth_limit, epochs, width, depth, repeat, lr, lr_scedule, train_dataset, train_batch, data_name, data_augmentation_train, data_augmentation_test, train_loader, validate_data, test_data_1, test_data_2, shuffle, loss, norm, log, no_class, input_channel): # :param model: network module # :param epochs: training total epochs # :param width: first encoder channel number # :param lr: learning rate # :param lr_scedule: true or false for learning rate schedule # :param repeat: repeat same experiments # :param train_dataset: training data set # :param train_batch: batch size # :param train_loader: training loader # :param validate_loader: validation loader # :param shuffle: shuffle training data or not # :param loss: loss function tag, use 'ce' for cross-entropy # :param weights_transfer: 'dynamic', 'static' or 'average' # :param alpha: weight for knowledge distillation loss # :param norm_1: normalisation for model 1 # :param norm_2: normalisation for model 2 # :param log: log tag for recording experiments # :param no_class: 2 or multi-class # :param input_channel: 4 for BRATS, 3 for CityScapes # :param dataset_name: name of the dataset # :param temperature_start: 2 or 4 # :param temperature_end: 4 or 2 # :return: device = torch.device('cuda:0') # side_output_use = False if model_name == 'unet': model = UNet(n_channels=input_channel, n_classes=no_class, bilinear=True).to(device=device) # model = UNet2(in_channels=1, n_classes=1, depth=4, wf=32, padding=False, batch_norm=True, up_mode='upconv').to(device=device) elif model_name == 'Segnet': model = SegNet(in_ch=input_channel, width=width, norm=norm, depth=4, n_classes=no_class, dropout=True, side_output=False).to(device=device) elif model_name == 'SOASNet_single': model = SOASNet_ss(in_ch=input_channel, width=width, depth=depth, norm=norm, n_classes=no_class, mode='low_rank_attn', side_output=False, downsampling_limit=depth_limit).to(device=device) elif model_name == 'SOASNet': model = SOASNet(in_ch=input_channel, width=width, depth=depth, norm=norm, n_classes=no_class, mode='low_rank_attn', side_output=False, downsampling_limit=depth_limit).to(device=device) elif model_name == 'SOASNet_large_kernel': model = SOASNet_ls(in_ch=input_channel, width=width, depth=depth, norm=norm, n_classes=no_class, mode='low_rank_attn', side_output=False, downsampling_limit=depth_limit).to(device=device) elif model_name == 'SOASNet_multi_attn': model = SOASNet_ma(in_ch=input_channel, width=width, depth=depth, norm=norm, n_classes=no_class, mode='low_rank_attn', side_output=False, downsampling_limit=depth_limit).to(device=device) elif model_name == 'SOASNet_very_large_kernel': model = SOASNet_vls(in_ch=input_channel, width=width, depth=depth, norm=norm, n_classes=no_class, mode='low_rank_attn', side_output=False, downsampling_limit=depth_limit).to(device=device) elif model_name == 'SOASNet_segnet': model = SOASNet_segnet( in_ch=input_channel, width=width, depth=depth, norm=norm, n_classes=no_class, mode='low_rank_attn', side_output=False, downsampling_limit=depth_limit).to(device=device) elif model_name == 'SOASNet_segnet_skip': model = SOASNet_segnet_skip( in_ch=input_channel, width=width, depth=depth, norm=norm, n_classes=no_class, mode='low_rank_attn', side_output=False, downsampling_limit=depth_limit).to(device=device) elif model_name == 'RelayNet': model = SOASNet_segnet_skip( in_ch=input_channel, width=width, depth=depth, norm=norm, n_classes=no_class, mode='relaynet', side_output=False, downsampling_limit=depth_limit).to(device=device) elif model_name == 'attn_unet': model = AttentionUNet(in_ch=input_channel, width=width, visulisation=False, class_no=no_class).to(device=device) # ================================== training_amount = len(train_dataset) iteration_amount = training_amount // train_batch iteration_amount = iteration_amount - 1 model_name = model_name + '_Epoch_' + str(epochs) + \ '_Dataset_' + data_name + \ '_Batch_' + str(train_batch) + \ '_Width_' + str(width) + \ '_Loss_' + loss + \ '_Norm_' + norm + \ '_ShuffleTraining_' + str(shuffle) + \ '_Data_Augmentation_Train_' + data_augmentation_train + '_' + \ '_Data_Augmentation_Test_' + data_augmentation_test + '_' + \ '_lr_' + str(lr) + \ '_Repeat_' + str(repeat) print(model_name) writer = SummaryWriter('../../Log_' + log + '/' + model_name) optimizer = AdamW(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-5) # if lr_scedule is True: # learning_rate_steps = lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1) for epoch in range(epochs): model.train() running_loss = 0 # i: index of mini batch if 'mixup' not in data_augmentation_train: for j, (images, labels, imagename) in enumerate(train_loader): images = images.to(device=device, dtype=torch.float32) if no_class == 2: labels = labels.to(device=device, dtype=torch.float32) else: labels = labels.to(device=device, dtype=torch.long) outputs_logits = model(images) optimizer.zero_grad() # calculate main losses for second time if no_class == 2: # if loss == 'dice': # main_loss = dice_loss(torch.sigmoid(outputs_logits), labels) # elif loss == 'ce': # main_loss = nn.BCEWithLogitsLoss(reduction='mean')( outputs_logits, labels) # elif loss == 'hybrid': # main_loss = dice_loss( torch.sigmoid(outputs_logits), labels) + nn.BCEWithLogitsLoss(reduction='mean')( outputs_logits, labels) else: # print(outputs_logits.shape) # print(labels.shape) main_loss = nn.CrossEntropyLoss( reduction='mean', ignore_index=8)(torch.softmax(outputs_logits, dim=1), labels.squeeze(1)) running_loss += main_loss main_loss.backward() optimizer.step() # ============================================================================== # Calculate training and validation metrics at the last iteration of each epoch # ============================================================================== if (j + 1) % iteration_amount == 0: if no_class == 2: outputs = torch.sigmoid(outputs_logits) # outputs = (outputs > 0.5).float() else: _, outputs = torch.max(outputs_logits, dim=1) # outputs = outputs.unsqueeze(1) labels = labels.squeeze(1) # print(outputs.shape) # print(labels.shape) # mean_iu = segmentation_scores(labels.cpu().detach().numpy(), outputs.cpu().detach().numpy(), no_class) mean_iu = intersectionAndUnion(outputs.cpu().detach(), labels.cpu().detach(), no_class) validate_iou, validate_f1, validate_recall, validate_precision = evaluate( data=validate_data, model=model, device=device, class_no=no_class) # print(validate_iou.type) print('Step [{}/{}], ' 'loss: {:.5f}, ' 'train iou: {:.5f}, ' 'val iou: {:.5f}'.format(epoch + 1, epochs, running_loss / (j + 1), mean_iu, validate_iou)) writer.add_scalars( 'scalars', { 'train iou': mean_iu, 'val iou': validate_iou, 'val f1': validate_f1, 'val recall': validate_recall, 'val precision': validate_precision }, epoch + 1) else: # mix-up strategy requires more calculations: for j, (images_1, labels_1, imagename_1, images_2, labels_2, mixed_up_image, lam) in enumerate(train_loader): mixed_up_image = mixed_up_image.to(device=device, dtype=torch.float32) lam = lam.to(device=device, dtype=torch.float32) if no_class == 2: labels_1 = labels_1.to(device=device, dtype=torch.float32) labels_2 = labels_2.to(device=device, dtype=torch.float32) else: labels_1 = labels_1.to(device=device, dtype=torch.long) labels_2 = labels_2.to(device=device, dtype=torch.long) outputs_logits = model(mixed_up_image) optimizer.zero_grad() # calculate main losses for second time if no_class == 2: if loss == 'dice': main_loss = lam * dice_loss( torch.sigmoid(outputs_logits), labels_1) + (1 - lam) * dice_loss( torch.sigmoid(outputs_logits), labels_2) elif loss == 'ce': main_loss = lam * nn.BCEWithLogitsLoss( reduction='mean')(outputs_logits, labels_1) + ( 1 - lam) * nn.BCEWithLogitsLoss( reduction='mean')(outputs_logits, labels_2) elif loss == 'hybrid': main_loss = lam * dice_loss(torch.sigmoid(outputs_logits), labels_1) \ + (1 - lam) * dice_loss(torch.sigmoid(outputs_logits), labels_2) \ + lam * nn.BCEWithLogitsLoss(reduction='mean')(outputs_logits, labels_1) \ + (1 - lam) * nn.BCEWithLogitsLoss(reduction='mean')(outputs_logits, labels_2) elif no_class == 8: main_loss = lam * nn.CrossEntropyLoss(reduction='mean')( outputs_logits, labels_1.squeeze(1)) + ( 1 - lam) * nn.CrossEntropyLoss(reduction='mean')( outputs_logits, labels_2.squeeze(1)) else: main_loss = lam * nn.CrossEntropyLoss(reduction='mean')( outputs_logits, labels_1.squeeze(1)) + ( 1 - lam) * nn.CrossEntropyLoss(reduction='mean')( outputs_logits, labels_2.squeeze(1)) running_loss += main_loss.mean() main_loss.mean().backward() optimizer.step() # ============================================================================== # Calculate training and validation metrics at the last iteration of each epoch # ============================================================================== if (j + 1) % iteration_amount == 0: if no_class == 2: outputs = torch.sigmoid(outputs_logits) else: _, outputs = torch.max(outputs_logits, dim=1) outputs = outputs.unsqueeze(1) mean_iu_1 = segmentation_scores( labels_1.cpu().detach().numpy(), outputs.cpu().detach().numpy(), no_class) mean_iu_2 = segmentation_scores( labels_2.cpu().detach().numpy(), outputs.cpu().detach().numpy(), no_class) mean_iu = lam.data.sum() * mean_iu_1 + ( 1 - lam.data.sum()) * mean_iu_2 validate_iou, validate_f1, validate_recall, validate_precision = evaluate( data=validate_data, model=model, device=device, class_no=no_class) mean_iu = mean_iu.item() print('Step [{}/{}], ' 'loss: {:.4f}, ' 'train iou: {:.4f}, ' 'val iou: {:.4f}'.format(epoch + 1, epochs, running_loss / (j + 1), mean_iu, validate_iou)) writer.add_scalars( 'scalars', { 'train iou': mean_iu, 'val iou': validate_iou, 'val f1': validate_f1, 'val recall': validate_recall, 'val precision': validate_precision }, epoch + 1) if lr_scedule is True: for param_group in optimizer.param_groups: param_group['lr'] = lr * ((1 - epoch / epochs)**0.999) # save model save_folder = '../../saved_models_' + log try: os.makedirs(save_folder) except OSError as exc: if exc.errno != errno.EEXIST: raise pass save_model_name = model_name + '_Final' save_model_name_full = save_folder + '/' + save_model_name + '.pt' torch.save(model, save_model_name_full) # ======================================================================= # testing (disabled during training, because it is too slow) # ======================================================================= save_results_folder = save_folder + '/testing_results_' + model_name try: os.makedirs(save_results_folder) except OSError as exc: if exc.errno != errno.EEXIST: raise pass test_iou_1, test_f1_1, test_recall_1, test_precision_1, mse_1, test_iou_2, test_f1_2, test_recall_2, test_precision_2, mse_2, outputs_1, outputs_2 = test( data_1=test_data_1, data_2=test_data_2, model=model, device=device, class_no=no_class, save_location=save_results_folder) print('test iou data 1: {:.4f}, ' 'test mse data 1: {:.4f}, ' 'test f1 data 1: {:.4f},' 'test recall data 1: {:.4f}, ' 'test precision data 1: {:.4f}, '.format(test_iou_1, mse_1, test_f1_1, test_recall_1, test_precision_1)) print('test iou data 2: {:.4f}, ' 'test mse data 2: {:.4f}, ' 'test f1 data 2: {:.4f},' 'test recall data 2: {:.4f}, ' 'test precision data 2: {:.4f}, '.format(test_iou_2, mse_2, test_f1_2, test_recall_2, test_precision_2)) print('\nTesting finished and results saved.\n') return save_model_name_full
def trainSingleModel(model, model_name, num_epochs, learning_rate, datasettag, train_dataset, train_batchsize, trainloader, validateloader, testdata, reverse_mode, lr_schedule, class_no): # change log names training_amount = len(train_dataset) iteration_amount = training_amount // train_batchsize iteration_amount = iteration_amount - 1 device = torch.device('cuda') lr_str = str(learning_rate) epoches_str = str(num_epochs) save_model_name = model_name + '_' + datasettag + '_e' + epoches_str + '_lr' + lr_str saved_information_path = './Results' try: os.mkdir(saved_information_path) except OSError as exc: if exc.errno != errno.EEXIST: raise pass saved_information_path = saved_information_path + '/' + save_model_name try: os.mkdir(saved_information_path) except OSError as exc: if exc.errno != errno.EEXIST: raise pass saved_model_path = saved_information_path + '/trained_models' try: os.mkdir(saved_model_path) except OSError as exc: if exc.errno != errno.EEXIST: raise pass print('The current model is:') print(save_model_name) print('\n') writer = SummaryWriter('./Results/Log_' + datasettag + '/' + save_model_name) model.to(device) threshold = torch.tensor([0.5], dtype=torch.float32, device=device, requires_grad=False) upper = torch.tensor([1.0], dtype=torch.float32, device=device, requires_grad=False) lower = torch.tensor([0.0], dtype=torch.float32, device=device, requires_grad=False) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-5) if lr_schedule is True: # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.2, patience=10, threshold=0.001) scheduler = lr_scheduler.MultiStepLR( optimizer, milestones=[num_epochs // 2, 3 * num_epochs // 4], gamma=0.1) start = timeit.default_timer() for epoch in range(num_epochs): model.train() train_iou = [] train_loss = [] # j: index of iteration for j, (images, labels, imagename) in enumerate(trainloader): optimizer.zero_grad() images = images.to(device=device, dtype=torch.float32) if class_no == 2: labels = labels.to(device=device, dtype=torch.float32) else: labels = labels.to(device=device, dtype=torch.long) outputs = model(images) if class_no == 2: prob_outputs = torch.sigmoid(outputs) loss = dice_loss(prob_outputs, labels) class_outputs = torch.where(prob_outputs > threshold, upper, lower) else: prob_outputs = torch.softmax(outputs, dim=1) # loss = nn.CrossEntropyLoss(reduction='mean')(prob_outputs, labels) loss = nn.CrossEntropyLoss(reduction='mean')(prob_outputs, labels.squeeze(1)) _, class_outputs = torch.max(outputs, dim=1) loss.backward() optimizer.step() mean_iu_, _, __ = segmentation_scores(labels, class_outputs, class_no) train_iou.append(mean_iu_) train_loss.append(loss.item()) if lr_schedule is True: # scheduler.step(validate_iou) scheduler.step() else: pass model.eval() with torch.no_grad(): validate_iou = [] validate_f1 = [] validate_h_dist = [] for i, (val_images, val_label, imagename) in enumerate(validateloader): val_img = val_images.to(device=device, dtype=torch.float32) if class_no == 2: val_label = val_label.to(device=device, dtype=torch.float32) else: val_label = val_label.to(device=device, dtype=torch.long) assert torch.max(val_label) != 100.0 val_outputs = model(val_img) if class_no == 2: val_class_outputs = torch.sigmoid(val_outputs) val_class_outputs = (val_class_outputs > 0.5).float() else: val_class_outputs = torch.softmax(val_outputs, dim=1) _, val_class_outputs = torch.max(val_class_outputs, dim=1) # b, c, h, w = val_label.size() # val_class_outputs = val_class_outputs.reshape(b, c, h, w) eval_mean_iu_, _, __ = segmentation_scores( val_label, val_class_outputs, class_no) eval_f1_, eval_recall_, eval_precision_, eTP, eTN, eFP, eFN, eP, eN = f1_score( val_label, val_class_outputs, class_no) validate_iou.append(eval_mean_iu_) validate_f1.append(eval_f1_) if (val_class_outputs == 1).sum() > 1 and ( val_label == 1).sum() > 1 and class_no == 2: v_dist_ = hd95(val_class_outputs, val_label, class_no) validate_h_dist.append(v_dist_) print('Step [{}/{}], ' 'Train loss: {:.4f}, ' 'Train iou: {:.4f}, ' 'val iou:{:.4f}, '.format(epoch + 1, num_epochs, np.nanmean(train_loss), np.nanmean(train_iou), np.nanmean(validate_iou))) writer.add_scalars( 'acc metrics', { 'train iou': np.nanmean(train_iou), 'val iou': np.nanmean(validate_iou), 'val f1': np.nanmean(validate_f1) }, epoch + 1) if epoch > num_epochs - 10: save_model_name_full = saved_model_path + '/epoch' + str(epoch) save_model_name_full = save_model_name_full + '.pt' path_model = save_model_name_full torch.save(model, path_model) test(testdata, saved_model_path, device, reverse_mode=reverse_mode, class_no=class_no, save_path=saved_model_path) # save model stop = timeit.default_timer() print('Time: ', stop - start) print('\nTraining finished and model saved\n') return model
def test(testdata, models_path, device, reverse_mode, class_no, save_path): all_models = glob.glob(os.path.join(models_path, '*.pt')) # with torch.no_grad(): test_f1 = [] test_iou = [] test_h_dist = [] test_acc = [] test_w_acc = [] test_recall = [] test_precision = [] test_bf = [] test_iou_adv = [] test_h_dist_adv = [] for model in all_models: model = torch.load(model) model.eval() for j, (testimg, testlabel, testname) in enumerate(testdata): # validate batch size will be set up as 2 # testimg = torch.from_numpy(testimg).to(device=device, dtype=torch.float32) # testlabel = torch.from_numpy(testlabel).to(device=device, dtype=torch.float32) testimg = testimg.to(device=device, dtype=torch.float32) if class_no == 2: testlabel = testlabel.to(device=device, dtype=torch.float32) else: testlabel = testlabel.to(device=device, dtype=torch.long) # b, c, h, w = testimg.size() # testimg = testimg[:, 0, :, :].view(b, 1, h, w).contiguous() # testlabel = testlabel[:, 0, :, :].view(b, 1, h, w).contiguous() if torch.max(testlabel) == 255.: testlabel = testlabel / 255. testimg.requires_grad = True threshold = torch.tensor([0.5], dtype=torch.float32, device=device, requires_grad=False) upper = torch.tensor([1.0], dtype=torch.float32, device=device, requires_grad=False) lower = torch.tensor([0.0], dtype=torch.float32, device=device, requires_grad=False) # c, h, w = testimg.size() # testimg = testimg.expand(1, c, h, w) testoutput = model(testimg) if class_no == 2: prob_testoutput = torch.sigmoid(testoutput) testoutput = (prob_testoutput > 0.5).float() else: prob_testoutput = torch.softmax(testoutput, dim=1) _, testoutput = torch.max(prob_testoutput, dim=1) # attack testing data: if class_no == 2: loss = dice_loss(prob_testoutput, testlabel) else: loss = nn.CrossEntropyLoss(reduction='mean')( prob_testoutput, testlabel.squeeze(1)) model.zero_grad() loss.backward() data_grad = testimg.grad.data perturbed_data = fgsm_attack(testimg, 0.2, data_grad) output_attack = model(perturbed_data) if class_no == 2: output_attack = torch.sigmoid(output_attack) output_attack = (output_attack > 0.5).float() else: output_attack = torch.softmax(output_attack, dim=1) _, output_attack = torch.max(output_attack, dim=1) mean_iu_, acc_, w_acc_ = segmentation_scores( testlabel, testoutput, class_no) test_iou.append(mean_iu_) test_acc.append(acc_) test_w_acc.append(w_acc_) mean_iu_adv_, _, __ = segmentation_scores(testlabel, output_attack, class_no) test_iou_adv.append(mean_iu_adv_) if (testoutput == 1).sum() > 1 and ( testlabel == 1).sum() > 1 and class_no == 2: h_dis95_ = hd95(testoutput, testlabel, class_no) test_h_dist.append(h_dis95_) if (output_attack == 1).sum() > 1 and ( testlabel == 1).sum() > 1 and class_no == 2: h_dis95_attack_ = hd95(output_attack, testlabel, class_no) test_h_dist_adv.append(h_dis95_attack_) f1_, recall_, precision_, TP, TN, FP, FN, P, N = f1_score( testlabel, testoutput, class_no) bf_ = 2 * precision_ * recall_ / (recall_ + precision_) test_f1.append(f1_) test_recall.append(recall_) test_precision.append(precision_) test_bf.append(bf_) prediction_map_path = save_path + '/Test' try: os.mkdir(prediction_map_path) except OSError as exc: if exc.errno != errno.EEXIST: raise pass result_dictionary = { 'Test IoU mean': str(np.mean(test_iou)), 'Test IoU std': str(np.std(test_iou)), 'Test Acc mean': str(np.mean(test_acc)), 'Test Acc std': str(np.std(test_acc)), 'Test W ACC mean': str(np.mean(test_w_acc)), 'Test W ACC std': str(np.std(test_w_acc)), 'Test BF mean': str(np.mean(test_bf)), 'Test BF std': str(np.std(test_bf)), 'Test f1 mean': str(np.mean(test_f1)), 'Test f1 std': str(np.std(test_f1)), 'Test H-dist mean': str(np.mean(test_h_dist)), 'Test H-dist std': str(np.std(test_h_dist)), 'Test precision mean': str(np.mean(test_precision)), 'Test precision std': str(np.std(test_precision)), 'Test recall mean': str(np.mean(test_recall)), 'Test recall std': str(np.std(test_recall)), 'Test IoU attack mean': str(np.mean(test_iou_adv)), 'Test IoU attack std': str(np.std(test_iou_adv)), 'Test H-dist attack mean': str(np.mean(test_h_dist_adv)), 'Test H-dist attack std': str(np.std(test_h_dist_adv)), } ff_path = prediction_map_path + '/test_result_data.txt' ff = open(ff_path, 'w') ff.write(str(result_dictionary)) ff.close() print('Test h-dist: {:.4f}, ' 'Test iou: {:.4f}, '.format(np.mean(test_h_dist), np.mean(test_iou)))
def evaluate(evaluatedata, model, device, reverse_mode, class_no): model.eval() f1 = 0 test_iou = 0 test_h_dist = 0 recall = 0 precision = 0 FPs_Ns = 0 FNs_Ps = 0 FPs_Ps = 0 FNs_Ns = 0 TPs = 0 TNs = 0 FNs = 0 FPs = 0 Ps = 0 Ns = 0 test_iou_attack = 0 test_h_dist_attack = 0 effective_h = 0 effective_h_attack = 0 for j, (testimg, testlabel, testname) in enumerate(evaluatedata): # validate batch size will be set up as 2 # j will be close enough to the # testimg = testimg.to(device=device, dtype=torch.float32) testimg = testimg.to(device=device, dtype=torch.float32) testlabel = testlabel.to(device=device, dtype=torch.float32) # b, c, h, w = testimg.size() # testimg = testimg[:, 0, :, :].view(b, 1, h, w).contiguous() # testlabel = testlabel[:, 0, :, :].view(b, 1, h, w).contiguous() # if torch.max(testlabel) == 255.: # testlabel = testlabel / 255. testimg.requires_grad = True # testlabel = testlabel.to(device=device, dtype=torch.float32) threshold = torch.tensor([0.5], dtype=torch.float32, device=device, requires_grad=False) upper = torch.tensor([1.0], dtype=torch.float32, device=device, requires_grad=False) lower = torch.tensor([0.0], dtype=torch.float32, device=device, requires_grad=False) testoutput = model(testimg) prob_testoutput = torch.sigmoid(testoutput) # attack testing data: loss = dice_loss(prob_testoutput, testlabel) model.zero_grad() loss.backward() data_grad = testimg.grad.data perturbed_data = fgsm_attack(testimg, 0.2, data_grad) output_attack = model(perturbed_data) output_attack = torch.sigmoid(output_attack) if reverse_mode is True: testoutput = torch.where(prob_testoutput < threshold, upper, lower) output_attack = torch.where(output_attack < threshold, upper, lower) else: testoutput = torch.where(prob_testoutput > threshold, upper, lower) output_attack = torch.where(output_attack > threshold, upper, lower) mean_iu_, _, __ = segmentation_scores(testlabel, testoutput, class_no) mean_iu_attack_, _, __ = segmentation_scores(testlabel, output_attack, class_no) if (testoutput == 1).sum() > 1 and (testlabel == 1).sum() > 1: h_dis95_ = hd95(testoutput, testlabel, class_no) test_h_dist += h_dis95_ effective_h = effective_h + 1 if (output_attack == 1).sum() > 1 and (testlabel == 1).sum() > 1: h_dis95_attack_ = hd95(output_attack, testlabel, class_no) effective_h_attack = effective_h_attack + 1 test_h_dist_attack += h_dis95_attack_ f1_, recall_, precision_, TP, TN, FP, FN, P, N = f1_score( testlabel, testoutput, class_no) f1 += f1_ test_iou += mean_iu_ recall += recall_ precision += precision_ TPs += TP TNs += TN FPs += FP FNs += FN Ps += P Ns += N FNs_Ps += (FN + 1e-10) / (P + 1e-10) FPs_Ns += (FP + 1e-10) / (N + 1e-10) FNs_Ns += (FN + 1e-10) / (N + 1e-10) FPs_Ps += (FP + 1e-10) / (P + 1e-10) test_iou_attack += mean_iu_attack_ return test_iou / (j + 1), f1 / (j + 1), recall / (j + 1), precision / ( j + 1), FPs_Ns / (j + 1), FPs_Ps / (j + 1), FNs_Ns / (j + 1), FNs_Ps / ( j + 1), FPs / (j + 1), FNs / (j + 1), TPs / (j + 1), TNs / ( j + 1), Ps / (j + 1), Ns / (j + 1), test_h_dist / ( effective_h + 1), test_iou_attack / ( j + 1), test_h_dist_attack / (effective_h_attack + 1)
def trainSingleModel(model, model_name, num_epochs, learning_rate, datasettag, train_dataset, train_batchsize, trainloader, validateloader, testdata, reverse_mode, lr_schedule, class_no): # change log names training_amount = len(train_dataset) iteration_amount = training_amount // train_batchsize iteration_amount = iteration_amount - 1 device = torch.device('cuda') lr_str = str(learning_rate) epoches_str = str(num_epochs) save_model_name = model_name + '_' + datasettag + '_e' + epoches_str + '_lr' + lr_str saved_information_path = './Results' try: os.mkdir(saved_information_path) except OSError as exc: if exc.errno != errno.EEXIST: raise pass saved_information_path = saved_information_path + '/' + save_model_name try: os.mkdir(saved_information_path) except OSError as exc: if exc.errno != errno.EEXIST: raise pass saved_model_path = saved_information_path + '/trained_models' try: os.mkdir(saved_model_path) except OSError as exc: if exc.errno != errno.EEXIST: raise pass print('The current model is:') print(save_model_name) print('\n') writer = SummaryWriter('./Results/Log_' + datasettag + '/' + save_model_name) model.to(device) threshold = torch.tensor([0.5], dtype=torch.float32, device=device, requires_grad=False) upper = torch.tensor([1.0], dtype=torch.float32, device=device, requires_grad=False) lower = torch.tensor([0.0], dtype=torch.float32, device=device, requires_grad=False) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-5) if lr_schedule is True: # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.2, patience=10, threshold=0.001) scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[num_epochs // 2, 3*num_epochs // 4], gamma=0.1) start = timeit.default_timer() for epoch in range(num_epochs): model.train() h_dists = 0 f1 = 0 accuracy_iou = 0 running_loss = 0 recall = 0 precision = 0 t_FPs_Ns = 0 t_FPs_Ps = 0 t_FNs_Ns = 0 t_FNs_Ps = 0 t_FPs = 0 t_FNs = 0 t_TPs = 0 t_TNs = 0 t_Ps = 0 t_Ns = 0 effective_h = 0 # j: index of iteration for j, (images, labels, imagename) in enumerate(trainloader): # check training data: # image = images[0, :, :, :].squeeze().detach().cpu().numpy() # label = labels[0, :, :, :].squeeze().detach().cpu().numpy() # image = np.transpose(image, (1, 2, 0)) # label = np.expand_dims(label, axis=2) # label = np.concatenate((label, label, label), axis=2) # plt.imshow(0.5*image + 0.5*label) # plt.show() optimizer.zero_grad() images = images.to(device=device, dtype=torch.float32) labels = labels.to(device=device, dtype=torch.float32) images.requires_grad = True if reverse_mode is True: inverse_labels = torch.ones_like(labels) inverse_labels = inverse_labels.to(device=device, dtype=torch.float32) inverse_labels = inverse_labels - labels else: pass outputs = model(images) prob_outputs = torch.sigmoid(outputs) if reverse_mode is True: loss = dice_loss(prob_outputs, inverse_labels) else: loss = dice_loss(prob_outputs, labels) loss.backward() optimizer.step() # The taks of binary segmentation is too easy, to compensate the simplicity of the task, # we add adversarial noises in the testing images: data_grad = images.grad.data perturbed_data = fgsm_attack(images, 0.2, data_grad) prob_outputs = model(perturbed_data) prob_outputs = torch.sigmoid(prob_outputs) if reverse_mode is True: class_outputs = torch.where(prob_outputs < threshold, upper, lower) else: class_outputs = torch.where(prob_outputs > threshold, upper, lower) if class_no == 2: # hasudorff distance is for binary if (class_outputs == 1).sum() > 1 and (labels == 1).sum() > 1: dist_ = hd95(class_outputs, labels, class_no) h_dists += dist_ effective_h = effective_h + 1 else: pass else: pass mean_iu_ = segmentation_scores(labels, class_outputs, class_no) f1_, recall_, precision_, TPs_, TNs_, FPs_, FNs_, Ps_, Ns_ = f1_score(labels, class_outputs, class_no) running_loss += loss f1 += f1_ accuracy_iou += mean_iu_ recall += recall_ precision += precision_ t_TPs += TPs_ t_TNs += TNs_ t_FPs += FPs_ t_FNs += FNs_ t_Ps += Ps_ t_Ns += Ns_ t_FNs_Ps += (FNs_ + 1e-8) / (Ps_ + 1e-8) t_FPs_Ns += (FPs_ + 1e-8) / (Ns_ + 1e-8) t_FNs_Ns += (FNs_ + 1e-8) / (Ns_ + 1e-8) t_FPs_Ps += (FPs_ + 1e-8) / (Ps_ + 1e-8) if (j + 1) % iteration_amount == 0: validate_iou, validate_f1, validate_recall, validate_precision, v_FPs_Ns, v_FPs_Ps, v_FNs_Ns, v_FNs_Ps, v_FPs, v_FNs, v_TPs, v_TNs, v_Ps, v_Ns, v_h_dist = evaluate(validateloader, model, device, reverse_mode=reverse_mode, class_no=class_no) print( 'Step [{}/{}], Train loss: {:.4f}, ' 'Train iou: {:.4f}, ' 'Train h-dist:{:.4f}, ' 'Val iou: {:.4f},' 'Val h-dist: {:.4f}'.format(epoch + 1, num_epochs, running_loss / (j + 1), accuracy_iou / (j + 1), h_dists / (effective_h + 1), validate_iou, v_h_dist)) # # # ================================================================== # # # # TensorboardX Logging # # # # # ================================================================ # writer.add_scalars('acc metrics', {'train iou': accuracy_iou / (j+1), 'train hausdorff dist': h_dists / (effective_h+1), 'val iou': validate_iou, 'val hasudorff distance': v_h_dist, 'loss': running_loss / (j+1)}, epoch + 1) writer.add_scalars('train confusion matrices analysis', {'train FPs/Ns': t_FPs_Ns / (j+1), 'train FNs/Ps': t_FNs_Ps / (j+1), 'train FPs/Ps': t_FPs_Ps / (j+1), 'train FNs/Ns': t_FNs_Ns / (j+1), 'train FNs': t_FNs / (j+1), 'train FPs': t_FPs / (j+1), 'train TNs': t_TNs / (j+1), 'train TPs': t_TPs / (j+1), 'train Ns': t_Ns / (j+1), 'train Ps': t_Ps / (j+1), 'train imbalance': t_Ps / (t_Ps + t_Ns)}, epoch + 1) writer.add_scalars('val confusion matrices analysis', {'val FPs/Ns': v_FPs_Ns, 'val FNs/Ps': v_FNs_Ps, 'val FPs/Ps': v_FPs_Ps, 'val FNs/Ns': v_FNs_Ns, 'val FNs': v_FNs, 'val FPs': v_FPs, 'val TNs': v_TNs, 'val TPs': v_TPs, 'val Ns': v_Ns, 'val Ps': v_Ps, 'val imbalance': v_Ps / (v_Ps + v_Ns)}, epoch + 1) else: pass # A learning rate schedule plan for fn attention: # we ramp-up linearly inside of each iteration # without the warm-up, it is hard to train sometimes if 'fn' in model_name or 'FN' in model_name: if reverse_mode is True: if epoch < 10: for param_group in optimizer.param_groups: param_group['lr'] = learning_rate * (j / len(trainloader)) else: pass else: pass else: pass if lr_schedule is True: scheduler.step() else: pass # save models at last 10 epochs if epoch >= (num_epochs - 10): save_model_name_full = saved_model_path + '/' + save_model_name + '_epoch' + str(epoch) + '.pt' path_model = save_model_name_full torch.save(model, path_model) # Test on all models and average them: test(testdata, saved_model_path, device, reverse_mode=reverse_mode, class_no=class_no, save_path=saved_information_path) # save model stop = timeit.default_timer() print('Time: ', stop - start) print('\n') print('\nTraining finished and model saved\n') return model
def test( testdata, models_path, device, reverse_mode, class_no, save_path): all_models = glob.glob(os.path.join(models_path, '*.pt')) test_f1 = [] test_iou = [] test_h_dist = [] test_recall = [] test_precision = [] test_iou_adv = [] test_h_dist_adv = [] for model in all_models: model = torch.load(model) model.eval() for j, (testimg, testlabel, testname) in enumerate(testdata): # validate batch size will be set up as 2 # testimg = torch.from_numpy(testimg).to(device=device, dtype=torch.float32) # testlabel = torch.from_numpy(testlabel).to(device=device, dtype=torch.float32) testimg = testimg.to(device=device, dtype=torch.float32) testimg.requires_grad = True testlabel = testlabel.to(device=device, dtype=torch.float32) threshold = torch.tensor([0.5], dtype=torch.float32, device=device, requires_grad=False) upper = torch.tensor([1.0], dtype=torch.float32, device=device, requires_grad=False) lower = torch.tensor([0.0], dtype=torch.float32, device=device, requires_grad=False) # c, h, w = testimg.size() # testimg = testimg.expand(1, c, h, w) testoutput = model(testimg) # (todo) add for multi-class prob_testoutput = torch.sigmoid(testoutput) if class_no == 2: if reverse_mode is True: testoutput = torch.where(prob_testoutput < threshold, upper, lower) else: testoutput = torch.where(prob_testoutput > threshold, upper, lower) # metrics before attack: mean_iu_ = segmentation_scores(testlabel, testoutput, class_no) test_iou.append(mean_iu_) if (testoutput == 1).sum() > 1 and (testlabel == 1).sum() > 1: h_dis95_ = hd95(testoutput, testlabel, class_no) test_h_dist.append(h_dis95_) f1_, recall_, precision_, TP, TN, FP, FN, P, N = f1_score(testlabel, testoutput, class_no) test_f1.append(f1_) test_recall.append(recall_) test_precision.append(precision_) # attack testing data: loss = dice_loss(prob_testoutput, testlabel) model.zero_grad() loss.backward() data_grad = testimg.grad.data perturbed_data = fgsm_attack(testimg, 0.2, data_grad) prob_testoutput_adv = model(perturbed_data) prob_testoutput_adv = torch.sigmoid(prob_testoutput_adv) if class_no == 2: if reverse_mode is True: testoutput_adv = torch.where(prob_testoutput_adv < threshold, upper, lower) else: testoutput_adv = torch.where(prob_testoutput_adv > threshold, upper, lower) mean_iu_adv_ = segmentation_scores(testlabel, testoutput_adv, class_no) test_iou_adv.append(mean_iu_adv_) if (testoutput_adv == 1).sum() > 1 and (testlabel == 1).sum() > 1: h_dis95_adv_ = hd95(testoutput_adv, testlabel, class_no) test_h_dist_adv.append(h_dis95_adv_) # store the test metrics prediction_map_path = save_path + '/Test_result' try: os.mkdir(prediction_map_path) except OSError as exc: if exc.errno != errno.EEXIST: raise pass # save numerical results: result_dictionary = { 'Test IoU mean': str(np.mean(test_iou)), 'Test IoU std': str(np.std(test_iou)), 'Test f1 mean': str(np.mean(test_f1)), 'Test f1 std': str(np.std(test_f1)), 'Test H-dist mean': str(np.mean(test_h_dist)), 'Test H-dist std': str(np.std(test_h_dist)), 'Test precision mean': str(np.mean(test_precision)), 'Test precision std': str(np.std(test_precision)), 'Test recall mean': str(np.mean(test_recall)), 'Test recall std': str(np.std(test_recall)), 'Test IoU attack mean': str(np.mean(test_iou_adv)), 'Test IoU attack std': str(np.std(test_iou_adv)), 'Test H-dist attack mean': str(np.mean(test_h_dist_adv)), 'Test H-dist attack std': str(np.std(test_h_dist_adv)), } ff_path = prediction_map_path + '/test_results.txt' ff = open(ff_path, 'w') ff.write(str(result_dictionary)) ff.close() print( 'Test h-dist: {:.4f}, ' 'Val iou: {:.4f}, '.format(np.mean(test_h_dist), np.mean(test_iou)))