segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 0 segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 1, testoutput_original[:, :, 1] == 1, testoutput_original[:, :, 2] == 1)] = 0 # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 2, testoutput_original[:, :, 1] == 2, testoutput_original[:, :, 2] == 2)] = 0 segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 2, testoutput_original[:, :, 1] == 2, testoutput_original[:, :, 2] == 2)] = 255 segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 2, testoutput_original[:, :, 1] == 2, testoutput_original[:, :, 2] == 2)] = 0 # segmentation_map[:, :, 0][np.logical_and(testoutput_original[:, :, 0] == 3, testoutput_original[:, :, 1] == 3, testoutput_original[:, :, 2] == 3)] = 0 segmentation_map[:, :, 1][np.logical_and(testoutput_original[:, :, 0] == 3, testoutput_original[:, :, 1] == 3, testoutput_original[:, :, 2] == 3)] = 0 segmentation_map[:, :, 2][np.logical_and(testoutput_original[:, :, 0] == 3, testoutput_original[:, :, 1] == 3, testoutput_original[:, :, 2] == 3)] = 255 imageio.imsave(save_name_label, segmentation_map) # if __name__ == '__main__': # # This is to run segmentation of trained our models # To use it, change: # 1. model_path # 2. test_pat: path to testing data model_path = '../../saved_model.pt' # data path: test_path = '../../data_folder' dataset_tag = 'brats' label_mode = 'multi' test_data = CustomDataset_punet(dataset_location=test_path, dataset_tag=dataset_tag, noisylabel=label_mode, augmentation=False) segmentation(model_name='Unet_CMs', model_path=model_path, testdata=test_data, class_no=4, data_set=dataset_tag)
def train_punet(epochs, iteration, train_batch_size, lr, num_filters, input_channels, latent_dim, no_conv_fcomb, num_classes, beta, test_samples_no, dataset_path, dataset_tag): """ This is the panel to control the training of baseline Probabilistic U-net. Args: input_dim: channel number of input image, for example, 3 for RGB class_no: number of classes of classification repeat: repat the same experiments with different stochastic seeds, we normally run each experiment at least 3 times train_batchsize: training batch size, this depends on the GPU memory validate_batchsize: we normally set-up as 1 num_epochs: training epoch length learning_rate: input_height: resolution of input image input_width: resolution of input image alpha: regularisation strength hyper-parameter width: channel number of first encoder in the segmentation network, for the standard U-net, it is 64 depth: down-sampling stages of the segmentation network data_path: path to where you store your all of your data dataset_tag: 'mnist' for MNIST; 'brats' for BRATS 2018; 'lidc' for LIDC lung data set label_mode: 'multi' for multi-class of proposed method; 'p_unet' for baseline probabilistic u-net; 'normal' for binary on MNIST; 'binary' for general binary segmentation loss_f: 'noisy_label' for our noisy label function, or 'dice' for dice loss save_probability_map: if True, we save all of the probability maps of output of networks Returns: """ for itr in range(iteration): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') train_path = dataset_path + '/train' validate_path = dataset_path + '/validate' test_path = dataset_path + '/test' dataset_train = CustomDataset_punet(dataset_location=train_path, dataset_tag=dataset_tag, noisylabel='p_unet', augmentation=True) dataset_val = CustomDataset_punet(dataset_location=validate_path, dataset_tag=dataset_tag, noisylabel='multi', augmentation=False) dataset_test = CustomDataset_punet(dataset_location=test_path, dataset_tag=dataset_tag, noisylabel='multi', augmentation=False) # dataset_size = len(dataset_train) # indices = list(range(dataset_size)) # split = int(np.floor(0.1 * dataset_size)) # np.random.shuffle(indices) # train_indices, test_indices = indices[split:], indices[:split] # train_sampler = SubsetRandomSampler(train_indices) # test_sampler = SubsetRandomSampler(test_indices) # print("Number of training/test patches:", (len(train_indices),len(test_indices))) train_loader = DataLoader(dataset_train, batch_size=train_batch_size, shuffle=True, num_workers=4, drop_last=True) val_loader = DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=1, drop_last=False) test_loader = DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=1, drop_last=False) # net = ProbabilisticUnet(input_channels=3, num_classes=1, num_filters=[8, 16, 32, 64], latent_dim=4, no_convs_fcomb=2, beta=10) net = ProbabilisticUnet(input_channels=input_channels, num_classes=num_classes, num_filters=num_filters, latent_dim=latent_dim, no_convs_fcomb=no_conv_fcomb, beta=beta) net.to(device) optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=1e-5) # epochs = 100 training_iterations = len(dataset_train) // train_batch_size - 1 for epoch in range(epochs): # net.train() # for step, (patch, mask, mask_name) in enumerate(train_loader): # # mask_list = [mask_over, mask_under, mask_wrong, mask_true] # mask = random.choice(mask_list) # print(np.unique(mask)) # patch = patch.to(device) mask = mask.to(device) # mask = torch.unsqueeze(mask,1) net.forward(patch, mask, training=True) elbo, reconstruction, kl = net.elbo(mask) # reg_loss = l2_regularisation(net.posterior) + l2_regularisation(net.prior) + l2_regularisation(net.fcomb.layers) # loss = -elbo + 1e-5 * reg_loss loss = -elbo optimizer.zero_grad() loss.backward() optimizer.step() # epoch_noisy_labels = [] epoch_noisy_segs = [] # if (step + 1) == training_iterations: # validate_iou = 0 generalized_energy_distance_epoch = 0 # validate_iou, generalized_energy_distance_epoch = evaluate_punet( net=net, val_data=val_loader, class_no=num_classes, sampling_no=4) print('epoch:' + str(epoch)) print('val dice: ' + str(validate_iou)) print('val generalized_energy: ' + str(generalized_energy_distance_epoch)) print('train loss: ' + str(loss.item())) print('kl is: ' + str(kl.item())) print('reconstruction loss is: ' + str(reconstruction.item())) print('\n') # print('\n') # save_path = '../Exp_Results_PUnet' # try: # os.mkdir(save_path) # except OSError as exc: # if exc.errno != errno.EEXIST: # raise # pass # save_path = save_path + '/Exp_' + str(itr) + \ '_punet_' + \ '_train_batch_' + str(train_batch_size) + \ '_latent_dim_' + str(latent_dim) + \ '_lr_' + str(lr) + \ '_epochs_' + str(epochs) + \ '_beta_' + str(beta) + \ '_test_sample_no_' + str(test_samples_no) # test_punet(net=net, testdata=test_loader, save_path=save_path, sampling_times=test_samples_no) # print('Training is finished.')
def getData(train_batchsize, validate_batchsize, data_path, dataset_tag, label_mode): # # train_image_folder = data_directory + '/' + dataset_name + '/' + \ # dataset_tag + '/train/patches' # train_label_folder = data_directory + '/' + dataset_name + '/' + \ # dataset_tag + '/train/labels' # validate_image_folder = data_directory + '/' + dataset_name + '/' + \ # dataset_tag + '/validate/patches' # validate_label_folder = data_directory + '/' + dataset_name + '/' + \ # dataset_tag + '/validate/labels' # test_image_folder = data_directory + '/' + dataset_name + '/' + \ # dataset_tag + '/test/patches' # test_label_folder = data_directory + '/' + dataset_name + '/' + \ # dataset_tag + '/test/labels' # # dataset_tag = 'mnist # noisylabel= 'multi # train_path = data_path + '/train' validate_path = data_path + '/validate' test_path = data_path + '/test' # train_dataset = CustomDataset_punet(dataset_location=train_path, dataset_tag=dataset_tag, noisylabel=label_mode, augmentation=True) validate_dataset = CustomDataset_punet(dataset_location=validate_path, dataset_tag=dataset_tag, noisylabel=label_mode, augmentation=False) test_dataset = CustomDataset_punet(dataset_location=test_path, dataset_tag=dataset_tag, noisylabel=label_mode, augmentation=False) # # train_dataset = CustomDataset(train_image_folder, train_label_folder, data_augment) # # # validate_dataset = CustomDataset(validate_image_folder, validate_label_folder, data_augment) # # # test_dataset = CustomDataset(test_image_folder, test_label_folder, 'none') # trainloader = data.DataLoader(train_dataset, batch_size=train_batchsize, shuffle=True, num_workers=5, drop_last=True) # validateloader = data.DataLoader(validate_dataset, batch_size=validate_batchsize, shuffle=False, num_workers=validate_batchsize, drop_last=False) # testloader = data.DataLoader(test_dataset, batch_size=validate_batchsize, shuffle=False, num_workers=validate_batchsize, drop_last=False) # return trainloader, validateloader, testloader, len(train_dataset)
def train_punet(epochs, iteration, train_batch_size, lr, num_filters, input_channels, latent_dim, no_conv_fcomb, num_classes, beta, test_samples_no, dataset_path, dataset_tag): for itr in range(iteration): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') train_path = dataset_path + '/train' validate_path = dataset_path + '/validate' test_path = dataset_path + '/test' dataset_train = CustomDataset_punet(dataset_location=train_path, dataset_tag=dataset_tag, noisylabel='p_unet', augmentation=True) dataset_val = CustomDataset_punet(dataset_location=validate_path, dataset_tag=dataset_tag, noisylabel='multi', augmentation=False) dataset_test = CustomDataset_punet(dataset_location=test_path, dataset_tag=dataset_tag, noisylabel='multi', augmentation=False) # dataset_size = len(dataset_train) # indices = list(range(dataset_size)) # split = int(np.floor(0.1 * dataset_size)) # np.random.shuffle(indices) # train_indices, test_indices = indices[split:], indices[:split] # train_sampler = SubsetRandomSampler(train_indices) # test_sampler = SubsetRandomSampler(test_indices) # print("Number of training/test patches:", (len(train_indices),len(test_indices))) train_loader = DataLoader(dataset_train, batch_size=train_batch_size, shuffle=True, num_workers=4, drop_last=True) val_loader = DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=1, drop_last=False) test_loader = DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=1, drop_last=False) # net = ProbabilisticUnet(input_channels=3, num_classes=1, num_filters=[8, 16, 32, 64], latent_dim=4, no_convs_fcomb=2, beta=10) net = ProbabilisticUnet(input_channels=input_channels, num_classes=num_classes, num_filters=num_filters, latent_dim=latent_dim, no_convs_fcomb=no_conv_fcomb, beta=beta) net.to(device) optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=1e-5) # epochs = 100 training_iterations = len(dataset_train) // train_batch_size - 1 for epoch in range(epochs): # net.train() # for step, (patch, mask, mask_name) in enumerate(train_loader): # # mask_list = [mask_over, mask_under, mask_wrong, mask_true] # mask = random.choice(mask_list) # print(np.unique(mask)) # patch = patch.to(device) mask = mask.to(device) # mask = torch.unsqueeze(mask,1) net.forward(patch, mask, training=True) elbo, reconstruction, kl = net.elbo(mask) # reg_loss = l2_regularisation(net.posterior) + l2_regularisation(net.prior) + l2_regularisation(net.fcomb.layers) # loss = -elbo + 1e-5 * reg_loss loss = -elbo optimizer.zero_grad() loss.backward() optimizer.step() # epoch_noisy_labels = [] epoch_noisy_segs = [] # if (step + 1) == training_iterations: # validate_iou = 0 generalized_energy_distance_epoch = 0 # validate_iou, generalized_energy_distance_epoch = evaluate_punet( net=net, val_data=val_loader, class_no=num_classes, sampling_no=4) print('epoch:' + str(epoch)) print('val dice: ' + str(validate_iou)) print('val generalized_energy: ' + str(generalized_energy_distance_epoch)) print('train loss: ' + str(loss.item())) print('kl is: ' + str(kl.item())) print('reconstruction loss is: ' + str(reconstruction.item())) print('\n') # print('\n') # save_path = '../Exp_Results_PUnet' # try: # os.mkdir(save_path) # except OSError as exc: # if exc.errno != errno.EEXIST: # raise # pass # save_path = save_path + '/Exp_' + str(itr) + \ '_punet_' + \ '_train_batch_' + str(train_batch_size) + \ '_latent_dim_' + str(latent_dim) + \ '_lr_' + str(lr) + \ '_epochs_' + str(epochs) + \ '_beta_' + str(beta) + \ '_test_sample_no_' + str(test_samples_no) # test_punet(net=net, testdata=test_loader, save_path=save_path, sampling_times=test_samples_no) # print('Training is finished.')