def main(): # Define training and patches sampling parameters num_epochs = 20 patch_size = 128 queue_length = 100 samples_per_volume = 5 batch_size = 2 # Populate a list with images one_subject = Subject( T1=Image('../BRATS2018_crop_renamed/LGG75_T1.nii.gz', torchio.INTENSITY), T2=Image('../BRATS2018_crop_renamed/LGG75_T2.nii.gz', torchio.INTENSITY), label=Image('../BRATS2018_crop_renamed/LGG75_Label.nii.gz', torchio.LABEL), ) # This subject doesn't have a T2 MRI! another_subject = Subject( T1=Image('../BRATS2018_crop_renamed/LGG74_T1.nii.gz', torchio.INTENSITY), label=Image('../BRATS2018_crop_renamed/LGG74_Label.nii.gz', torchio.LABEL), ) subjects = [ one_subject, another_subject, ] subjects_dataset = ImagesDataset(subjects) queue_dataset = Queue( subjects_dataset, queue_length, samples_per_volume, patch_size, ImageSampler, ) # This collate_fn is needed in the case of missing modalities # In this case, the batch will be composed by a *list* of samples instead of # the typical Python dictionary that is collated by default in Pytorch batch_loader = DataLoader( queue_dataset, batch_size=batch_size, collate_fn=lambda x: x, ) # Mock PyTorch model model = nn.Identity() for epoch_index in range(num_epochs): for batch in batch_loader: # batch is a *list* here, not a dictionary logits = model(batch) print([batch[idx].keys() for idx in range(batch_size)]) print()
def test_queue(self): subjects_dataset = ImagesDataset(self.subjects_list) queue_dataset = Queue( subjects_dataset, max_length=6, samples_per_volume=2, patch_size=10, sampler_class=ImageSampler, num_workers=2, verbose=True, ) _ = str(queue_dataset) batch_loader = DataLoader(queue_dataset, batch_size=4) for batch in batch_loader: _ = batch['one_modality'][DATA] _ = batch['segmentation'][DATA]
def run_queue(self, num_workers, **kwargs): subjects_dataset = SubjectsDataset(self.subjects_list) patch_size = 10 sampler = UniformSampler(patch_size) queue_dataset = Queue( subjects_dataset, max_length=6, samples_per_volume=2, sampler=sampler, **kwargs, ) _ = str(queue_dataset) batch_loader = DataLoader(queue_dataset, batch_size=4) for batch in batch_loader: _ = batch['one_modality'][DATA] _ = batch['segmentation'][DATA]
another_subject_dict, ] subjects_dataset = ImagesDataset(subjects_paths, transform=transform) # Run a benchmark for different numbers of workers workers = range(mp.cpu_count() + 1) for num_workers in workers: print('Number of workers:', num_workers) # Define the dataset as a queue of patches queue_dataset = Queue( subjects_dataset, queue_length, samples_per_volume, patch_size, ImageSampler, num_workers=num_workers, shuffle_subjects=False, ) # This collate_fn is needed in the case of missing modalities (TODO: elaborate) batch_loader = DataLoader(queue_dataset, batch_size=batch_size, collate_fn=lambda x: x) start = time.time() for epoch_index in range(num_epochs): for batch in batch_loader: logits = model(batch) print([batch[idx].keys() for idx in range(batch_size)])
def set_data_loader(self, train_csv_file='', val_csv_file='', transforms=None, batch_size=1, num_workers=0, par_queue=None, save_to_dir=None, load_from_dir=None, replicate_suj=0, shuffel_train=True, get_condition_csv=None, get_condition_field='', get_condition_nb_wanted=1 / 4, collate_fn=None, add_to_load=None, add_to_load_regexp=None): if not isinstance(transforms, torchvision.transforms.transforms.Compose ) and transforms is not None: transforms = Compose(transforms) if load_from_dir is not None: if type(load_from_dir) == str: load_from_dir = [load_from_dir, load_from_dir] fsample_train, fsample_val = gfile(load_from_dir[0], 'sample.*pt'), gfile( load_from_dir[1], 'sample.*pt') #random.shuffle(fsample_train) #fsample_train = fsample_train[0:10000] if get_condition_csv is not None: res = pd.read_csv(load_from_dir[0] + '/' + get_condition_csv) cond_val = res[get_condition_field].values y = np.linspace(np.min(cond_val), np.max(cond_val), 101) nb_wanted_per_interval = int( np.round(len(cond_val) * get_condition_nb_wanted / 100)) y_select = [] for i in range(len(y) - 1): indsel = np.where((cond_val > y[i]) & (cond_val < y[i + 1]))[0] nb_select = len(indsel) if nb_select < nb_wanted_per_interval: print( ' only {} / {} for interval {} {:,.3f} | {:,.3f} ' .format(nb_select, nb_wanted_per_interval, i, y[i], y[i + 1])) y_select.append(indsel) else: pind = np.random.permutation(range(0, nb_select)) y_select.append(indsel[pind[0:nb_wanted_per_interval]]) #print('{} selecting {}'.format(i, len(y_select[-1]))) ind_select = np.hstack(y_select) y = cond_val[ind_select] fsample_train = [fsample_train[ii] for ii in ind_select] self.log_string += '\nfinal selection {} soit {:,.3f} % instead of {:,.3f} %'.format( len(y), len(y) / len(cond_val) * 100, get_condition_nb_wanted * 100) #conditions = [("MSE", ">", 0.0028),] #select_ind = apply_conditions_on_dataset(res,conditions) #fsel = [fsample_train[ii] for ii,jj in enumerate(select_ind) if jj] self.log_string += '\nloading {} train sample from {}'.format( len(fsample_train), load_from_dir[0]) self.log_string += '\nloading {} val sample from {}'.format( len(fsample_val), load_from_dir[1]) train_dataset = ImagesDataset( fsample_train, load_from_dir=load_from_dir[0], transform=transforms, add_to_load=add_to_load, add_to_load_regexp=add_to_load_regexp) self.train_csv_load_file_train = fsample_train val_dataset = ImagesDataset(fsample_val, load_from_dir=load_from_dir[1], transform=transforms, add_to_load=add_to_load, add_to_load_regexp=add_to_load_regexp) self.train_csv_load_file_train = fsample_val else: data_parameters = { 'image': { 'csv_file': train_csv_file }, } data_parameters_val = { 'image': { 'csv_file': val_csv_file }, } paths_dict, info = get_subject_list_and_csv_info_from_data_prameters( data_parameters, fpath_idx='filename') paths_dict_val, info_val = get_subject_list_and_csv_info_from_data_prameters( data_parameters_val, fpath_idx='filename', shuffle_order=False) if replicate_suj: lll = [] for i in range(0, replicate_suj): lll.extend(paths_dict) paths_dict = lll self.log_string += 'Replicating train dataSet {} times, new length is {}'.format( replicate_suj, len(lll)) train_dataset = ImagesDataset(paths_dict, transform=transforms, save_to_dir=save_to_dir) val_dataset = ImagesDataset(paths_dict_val, transform=transforms, save_to_dir=save_to_dir) self.res_name += '_B{}_nw{}'.format(batch_size, num_workers) if par_queue is not None: self.patch = True windows_size = par_queue['windows_size'] if len(windows_size) == 1: windows_size = [ windows_size[0], windows_size[0], windows_size[0] ] train_queue = Queue(train_dataset, par_queue['queue_length'], par_queue['samples_per_volume'], windows_size, ImageSampler, num_workers=num_workers, verbose=self.verbose) val_queue = Queue(val_dataset, par_queue['queue_length'], 1, windows_size, ImageSampler, num_workers=num_workers, shuffle_subjects=False, shuffle_patches=False, verbose=self.verbose) self.res_name += '_spv{}'.format(par_queue['samples_per_volume']) self.train_dataloader = DataLoader(train_queue, batch_size=batch_size, shuffle=shuffel_train, collate_fn=collate_fn) self.val_dataloader = DataLoader(val_queue, batch_size=batch_size, shuffle=False, collate_fn=collate_fn) else: self.train_dataset = train_dataset self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffel_train, num_workers=num_workers, collate_fn=collate_fn) self.val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn)
'csv_file': '/data/romain/data_exemple/file_p3.csv' }, 'sampler': { 'csv_file': '/data/romain/data_exemple/file_mask.csv' } } roi_path = None #'/data/romain/data_exemple/roi16_weighted.txt' subject_list, res_info = get_subject_list_and_csv_info_from_data_prameters( data_parameters) train_dataset = ImagesDataset(subject_list, transform=transforms) train_queue = Queue(train_dataset, queue_length, samples_per_volume, windows_size, ImageSampler, num_workers=num_workers, shuffle_patches=True, verbose=False) train_dataloader = DataLoader(train_queue, batch_size=batch_size, shuffle=True) # d=next(iter(train_dataloader)) # d['image'].shape # d['label'].shape #loss = BCEWithLogitsLoss() # sigmoid + bcel #loss = dice_coef_loss #dice_loss #loss = dice_loss() if losstype == 'BCE': loss = tnn.BCELoss() elif losstype == 'dice': loss = dice_loss(type=1)
# This subject doesn't have a T2 MRI! another_subject = Subject( Image('T1', '../BRATS2018_crop_renamed/LGG74_T1.nii.gz', torchio.INTENSITY), Image('label', '../BRATS2018_crop_renamed/LGG74_Label.nii.gz', torchio.LABEL), ) subjects = [ one_subject, another_subject, ] subjects_dataset = ImagesDataset(subjects) queue_dataset = Queue( subjects_dataset, queue_length, samples_per_volume, patch_size, ImageSampler, ) # This collate_fn is needed in the case of missing modalities # In this case, the batch will be composed by a *list* of samples instead of # the typical Python dictionary that is collated by default in Pytorch batch_loader = DataLoader( queue_dataset, batch_size=batch_size, collate_fn=lambda x: x, ) # Mock PyTorch model model = nn.Module()
def __init__(self, images_dir, labels_dir): if hp.mode == '3d': patch_size = hp.patch_size elif hp.mode == '2d': patch_size = (hp.patch_size, hp.patch_size, 1) else: raise Exception('no such kind of mode!') queue_length = 5 samples_per_volume = 5 self.subjects = [] if (hp.in_class == 1) and (hp.out_class == 1): images_dir = Path(images_dir) self.image_paths = sorted(images_dir.glob(hp.fold_arch)) labels_dir = Path(labels_dir) self.label_paths = sorted(labels_dir.glob(hp.fold_arch)) for (image_path, label_path) in zip(self.image_paths, self.label_paths): subject = tio.Subject( source=tio.ScalarImage(image_path), label=tio.LabelMap(label_path), ) self.subjects.append(subject) else: images_dir = Path(images_dir) self.image_paths = sorted(images_dir.glob(hp.fold_arch)) artery_labels_dir = Path(labels_dir + '/artery') self.artery_label_paths = sorted( artery_labels_dir.glob(hp.fold_arch)) lung_labels_dir = Path(labels_dir + '/lung') self.lung_label_paths = sorted(lung_labels_dir.glob(hp.fold_arch)) trachea_labels_dir = Path(labels_dir + '/trachea') self.trachea_label_paths = sorted( trachea_labels_dir.glob(hp.fold_arch)) vein_labels_dir = Path(labels_dir + '/vein') self.vein_label_paths = sorted(vein_labels_dir.glob(hp.fold_arch)) for (image_path, artery_label_path, lung_label_path, trachea_label_path, vein_label_path) in zip( self.image_paths, self.artery_label_paths, self.lung_label_paths, self.trachea_label_paths, self.vein_label_paths): subject = tio.Subject( source=tio.ScalarImage(image_path), atery=tio.LabelMap(artery_label_path), lung=tio.LabelMap(lung_label_path), trachea=tio.LabelMap(trachea_label_path), vein=tio.LabelMap(vein_label_path), ) self.subjects.append(subject) self.transforms = self.transform() self.training_set = tio.SubjectsDataset(self.subjects, transform=self.transforms) self.queue_dataset = Queue( self.training_set, queue_length, samples_per_volume, UniformSampler(patch_size), )
def train(paths_dict, model, transformation, device, save_path, opt): since = time.time() dataloaders = dict() # Define transforms for data normalization and augmentation for dataset in DATASETS: subjects_dataset_train = ImagesDataset( paths_dict[dataset]['training'], transform=transformation['training'][dataset]) subjects_dataset_val = ImagesDataset( paths_dict[dataset]['validation'], transform=transformation['validation'][dataset]) # Number of qoekwea workers = 10 # Define the dataset data as a queue of patches queue_dataset_train = Queue( subjects_dataset_train, queue_length, samples_per_volume, patch_size, ImageSampler, num_workers=workers, shuffle_subjects=True, shuffle_patches=True, ) queue_dataset_val = Queue( subjects_dataset_val, queue_length, samples_per_volume, patch_size, ImageSampler, num_workers=workers, shuffle_subjects=True, shuffle_patches=True, ) batch_loader_dataset_train = infinite_iterable( DataLoader(queue_dataset_train, batch_size=batch_size[dataset])) batch_loader_dataset_val = infinite_iterable( DataLoader(queue_dataset_val, batch_size=batch_size[dataset])) dataloaders_dataset = dict() dataloaders_dataset['training'] = batch_loader_dataset_train dataloaders_dataset['validation'] = batch_loader_dataset_val dataloaders[dataset] = dataloaders_dataset df_path = os.path.join(opt.model_dir, 'log.csv') if os.path.isfile(df_path): df = pd.read_csv(df_path) epoch = df.iloc[-1]['epoch'] best_epoch = df.iloc[-1]['best_epoch'] val_eval_criterion_MA = df.iloc[-1]['MA'] best_val_eval_criterion_MA = df.iloc[-1]['best_MA'] initial_lr = df.iloc[-1]['lr'] model.load_state_dict(torch.load(save_path.format('best'))) else: df = pd.DataFrame( columns=['epoch', 'best_epoch', 'MA', 'best_MA', 'lr']) val_eval_criterion_MA = None best_epoch = 0 epoch = 0 initial_lr = opt.learning_rate # Optimisation policy optimizer = torch.optim.Adam(model.parameters(), opt.learning_rate, weight_decay=weight_decay, amsgrad=True) lr_s = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=patience_lr, verbose=True, threshold=1e-3, threshold_mode="abs") model = model.to(device) continue_training = True ind_batch_train = np.arange( 0, samples_per_volume * len(paths_dict['lesion']['training']), batch_size['lesion']) ind_batch_val = np.arange( 0, samples_per_volume * len(paths_dict['lesion']['validation']), batch_size['lesion']) ind_batch = dict() ind_batch['training'] = ind_batch_train ind_batch['validation'] = ind_batch_val while continue_training: epoch += 1 print('-' * 10) print('Epoch {}/'.format(epoch)) for param_group in optimizer.param_groups: print("Current learning rate is: {}".format(param_group['lr'])) # Each epoch has a training and validation phase for phase in ['training', 'validation']: #for phase in ['validation','training']: print(phase) if phase == 'training': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode running_loss = 0.0 running_loss_lesion = 0.0 running_loss_control = 0.0 running_loss_intermod = 0.0 epoch_samples = 0 # Iterate over data for _ in tqdm(ind_batch[phase]): # Control data (T1) batch_control = next(dataloaders['control'][phase]) labels_control = batch_control['label'][DATA].to(device).type( torch.cuda.IntTensor) inputs_control = { mod: batch_control[mod][DATA].to(device) for mod in MODALITIES_CONTROL } # Lesion data (T1 + full set of modalities) batch_lesion = next(dataloaders['lesion'][phase]) labels_lesion = batch_lesion['label'][DATA].to(device).type( torch.cuda.IntTensor) inputs_lesion = { 'T1': batch_lesion['T1'][DATA].to(device), 'all': torch.cat( [batch_lesion[mod][DATA] for mod in MODALITIES_LESION], 1).to(device) } inputs_lesion_T1 = { mod: batch_lesion[mod][DATA].to(device) for mod in MODALITIES_CONTROL } # Batch sizes (useful when the batch is not full at the end of the epoch) bs_control = inputs_control['T1'].shape[0] bs_lesion = inputs_lesion['T1'].shape[0] # zero the parameter gradients optimizer.zero_grad() # track history if only in train with torch.set_grad_enabled(phase == 'training'): # Concatenating the 2 predictions made on the T1 for a faster computation input_T1 = { 'T1': torch.cat( [inputs_control['T1'], inputs_lesion_T1['T1']], 0) } outputs_T1, _ = model(input_T1) # Deconcatenating the 2 predictions outputs_control = outputs_T1[:bs_control, ...] outputs_lesion_T1 = outputs_T1[bs_control:bs_control + bs_lesion, ...] # Prediction using the full set of modalities (lesion data) outputs_lesion, _ = model(inputs_lesion) # Loss computation using the Jaccard distance loss_control = jaccard_tissue(outputs_control, labels_control) loss_lesion = jaccard_lesion(outputs_lesion, 6.0 + labels_lesion) loss_intermod = jaccard_tissue(outputs_lesion, outputs_lesion_T1, is_prob=True) if epoch > opt.warmup and phase == 'training': loss = loss_control + loss_lesion + loss_intermod else: loss = loss_control + loss_lesion # backward + optimize only if in training phase if phase == 'training': loss.backward() optimizer.step() # statistics epoch_samples += 1 running_loss += loss.item() running_loss_control += loss_control.item() running_loss_lesion += loss_lesion.item() running_loss_intermod += loss_intermod.item() epoch_loss = running_loss / epoch_samples epoch_loss_control = running_loss_control / epoch_samples epoch_loss_lesion = running_loss_lesion / epoch_samples epoch_loss_intermod = running_loss_intermod / epoch_samples print('{} Loss Seg Control: {:.4f}'.format( phase, epoch_loss_control)) print('{} Loss Seg Lesion: {:.4f}'.format(phase, epoch_loss_lesion)) print('{} Loss Seg Intermod: {:.4f}'.format( phase, epoch_loss_intermod)) if phase == 'validation': if val_eval_criterion_MA is None: #first iteration val_eval_criterion_MA = epoch_loss best_val_eval_criterion_MA = val_eval_criterion_MA print(val_eval_criterion_MA) else: #update criterion val_eval_criterion_MA = val_eval_criterion_alpha * val_eval_criterion_MA + ( 1 - val_eval_criterion_alpha) * epoch_loss print(val_eval_criterion_MA) lr_s.step(val_eval_criterion_MA) df = df.append( { 'epoch': epoch, 'best_epoch': best_epoch, 'MA': val_eval_criterion_MA, 'best_MA': best_val_eval_criterion_MA, 'lr': param_group['lr'] }, ignore_index=True) df.to_csv(df_path) if val_eval_criterion_MA < best_val_eval_criterion_MA: best_val_eval_criterion_MA = val_eval_criterion_MA best_epoch = epoch torch.save(model.state_dict(), save_path.format('best')) else: if epoch - best_epoch > nb_patience: continue_training = False if epoch == opt.warmup: torch.save(model.state_dict(), save_path.format(opt.warmup)) time_elapsed = time.time() - since print('Training completed in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print('Best epoch is {}'.format(best_epoch))
def train(paths_dict, model, discriminator, transformation, device, save_path, opt): since = time.time() dataloaders = dict() # Define transforms for data normalization and augmentation for dataset in DATASETS: subjects_dataset_train = ImagesDataset( paths_dict[dataset]['training'], transform=transformation['training'][dataset]) subjects_dataset_val = ImagesDataset( paths_dict[dataset]['validation'], transform=transformation['validation'][dataset]) # Number of qoekwea workers = 10 # Define the dataset data as a queue of patches queue_dataset_train = Queue( subjects_dataset_train, queue_length, samples_per_volume, patch_size, ImageSampler, num_workers=workers, shuffle_subjects=True, shuffle_patches=True, ) queue_dataset_val = Queue( subjects_dataset_val, queue_length, samples_per_volume, patch_size, ImageSampler, num_workers=workers, shuffle_subjects=True, shuffle_patches=True, ) batch_loader_dataset_train = infinite_iterable( DataLoader(queue_dataset_train, batch_size=batch_size[dataset])) batch_loader_dataset_val = infinite_iterable( DataLoader(queue_dataset_val, batch_size=batch_size[dataset])) dataloaders_dataset = dict() dataloaders_dataset['training'] = batch_loader_dataset_train dataloaders_dataset['validation'] = batch_loader_dataset_val dataloaders[dataset] = dataloaders_dataset df_path = os.path.join(opt.model_dir, 'log.csv') if os.path.isfile(df_path): df = pd.read_csv(df_path) epoch = df.iloc[-1]['epoch'] best_epoch = df.iloc[-1]['best_epoch'] val_eval_criterion_MA = df.iloc[-1]['MA'] best_val_eval_criterion_MA = df.iloc[-1]['best_MA'] initial_lr = df.iloc[-1]['lr'] model.load_state_dict(torch.load(save_path.format('best'))) else: df = pd.DataFrame( columns=['epoch', 'best_epoch', 'MA', 'best_MA', 'lr']) val_eval_criterion_MA = None best_epoch = 0 epoch = 0 initial_lr = opt.learning_rate # Optimisation policy optimizer = torch.optim.Adam(model.parameters(), opt.learning_rate, weight_decay=weight_decay, amsgrad=True) lr_s = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=patience_lr, verbose=True, threshold=1e-3, threshold_mode="abs") optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=1e-4) criterion_discriminator = torch.nn.BCEWithLogitsLoss() model = model.to(device) discriminator = discriminator.to(device) continue_training = True ind_batch_train = np.arange( 0, samples_per_volume * len(paths_dict['lesion']['training']), batch_size['lesion']) ind_batch_val = np.arange( 0, samples_per_volume * len(paths_dict['lesion']['validation']), batch_size['lesion']) ind_batch = dict() ind_batch['training'] = ind_batch_train ind_batch['validation'] = ind_batch_val while continue_training: epoch += 1 print('-' * 10) print('Epoch {}/'.format(epoch)) for param_group in optimizer.param_groups: print("Current learning rate is: {}".format(param_group['lr'])) # Each epoch has a training and validation phase for phase in ['training', 'validation']: #for phase in ['validation','training']: print(phase) if phase == 'training': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode running_loss = 0.0 running_loss_lesion = 0.0 running_loss_control = 0.0 running_loss_intermod = 0.0 running_loss_discriminator = 0.0 running_loss_DA = 0.0 epoch_samples = 0 # Iterate over data for _ in tqdm(ind_batch[phase]): batch_control = next(dataloaders['control'][phase]) labels_control = batch_control['label'][DATA].to(device).type( torch.cuda.IntTensor) inputs_control = { mod: batch_control[mod][DATA].to(device) for mod in MODALITIES_CONTROL } batch_lesion = next(dataloaders['lesion'][phase]) labels_lesion = batch_lesion['label'][DATA].to(device).type( torch.cuda.IntTensor) inputs_lesion = { 'T1': batch_lesion['T1'][DATA].to(device), 'all': torch.cat( [batch_lesion[mod][DATA] for mod in MODALITIES_LESION], 1).to(device) } #inputs_lesion_ascontrol = {mod:batch_lesion[mod][DATA].to(device) for mod in MODALITIES_CONTROL} bs_control = inputs_control['T1'].shape[0] bs_lesion = inputs_lesion['T1'].shape[0] sum_batch_size = bs_control + bs_lesion #bs_lesion_T1 = inputs_lesion_T1['T1'].shape[0] # PHASE 1: TRAIN DISCRIMINATOR with torch.set_grad_enabled(phase == 'training'): model.eval() if phase == 'training': discriminator.train() # Set model to training mode else: discriminator.eval() # Set model to evaluate mode set_requires_grad(model, requires_grad=False) set_requires_grad(discriminator, requires_grad=True) skips = model( { 'T1': torch.cat( [inputs_control['T1'], inputs_lesion['T1']], 0) }, skip_only=True) inputs_discrimator, z_dim = merge_skips(skips) labels_discriminator = torch.cat((torch.zeros( bs_control * z_dim), torch.ones(bs_lesion * z_dim)), 0).to(device) outputs_discriminator = discriminator( inputs_discrimator).squeeze() loss_discriminator = criterion_discriminator( outputs_discriminator, labels_discriminator) accuracy_discriminator = ((outputs_discriminator > 0).long( ) == labels_discriminator.long()).float().mean().item() if phase == 'training': optimizer_discriminator.zero_grad() loss_discriminator.backward() optimizer_discriminator.step() # PHASE 2: TRAIN SEGMENTER with torch.set_grad_enabled(phase == 'training'): discriminator.eval() if phase == 'training': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode set_requires_grad(model, requires_grad=True) set_requires_grad(discriminator, requires_grad=False) outputs_lesion, _ = model(inputs_lesion) outputs, skips = model({ 'T1': torch.cat([inputs_control['T1'], inputs_lesion['T1']], 0) }) inputs_discrimator, _ = merge_skips(skips) outputs_discriminator = discriminator( inputs_discrimator).squeeze() loss_discriminator_inv = criterion_discriminator( outputs_discriminator, labels_discriminator) outputs_control = outputs[:bs_control, ...] outputs_lesion_ascontrol = outputs[bs_control:, ...] loss_control = jaccard_tissue(outputs_control, labels_control) loss_lesion = jaccard_lesion(outputs_lesion, 6.0 + labels_lesion) loss_intermod = jaccard_tissue(outputs_lesion, outputs_lesion_ascontrol, is_prob=True) if epoch > opt.warmup and phase == 'training': loss = loss_control + loss_lesion + loss_intermod - 0.1 * opt.weight_discri * loss_discriminator_inv elif epoch > opt.warmup_discriminator and phase == 'training': loss = loss_control + loss_lesion - 0.1 * opt.weight_discri * ( epoch - opt.warmup_discriminator ) / (opt.warmup - opt.warmup_discriminator) * loss_discriminator_inv else: loss = loss_control + loss_lesion if phase == 'training': optimizer.zero_grad() loss.backward() optimizer.step() # statistics epoch_samples += 1 running_loss += loss.item() running_loss_control += loss_control.item() running_loss_lesion += loss_lesion.item() running_loss_intermod += loss_intermod.item() running_loss_discriminator += accuracy_discriminator epoch_loss = running_loss / epoch_samples epoch_loss_control = running_loss_control / epoch_samples epoch_loss_lesion = running_loss_lesion / epoch_samples epoch_loss_intermod = running_loss_intermod / epoch_samples epoch_loss_discrimator = running_loss_discriminator / epoch_samples print('{} Loss Seg Control: {:.4f}'.format( phase, epoch_loss_control)) print('{} Loss Seg Lesion: {:.4f}'.format(phase, epoch_loss_lesion)) print('{} Loss Seg Intermod: {:.4f}'.format( phase, epoch_loss_intermod)) print('{} Loss Discriminator: {:.4f}'.format( phase, epoch_loss_discrimator)) if phase == 'validation': if val_eval_criterion_MA is None: #first iteration val_eval_criterion_MA = epoch_loss best_val_eval_criterion_MA = val_eval_criterion_MA print(val_eval_criterion_MA) else: #update criterion val_eval_criterion_MA = val_eval_criterion_alpha * val_eval_criterion_MA + ( 1 - val_eval_criterion_alpha) * epoch_loss print(val_eval_criterion_MA) lr_s.step(val_eval_criterion_MA) df = df.append( { 'epoch': epoch, 'best_epoch': best_epoch, 'MA': val_eval_criterion_MA, 'best_MA': best_val_eval_criterion_MA, 'lr': param_group['lr'] }, ignore_index=True) df.to_csv(df_path) if val_eval_criterion_MA < best_val_eval_criterion_MA: best_val_eval_criterion_MA = val_eval_criterion_MA best_epoch = epoch torch.save(model.state_dict(), save_path.format('best')) torch.save(discriminator.state_dict(), save_path.format('discriminator_best')) else: if epoch - best_epoch > nb_patience: continue_training = False if epoch == opt.warmup: torch.save(model.state_dict(), save_path.format(epoch)) torch.save(discriminator.state_dict(), save_path.format('discriminator_' + str(epoch))) time_elapsed = time.time() - since print('Training completed in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print('Best epoch is {}'.format(best_epoch))