def evaluate(data, model, device, class_no): model.eval() with torch.no_grad(): # f1 = 0 test_iou = 0 test_h_dist = 0 recall = 0 precision = 0 # # for index in evaluate_index: for j, (testimg, testlabel, testimgname) in enumerate(data): # =========================================================== # =========================================================== testimg = testimg.to(device=device, dtype=torch.float32) testlabel = testlabel.to(device=device, dtype=torch.float32) testoutput = model(testimg) if class_no == 2: # testoutput = torch.sigmoid(testoutput) testoutput = (testoutput > 0.5).float() # else: # _, testoutput = torch.max(testoutput, dim=1) # # mean_iu_ = segmentation_scores(testlabel.cpu().detach().numpy(), testoutput.cpu().detach().numpy(), class_no) mean_iu_ = intersectionAndUnion(testoutput.cpu().detach(), testlabel.cpu().detach(), class_no) f1_, recall_, precision_ = f1_score(testlabel.cpu().detach().numpy(), testoutput.cpu().detach().numpy(), class_no) f1 += f1_ test_iou += mean_iu_ recall += recall_ precision += precision_ # return test_iou / len(evaluate_index), f1 / len(evaluate_index), recall / len(evaluate_index), precision / len(evaluate_index) return test_iou / (j + 1), f1 / (j + 1), recall / (j + 1), precision / (j + 1)
def trainSingleModel(model_name, depth_limit, epochs, width, depth, repeat, lr, lr_scedule, train_dataset, train_batch, data_name, data_augmentation_train, data_augmentation_test, train_loader, validate_data, test_data_1, test_data_2, shuffle, loss, norm, log, no_class, input_channel): # :param model: network module # :param epochs: training total epochs # :param width: first encoder channel number # :param lr: learning rate # :param lr_scedule: true or false for learning rate schedule # :param repeat: repeat same experiments # :param train_dataset: training data set # :param train_batch: batch size # :param train_loader: training loader # :param validate_loader: validation loader # :param shuffle: shuffle training data or not # :param loss: loss function tag, use 'ce' for cross-entropy # :param weights_transfer: 'dynamic', 'static' or 'average' # :param alpha: weight for knowledge distillation loss # :param norm_1: normalisation for model 1 # :param norm_2: normalisation for model 2 # :param log: log tag for recording experiments # :param no_class: 2 or multi-class # :param input_channel: 4 for BRATS, 3 for CityScapes # :param dataset_name: name of the dataset # :param temperature_start: 2 or 4 # :param temperature_end: 4 or 2 # :return: device = torch.device('cuda:0') # side_output_use = False if model_name == 'unet': model = UNet(n_channels=input_channel, n_classes=no_class, bilinear=True).to(device=device) # model = UNet2(in_channels=1, n_classes=1, depth=4, wf=32, padding=False, batch_norm=True, up_mode='upconv').to(device=device) elif model_name == 'Segnet': model = SegNet(in_ch=input_channel, width=width, norm=norm, depth=4, n_classes=no_class, dropout=True, side_output=False).to(device=device) elif model_name == 'SOASNet_single': model = SOASNet_ss(in_ch=input_channel, width=width, depth=depth, norm=norm, n_classes=no_class, mode='low_rank_attn', side_output=False, downsampling_limit=depth_limit).to(device=device) elif model_name == 'SOASNet': model = SOASNet(in_ch=input_channel, width=width, depth=depth, norm=norm, n_classes=no_class, mode='low_rank_attn', side_output=False, downsampling_limit=depth_limit).to(device=device) elif model_name == 'SOASNet_large_kernel': model = SOASNet_ls(in_ch=input_channel, width=width, depth=depth, norm=norm, n_classes=no_class, mode='low_rank_attn', side_output=False, downsampling_limit=depth_limit).to(device=device) elif model_name == 'SOASNet_multi_attn': model = SOASNet_ma(in_ch=input_channel, width=width, depth=depth, norm=norm, n_classes=no_class, mode='low_rank_attn', side_output=False, downsampling_limit=depth_limit).to(device=device) elif model_name == 'SOASNet_very_large_kernel': model = SOASNet_vls(in_ch=input_channel, width=width, depth=depth, norm=norm, n_classes=no_class, mode='low_rank_attn', side_output=False, downsampling_limit=depth_limit).to(device=device) elif model_name == 'SOASNet_segnet': model = SOASNet_segnet( in_ch=input_channel, width=width, depth=depth, norm=norm, n_classes=no_class, mode='low_rank_attn', side_output=False, downsampling_limit=depth_limit).to(device=device) elif model_name == 'SOASNet_segnet_skip': model = SOASNet_segnet_skip( in_ch=input_channel, width=width, depth=depth, norm=norm, n_classes=no_class, mode='low_rank_attn', side_output=False, downsampling_limit=depth_limit).to(device=device) elif model_name == 'RelayNet': model = SOASNet_segnet_skip( in_ch=input_channel, width=width, depth=depth, norm=norm, n_classes=no_class, mode='relaynet', side_output=False, downsampling_limit=depth_limit).to(device=device) elif model_name == 'attn_unet': model = AttentionUNet(in_ch=input_channel, width=width, visulisation=False, class_no=no_class).to(device=device) # ================================== training_amount = len(train_dataset) iteration_amount = training_amount // train_batch iteration_amount = iteration_amount - 1 model_name = model_name + '_Epoch_' + str(epochs) + \ '_Dataset_' + data_name + \ '_Batch_' + str(train_batch) + \ '_Width_' + str(width) + \ '_Loss_' + loss + \ '_Norm_' + norm + \ '_ShuffleTraining_' + str(shuffle) + \ '_Data_Augmentation_Train_' + data_augmentation_train + '_' + \ '_Data_Augmentation_Test_' + data_augmentation_test + '_' + \ '_lr_' + str(lr) + \ '_Repeat_' + str(repeat) print(model_name) writer = SummaryWriter('../../Log_' + log + '/' + model_name) optimizer = AdamW(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-5) # if lr_scedule is True: # learning_rate_steps = lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1) for epoch in range(epochs): model.train() running_loss = 0 # i: index of mini batch if 'mixup' not in data_augmentation_train: for j, (images, labels, imagename) in enumerate(train_loader): images = images.to(device=device, dtype=torch.float32) if no_class == 2: labels = labels.to(device=device, dtype=torch.float32) else: labels = labels.to(device=device, dtype=torch.long) outputs_logits = model(images) optimizer.zero_grad() # calculate main losses for second time if no_class == 2: # if loss == 'dice': # main_loss = dice_loss(torch.sigmoid(outputs_logits), labels) # elif loss == 'ce': # main_loss = nn.BCEWithLogitsLoss(reduction='mean')( outputs_logits, labels) # elif loss == 'hybrid': # main_loss = dice_loss( torch.sigmoid(outputs_logits), labels) + nn.BCEWithLogitsLoss(reduction='mean')( outputs_logits, labels) else: # print(outputs_logits.shape) # print(labels.shape) main_loss = nn.CrossEntropyLoss( reduction='mean', ignore_index=8)(torch.softmax(outputs_logits, dim=1), labels.squeeze(1)) running_loss += main_loss main_loss.backward() optimizer.step() # ============================================================================== # Calculate training and validation metrics at the last iteration of each epoch # ============================================================================== if (j + 1) % iteration_amount == 0: if no_class == 2: outputs = torch.sigmoid(outputs_logits) # outputs = (outputs > 0.5).float() else: _, outputs = torch.max(outputs_logits, dim=1) # outputs = outputs.unsqueeze(1) labels = labels.squeeze(1) # print(outputs.shape) # print(labels.shape) # mean_iu = segmentation_scores(labels.cpu().detach().numpy(), outputs.cpu().detach().numpy(), no_class) mean_iu = intersectionAndUnion(outputs.cpu().detach(), labels.cpu().detach(), no_class) validate_iou, validate_f1, validate_recall, validate_precision = evaluate( data=validate_data, model=model, device=device, class_no=no_class) # print(validate_iou.type) print('Step [{}/{}], ' 'loss: {:.5f}, ' 'train iou: {:.5f}, ' 'val iou: {:.5f}'.format(epoch + 1, epochs, running_loss / (j + 1), mean_iu, validate_iou)) writer.add_scalars( 'scalars', { 'train iou': mean_iu, 'val iou': validate_iou, 'val f1': validate_f1, 'val recall': validate_recall, 'val precision': validate_precision }, epoch + 1) else: # mix-up strategy requires more calculations: for j, (images_1, labels_1, imagename_1, images_2, labels_2, mixed_up_image, lam) in enumerate(train_loader): mixed_up_image = mixed_up_image.to(device=device, dtype=torch.float32) lam = lam.to(device=device, dtype=torch.float32) if no_class == 2: labels_1 = labels_1.to(device=device, dtype=torch.float32) labels_2 = labels_2.to(device=device, dtype=torch.float32) else: labels_1 = labels_1.to(device=device, dtype=torch.long) labels_2 = labels_2.to(device=device, dtype=torch.long) outputs_logits = model(mixed_up_image) optimizer.zero_grad() # calculate main losses for second time if no_class == 2: if loss == 'dice': main_loss = lam * dice_loss( torch.sigmoid(outputs_logits), labels_1) + (1 - lam) * dice_loss( torch.sigmoid(outputs_logits), labels_2) elif loss == 'ce': main_loss = lam * nn.BCEWithLogitsLoss( reduction='mean')(outputs_logits, labels_1) + ( 1 - lam) * nn.BCEWithLogitsLoss( reduction='mean')(outputs_logits, labels_2) elif loss == 'hybrid': main_loss = lam * dice_loss(torch.sigmoid(outputs_logits), labels_1) \ + (1 - lam) * dice_loss(torch.sigmoid(outputs_logits), labels_2) \ + lam * nn.BCEWithLogitsLoss(reduction='mean')(outputs_logits, labels_1) \ + (1 - lam) * nn.BCEWithLogitsLoss(reduction='mean')(outputs_logits, labels_2) elif no_class == 8: main_loss = lam * nn.CrossEntropyLoss(reduction='mean')( outputs_logits, labels_1.squeeze(1)) + ( 1 - lam) * nn.CrossEntropyLoss(reduction='mean')( outputs_logits, labels_2.squeeze(1)) else: main_loss = lam * nn.CrossEntropyLoss(reduction='mean')( outputs_logits, labels_1.squeeze(1)) + ( 1 - lam) * nn.CrossEntropyLoss(reduction='mean')( outputs_logits, labels_2.squeeze(1)) running_loss += main_loss.mean() main_loss.mean().backward() optimizer.step() # ============================================================================== # Calculate training and validation metrics at the last iteration of each epoch # ============================================================================== if (j + 1) % iteration_amount == 0: if no_class == 2: outputs = torch.sigmoid(outputs_logits) else: _, outputs = torch.max(outputs_logits, dim=1) outputs = outputs.unsqueeze(1) mean_iu_1 = segmentation_scores( labels_1.cpu().detach().numpy(), outputs.cpu().detach().numpy(), no_class) mean_iu_2 = segmentation_scores( labels_2.cpu().detach().numpy(), outputs.cpu().detach().numpy(), no_class) mean_iu = lam.data.sum() * mean_iu_1 + ( 1 - lam.data.sum()) * mean_iu_2 validate_iou, validate_f1, validate_recall, validate_precision = evaluate( data=validate_data, model=model, device=device, class_no=no_class) mean_iu = mean_iu.item() print('Step [{}/{}], ' 'loss: {:.4f}, ' 'train iou: {:.4f}, ' 'val iou: {:.4f}'.format(epoch + 1, epochs, running_loss / (j + 1), mean_iu, validate_iou)) writer.add_scalars( 'scalars', { 'train iou': mean_iu, 'val iou': validate_iou, 'val f1': validate_f1, 'val recall': validate_recall, 'val precision': validate_precision }, epoch + 1) if lr_scedule is True: for param_group in optimizer.param_groups: param_group['lr'] = lr * ((1 - epoch / epochs)**0.999) # save model save_folder = '../../saved_models_' + log try: os.makedirs(save_folder) except OSError as exc: if exc.errno != errno.EEXIST: raise pass save_model_name = model_name + '_Final' save_model_name_full = save_folder + '/' + save_model_name + '.pt' torch.save(model, save_model_name_full) # ======================================================================= # testing (disabled during training, because it is too slow) # ======================================================================= save_results_folder = save_folder + '/testing_results_' + model_name try: os.makedirs(save_results_folder) except OSError as exc: if exc.errno != errno.EEXIST: raise pass test_iou_1, test_f1_1, test_recall_1, test_precision_1, mse_1, test_iou_2, test_f1_2, test_recall_2, test_precision_2, mse_2, outputs_1, outputs_2 = test( data_1=test_data_1, data_2=test_data_2, model=model, device=device, class_no=no_class, save_location=save_results_folder) print('test iou data 1: {:.4f}, ' 'test mse data 1: {:.4f}, ' 'test f1 data 1: {:.4f},' 'test recall data 1: {:.4f}, ' 'test precision data 1: {:.4f}, '.format(test_iou_1, mse_1, test_f1_1, test_recall_1, test_precision_1)) print('test iou data 2: {:.4f}, ' 'test mse data 2: {:.4f}, ' 'test f1 data 2: {:.4f},' 'test recall data 2: {:.4f}, ' 'test precision data 2: {:.4f}, '.format(test_iou_2, mse_2, test_f1_2, test_recall_2, test_precision_2)) print('\nTesting finished and results saved.\n') return save_model_name_full
def test1(data_1, model, device, class_no, save_location): model.eval() data_1_testoutputs = [] with torch.no_grad(): f1_1 = 0 test_iou_1 = 0 # test_h_dist_1 = 0 recall_1 = 0 precision_1 = 0 mse_1 = 0 # ============================================== evaluate_index_all_1 = range(0, len(data_1) - 1) # # ============================================== # evaluate_index_all_2 = range(0, len(data_2) - 1) # for j, (testimg, testlabel, testimgname) in enumerate(data_1): # extract a few random indexs every time in a range of the data # ======================================================================== # ======================================================================== testimg = torch.from_numpy(testimg).to(device=device, dtype=torch.float32) testlabel = torch.from_numpy(testlabel).to(device=device, dtype=torch.float32) c, h, w = testimg.size() testimg = testimg.expand(1, c, h, w) testoutput_original = model(testimg) if class_no == 2: testoutput = torch.sigmoid(testoutput_original.view(1, h, w)) testoutput = (testoutput > 0.5).float() data_1_testoutputs.append(testoutput) # else: # _, testoutput = torch.max(testoutput_original, dim=1) # mean_iu_ = intersectionAndUnion(testoutput.cpu().detach(), testlabel.cpu().detach(), class_no) f1_, recall_, precision_ = f1_score(testlabel.cpu().detach().numpy(), testoutput.cpu().detach().numpy(), class_no) mse_ = (np.square(testlabel.cpu().detach().numpy() - testoutput.cpu().detach().numpy())).mean() f1_1 += f1_ test_iou_1 += mean_iu_ recall_1 += recall_ precision_1 += precision_ mse_1 += mse_ # # # Plotting segmentation: # testoutput_original = np.asarray(testoutput_original.cpu().detach().numpy(), dtype=np.uint8) # testoutput_original = np.squeeze(testoutput_original, axis=0) # testoutput_original = np.repeat(testoutput_original[:, :, np.newaxis], 3, axis=2) # # # if class_no == 2: # segmentation_map = np.zeros((h, w, 3), dtype=np.uint8) # # # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 255 # segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 0 # segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 0 # # # else: # segmentation_map = np.zeros((h, w, 3), dtype=np.uint8) # if class_no == 4: # # multi class for brats 2018 # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 255 # segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 0 # segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 0 # # # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 2, testoutput_original[:, :, 1] == 2, testoutput_original[:, :, 2] == 2)] = 0 # segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 2, testoutput_original[:, :, 1] == 2, testoutput_original[:, :, 2] == 2)] = 255 # segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 2, testoutput_original[:, :, 1] == 2, testoutput_original[:, :, 2] == 2)] = 0 # # # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 3, testoutput_original[:, :, 1] == 3, testoutput_original[:, :, 2] == 3)] = 0 # segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 3, testoutput_original[:, :, 1] == 3, testoutput_original[:, :, 2] == 3)] = 0 # segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 3, testoutput_original[:, :, 1] == 3, testoutput_original[:, :, 2] == 3)] = 255 # # # elif class_no == 8: # # multi class for cityscapes # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 0, testoutput_original[:, :, 1] == 0, testoutput_original[:, :, 2] == 0)] = 255 # segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 0, testoutput_original[:, :, 1] == 0, testoutput_original[:, :, 2] == 0)] = 0 # segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 0, testoutput_original[:, :, 1] == 0, testoutput_original[:, :, 2] == 0)] = 0 # # # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 0 # segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 255 # segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 0 # # # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 2, testoutput_original[:, :, 1] == 2, testoutput_original[:, :, 2] == 2)] = 0 # segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 2, testoutput_original[:, :, 1] == 2, testoutput_original[:, :, 2] == 2)] = 0 # segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 2, testoutput_original[:, :, 1] == 2, testoutput_original[:, :, 2] == 2)] = 255 # # # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 3, testoutput_original[:, :, 1] == 3, testoutput_original[:, :, 2] == 3)] = 255 # segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 3, testoutput_original[:, :, 1] == 3, testoutput_original[:, :, 2] == 3)] = 255 # segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 3, testoutput_original[:, :, 1] == 3, testoutput_original[:, :, 2] == 3)] = 0 # # # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 4, testoutput_original[:, :, 1] == 4, testoutput_original[:, :, 2] == 4)] = 153 # segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 4, testoutput_original[:, :, 1] == 4, testoutput_original[:, :, 2] == 4)] = 51 # segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 4, testoutput_original[:, :, 1] == 4, testoutput_original[:, :, 2] == 4)] = 255 # # # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 5, testoutput_original[:, :, 1] == 5, testoutput_original[:, :, 2] == 5)] = 255 # segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 5, testoutput_original[:, :, 1] == 5, testoutput_original[:, :, 2] == 5)] = 102 # segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 5, testoutput_original[:, :, 1] == 5, testoutput_original[:, :, 2] == 5)] = 178 # # # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 6, testoutput_original[:, :, 1] == 6, testoutput_original[:, :, 2] == 6)] = 102 # segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 6, testoutput_original[:, :, 1] == 6, testoutput_original[:, :, 2] == 6)] = 255 # segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 6, testoutput_original[:, :, 1] == 6, testoutput_original[:, :, 2] == 6)] = 102 # # # prediction_name = 'seg_' + test_imagename + '.png' # full_error_map_name = os.path.join(prediction_map_path, prediction_name) # imageio.imsave(full_error_map_name, segmentation_map) # prediction_map_path = save_location + '/' + 'Results_map' # try: os.mkdir(prediction_map_path) except OSError as exc: if exc.errno != errno.EEXIST: raise pass # save numerical results: result_dictionary = {'Test IoU data 1': str(test_iou_1 / len(evaluate_index_all_1)), 'Test f1 data 1': str(f1_1 / len(evaluate_index_all_1)), 'Test recall data 1': str(recall_1 / len(evaluate_index_all_1)), 'Test Precision data 1': str(precision_1 / len(evaluate_index_all_1)), 'Test MSE data 1': str(mse_1 / len(evaluate_index_all_1)) } ff_path = prediction_map_path + '/test_result_data.txt' ff = open(ff_path, 'w') ff.write(str(result_dictionary)) ff.close() return test_iou_1 / len(evaluate_index_all_1), f1_1 / len(evaluate_index_all_1), recall_1 / len(evaluate_index_all_1), precision_1 / len(evaluate_index_all_1), mse_1 / len(evaluate_index_all_1), \ data_1_testoutputs