예제 #1
0
def main():
    # Define training and patches sampling parameters
    num_epochs = 20
    patch_size = 128
    queue_length = 100
    samples_per_volume = 5
    batch_size = 2

    # Populate a list with images
    one_subject = Subject(
        T1=Image('../BRATS2018_crop_renamed/LGG75_T1.nii.gz',
                 torchio.INTENSITY),
        T2=Image('../BRATS2018_crop_renamed/LGG75_T2.nii.gz',
                 torchio.INTENSITY),
        label=Image('../BRATS2018_crop_renamed/LGG75_Label.nii.gz',
                    torchio.LABEL),
    )

    # This subject doesn't have a T2 MRI!
    another_subject = Subject(
        T1=Image('../BRATS2018_crop_renamed/LGG74_T1.nii.gz',
                 torchio.INTENSITY),
        label=Image('../BRATS2018_crop_renamed/LGG74_Label.nii.gz',
                    torchio.LABEL),
    )

    subjects = [
        one_subject,
        another_subject,
    ]

    subjects_dataset = ImagesDataset(subjects)
    queue_dataset = Queue(
        subjects_dataset,
        queue_length,
        samples_per_volume,
        patch_size,
        ImageSampler,
    )

    # This collate_fn is needed in the case of missing modalities
    # In this case, the batch will be composed by a *list* of samples instead of
    # the typical Python dictionary that is collated by default in Pytorch
    batch_loader = DataLoader(
        queue_dataset,
        batch_size=batch_size,
        collate_fn=lambda x: x,
    )

    # Mock PyTorch model
    model = nn.Identity()

    for epoch_index in range(num_epochs):
        for batch in batch_loader:  # batch is a *list* here, not a dictionary
            logits = model(batch)
            print([batch[idx].keys() for idx in range(batch_size)])
    print()
예제 #2
0
 def test_queue(self):
     subjects_dataset = ImagesDataset(self.subjects_list)
     queue_dataset = Queue(
         subjects_dataset,
         max_length=6,
         samples_per_volume=2,
         patch_size=10,
         sampler_class=ImageSampler,
         num_workers=2,
         verbose=True,
     )
     _ = str(queue_dataset)
     batch_loader = DataLoader(queue_dataset, batch_size=4)
     for batch in batch_loader:
         _ = batch['one_modality'][DATA]
         _ = batch['segmentation'][DATA]
예제 #3
0
 def run_queue(self, num_workers, **kwargs):
     subjects_dataset = SubjectsDataset(self.subjects_list)
     patch_size = 10
     sampler = UniformSampler(patch_size)
     queue_dataset = Queue(
         subjects_dataset,
         max_length=6,
         samples_per_volume=2,
         sampler=sampler,
         **kwargs,
     )
     _ = str(queue_dataset)
     batch_loader = DataLoader(queue_dataset, batch_size=4)
     for batch in batch_loader:
         _ = batch['one_modality'][DATA]
         _ = batch['segmentation'][DATA]
예제 #4
0
    another_subject_dict,
]

subjects_dataset = ImagesDataset(subjects_paths, transform=transform)

# Run a benchmark for different numbers of workers
workers = range(mp.cpu_count() + 1)
for num_workers in workers:
    print('Number of workers:', num_workers)

    # Define the dataset as a queue of patches
    queue_dataset = Queue(
        subjects_dataset,
        queue_length,
        samples_per_volume,
        patch_size,
        ImageSampler,
        num_workers=num_workers,
        shuffle_subjects=False,
    )

    # This collate_fn is needed in the case of missing modalities (TODO: elaborate)
    batch_loader = DataLoader(queue_dataset,
                              batch_size=batch_size,
                              collate_fn=lambda x: x)

    start = time.time()
    for epoch_index in range(num_epochs):
        for batch in batch_loader:
            logits = model(batch)
            print([batch[idx].keys() for idx in range(batch_size)])
예제 #5
0
    def set_data_loader(self,
                        train_csv_file='',
                        val_csv_file='',
                        transforms=None,
                        batch_size=1,
                        num_workers=0,
                        par_queue=None,
                        save_to_dir=None,
                        load_from_dir=None,
                        replicate_suj=0,
                        shuffel_train=True,
                        get_condition_csv=None,
                        get_condition_field='',
                        get_condition_nb_wanted=1 / 4,
                        collate_fn=None,
                        add_to_load=None,
                        add_to_load_regexp=None):

        if not isinstance(transforms, torchvision.transforms.transforms.Compose
                          ) and transforms is not None:
            transforms = Compose(transforms)

        if load_from_dir is not None:
            if type(load_from_dir) == str:
                load_from_dir = [load_from_dir, load_from_dir]
            fsample_train, fsample_val = gfile(load_from_dir[0],
                                               'sample.*pt'), gfile(
                                                   load_from_dir[1],
                                                   'sample.*pt')
            #random.shuffle(fsample_train)
            #fsample_train = fsample_train[0:10000]

            if get_condition_csv is not None:
                res = pd.read_csv(load_from_dir[0] + '/' + get_condition_csv)
                cond_val = res[get_condition_field].values

                y = np.linspace(np.min(cond_val), np.max(cond_val), 101)
                nb_wanted_per_interval = int(
                    np.round(len(cond_val) * get_condition_nb_wanted / 100))
                y_select = []
                for i in range(len(y) - 1):
                    indsel = np.where((cond_val > y[i])
                                      & (cond_val < y[i + 1]))[0]
                    nb_select = len(indsel)
                    if nb_select < nb_wanted_per_interval:
                        print(
                            ' only {} / {} for interval {} {:,.3f} |  {:,.3f} '
                            .format(nb_select, nb_wanted_per_interval, i, y[i],
                                    y[i + 1]))
                        y_select.append(indsel)
                    else:
                        pind = np.random.permutation(range(0, nb_select))
                        y_select.append(indsel[pind[0:nb_wanted_per_interval]])
                        #print('{} selecting {}'.format(i, len(y_select[-1])))
                ind_select = np.hstack(y_select)
                y = cond_val[ind_select]
                fsample_train = [fsample_train[ii] for ii in ind_select]
                self.log_string += '\nfinal selection {} soit {:,.3f} % instead of {:,.3f} %'.format(
                    len(y),
                    len(y) / len(cond_val) * 100,
                    get_condition_nb_wanted * 100)

                #conditions = [("MSE", ">", 0.0028),]
                #select_ind = apply_conditions_on_dataset(res,conditions)
                #fsel = [fsample_train[ii] for ii,jj in enumerate(select_ind) if jj]

            self.log_string += '\nloading {} train sample from {}'.format(
                len(fsample_train), load_from_dir[0])
            self.log_string += '\nloading {} val   sample from {}'.format(
                len(fsample_val), load_from_dir[1])
            train_dataset = ImagesDataset(
                fsample_train,
                load_from_dir=load_from_dir[0],
                transform=transforms,
                add_to_load=add_to_load,
                add_to_load_regexp=add_to_load_regexp)
            self.train_csv_load_file_train = fsample_train

            val_dataset = ImagesDataset(fsample_val,
                                        load_from_dir=load_from_dir[1],
                                        transform=transforms,
                                        add_to_load=add_to_load,
                                        add_to_load_regexp=add_to_load_regexp)
            self.train_csv_load_file_train = fsample_val

        else:
            data_parameters = {
                'image': {
                    'csv_file': train_csv_file
                },
            }
            data_parameters_val = {
                'image': {
                    'csv_file': val_csv_file
                },
            }

            paths_dict, info = get_subject_list_and_csv_info_from_data_prameters(
                data_parameters, fpath_idx='filename')
            paths_dict_val, info_val = get_subject_list_and_csv_info_from_data_prameters(
                data_parameters_val, fpath_idx='filename', shuffle_order=False)

            if replicate_suj:
                lll = []
                for i in range(0, replicate_suj):
                    lll.extend(paths_dict)
                paths_dict = lll
                self.log_string += 'Replicating train dataSet {} times, new length is {}'.format(
                    replicate_suj, len(lll))

            train_dataset = ImagesDataset(paths_dict,
                                          transform=transforms,
                                          save_to_dir=save_to_dir)
            val_dataset = ImagesDataset(paths_dict_val,
                                        transform=transforms,
                                        save_to_dir=save_to_dir)

        self.res_name += '_B{}_nw{}'.format(batch_size, num_workers)

        if par_queue is not None:
            self.patch = True
            windows_size = par_queue['windows_size']
            if len(windows_size) == 1:
                windows_size = [
                    windows_size[0], windows_size[0], windows_size[0]
                ]

            train_queue = Queue(train_dataset,
                                par_queue['queue_length'],
                                par_queue['samples_per_volume'],
                                windows_size,
                                ImageSampler,
                                num_workers=num_workers,
                                verbose=self.verbose)

            val_queue = Queue(val_dataset,
                              par_queue['queue_length'],
                              1,
                              windows_size,
                              ImageSampler,
                              num_workers=num_workers,
                              shuffle_subjects=False,
                              shuffle_patches=False,
                              verbose=self.verbose)
            self.res_name += '_spv{}'.format(par_queue['samples_per_volume'])

            self.train_dataloader = DataLoader(train_queue,
                                               batch_size=batch_size,
                                               shuffle=shuffel_train,
                                               collate_fn=collate_fn)
            self.val_dataloader = DataLoader(val_queue,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             collate_fn=collate_fn)

        else:
            self.train_dataset = train_dataset
            self.train_dataloader = DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=shuffel_train,
                                               num_workers=num_workers,
                                               collate_fn=collate_fn)
            self.val_dataloader = DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=num_workers,
                                             collate_fn=collate_fn)
예제 #6
0
        'csv_file': '/data/romain/data_exemple/file_p3.csv'
    },
    'sampler': {
        'csv_file': '/data/romain/data_exemple/file_mask.csv'
    }
}
roi_path = None  #'/data/romain/data_exemple/roi16_weighted.txt'

subject_list, res_info = get_subject_list_and_csv_info_from_data_prameters(
    data_parameters)

train_dataset = ImagesDataset(subject_list, transform=transforms)
train_queue = Queue(train_dataset,
                    queue_length,
                    samples_per_volume,
                    windows_size,
                    ImageSampler,
                    num_workers=num_workers,
                    shuffle_patches=True,
                    verbose=False)

train_dataloader = DataLoader(train_queue, batch_size=batch_size, shuffle=True)

# d=next(iter(train_dataloader))
# d['image'].shape
# d['label'].shape

#loss = BCEWithLogitsLoss()  # sigmoid + bcel
#loss = dice_coef_loss #dice_loss
#loss = dice_loss()
if losstype == 'BCE': loss = tnn.BCELoss()
elif losstype == 'dice': loss = dice_loss(type=1)
예제 #7
0
# This subject doesn't have a T2 MRI!
another_subject = Subject(
    Image('T1', '../BRATS2018_crop_renamed/LGG74_T1.nii.gz', torchio.INTENSITY),
    Image('label', '../BRATS2018_crop_renamed/LGG74_Label.nii.gz', torchio.LABEL),
)

subjects = [
    one_subject,
    another_subject,
]

subjects_dataset = ImagesDataset(subjects)
queue_dataset = Queue(
    subjects_dataset,
    queue_length,
    samples_per_volume,
    patch_size,
    ImageSampler,
)

# This collate_fn is needed in the case of missing modalities
# In this case, the batch will be composed by a *list* of samples instead of
# the typical Python dictionary that is collated by default in Pytorch
batch_loader = DataLoader(
    queue_dataset,
    batch_size=batch_size,
    collate_fn=lambda x: x,
)

# Mock PyTorch model
model = nn.Module()
예제 #8
0
    def __init__(self, images_dir, labels_dir):

        if hp.mode == '3d':
            patch_size = hp.patch_size
        elif hp.mode == '2d':
            patch_size = (hp.patch_size, hp.patch_size, 1)
        else:
            raise Exception('no such kind of mode!')

        queue_length = 5
        samples_per_volume = 5

        self.subjects = []

        if (hp.in_class == 1) and (hp.out_class == 1):

            images_dir = Path(images_dir)
            self.image_paths = sorted(images_dir.glob(hp.fold_arch))
            labels_dir = Path(labels_dir)
            self.label_paths = sorted(labels_dir.glob(hp.fold_arch))

            for (image_path, label_path) in zip(self.image_paths,
                                                self.label_paths):
                subject = tio.Subject(
                    source=tio.ScalarImage(image_path),
                    label=tio.LabelMap(label_path),
                )
                self.subjects.append(subject)
        else:
            images_dir = Path(images_dir)
            self.image_paths = sorted(images_dir.glob(hp.fold_arch))

            artery_labels_dir = Path(labels_dir + '/artery')
            self.artery_label_paths = sorted(
                artery_labels_dir.glob(hp.fold_arch))

            lung_labels_dir = Path(labels_dir + '/lung')
            self.lung_label_paths = sorted(lung_labels_dir.glob(hp.fold_arch))

            trachea_labels_dir = Path(labels_dir + '/trachea')
            self.trachea_label_paths = sorted(
                trachea_labels_dir.glob(hp.fold_arch))

            vein_labels_dir = Path(labels_dir + '/vein')
            self.vein_label_paths = sorted(vein_labels_dir.glob(hp.fold_arch))

            for (image_path, artery_label_path, lung_label_path,
                 trachea_label_path, vein_label_path) in zip(
                     self.image_paths, self.artery_label_paths,
                     self.lung_label_paths, self.trachea_label_paths,
                     self.vein_label_paths):
                subject = tio.Subject(
                    source=tio.ScalarImage(image_path),
                    atery=tio.LabelMap(artery_label_path),
                    lung=tio.LabelMap(lung_label_path),
                    trachea=tio.LabelMap(trachea_label_path),
                    vein=tio.LabelMap(vein_label_path),
                )
                self.subjects.append(subject)

        self.transforms = self.transform()

        self.training_set = tio.SubjectsDataset(self.subjects,
                                                transform=self.transforms)

        self.queue_dataset = Queue(
            self.training_set,
            queue_length,
            samples_per_volume,
            UniformSampler(patch_size),
        )
예제 #9
0
def train(paths_dict, model, transformation, device, save_path, opt):

    since = time.time()

    dataloaders = dict()
    # Define transforms for data normalization and augmentation
    for dataset in DATASETS:
        subjects_dataset_train = ImagesDataset(
            paths_dict[dataset]['training'],
            transform=transformation['training'][dataset])

        subjects_dataset_val = ImagesDataset(
            paths_dict[dataset]['validation'],
            transform=transformation['validation'][dataset])

        # Number of qoekwea
        workers = 10

        # Define the dataset data as a queue of patches
        queue_dataset_train = Queue(
            subjects_dataset_train,
            queue_length,
            samples_per_volume,
            patch_size,
            ImageSampler,
            num_workers=workers,
            shuffle_subjects=True,
            shuffle_patches=True,
        )

        queue_dataset_val = Queue(
            subjects_dataset_val,
            queue_length,
            samples_per_volume,
            patch_size,
            ImageSampler,
            num_workers=workers,
            shuffle_subjects=True,
            shuffle_patches=True,
        )

        batch_loader_dataset_train = infinite_iterable(
            DataLoader(queue_dataset_train, batch_size=batch_size[dataset]))
        batch_loader_dataset_val = infinite_iterable(
            DataLoader(queue_dataset_val, batch_size=batch_size[dataset]))

        dataloaders_dataset = dict()
        dataloaders_dataset['training'] = batch_loader_dataset_train
        dataloaders_dataset['validation'] = batch_loader_dataset_val
        dataloaders[dataset] = dataloaders_dataset

    df_path = os.path.join(opt.model_dir, 'log.csv')
    if os.path.isfile(df_path):
        df = pd.read_csv(df_path)
        epoch = df.iloc[-1]['epoch']
        best_epoch = df.iloc[-1]['best_epoch']

        val_eval_criterion_MA = df.iloc[-1]['MA']
        best_val_eval_criterion_MA = df.iloc[-1]['best_MA']

        initial_lr = df.iloc[-1]['lr']

        model.load_state_dict(torch.load(save_path.format('best')))

    else:
        df = pd.DataFrame(
            columns=['epoch', 'best_epoch', 'MA', 'best_MA', 'lr'])
        val_eval_criterion_MA = None
        best_epoch = 0
        epoch = 0

        initial_lr = opt.learning_rate

    # Optimisation policy
    optimizer = torch.optim.Adam(model.parameters(),
                                 opt.learning_rate,
                                 weight_decay=weight_decay,
                                 amsgrad=True)
    lr_s = lr_scheduler.ReduceLROnPlateau(optimizer,
                                          mode='min',
                                          factor=0.2,
                                          patience=patience_lr,
                                          verbose=True,
                                          threshold=1e-3,
                                          threshold_mode="abs")

    model = model.to(device)

    continue_training = True

    ind_batch_train = np.arange(
        0, samples_per_volume * len(paths_dict['lesion']['training']),
        batch_size['lesion'])
    ind_batch_val = np.arange(
        0, samples_per_volume * len(paths_dict['lesion']['validation']),
        batch_size['lesion'])
    ind_batch = dict()
    ind_batch['training'] = ind_batch_train
    ind_batch['validation'] = ind_batch_val

    while continue_training:
        epoch += 1
        print('-' * 10)
        print('Epoch {}/'.format(epoch))
        for param_group in optimizer.param_groups:
            print("Current learning rate is: {}".format(param_group['lr']))

        # Each epoch has a training and validation phase
        for phase in ['training', 'validation']:
            #for phase in ['validation','training']:
            print(phase)
            if phase == 'training':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_loss_lesion = 0.0
            running_loss_control = 0.0
            running_loss_intermod = 0.0

            epoch_samples = 0

            # Iterate over data
            for _ in tqdm(ind_batch[phase]):
                # Control data (T1)
                batch_control = next(dataloaders['control'][phase])
                labels_control = batch_control['label'][DATA].to(device).type(
                    torch.cuda.IntTensor)
                inputs_control = {
                    mod: batch_control[mod][DATA].to(device)
                    for mod in MODALITIES_CONTROL
                }

                # Lesion data (T1 + full set of modalities)
                batch_lesion = next(dataloaders['lesion'][phase])
                labels_lesion = batch_lesion['label'][DATA].to(device).type(
                    torch.cuda.IntTensor)
                inputs_lesion = {
                    'T1':
                    batch_lesion['T1'][DATA].to(device),
                    'all':
                    torch.cat(
                        [batch_lesion[mod][DATA] for mod in MODALITIES_LESION],
                        1).to(device)
                }
                inputs_lesion_T1 = {
                    mod: batch_lesion[mod][DATA].to(device)
                    for mod in MODALITIES_CONTROL
                }

                # Batch sizes (useful when the batch is not full at the end of the epoch)
                bs_control = inputs_control['T1'].shape[0]
                bs_lesion = inputs_lesion['T1'].shape[0]

                # zero the parameter gradients
                optimizer.zero_grad()

                # track history if only in train
                with torch.set_grad_enabled(phase == 'training'):
                    # Concatenating the 2 predictions made on the T1 for a faster computation
                    input_T1 = {
                        'T1':
                        torch.cat(
                            [inputs_control['T1'], inputs_lesion_T1['T1']], 0)
                    }
                    outputs_T1, _ = model(input_T1)

                    # Deconcatenating the 2 predictions
                    outputs_control = outputs_T1[:bs_control, ...]
                    outputs_lesion_T1 = outputs_T1[bs_control:bs_control +
                                                   bs_lesion, ...]

                    # Prediction using the full set of modalities (lesion data)
                    outputs_lesion, _ = model(inputs_lesion)

                    # Loss computation using the Jaccard distance
                    loss_control = jaccard_tissue(outputs_control,
                                                  labels_control)
                    loss_lesion = jaccard_lesion(outputs_lesion,
                                                 6.0 + labels_lesion)
                    loss_intermod = jaccard_tissue(outputs_lesion,
                                                   outputs_lesion_T1,
                                                   is_prob=True)

                    if epoch > opt.warmup and phase == 'training':
                        loss = loss_control + loss_lesion + loss_intermod
                    else:
                        loss = loss_control + loss_lesion

                    # backward + optimize only if in training phase
                    if phase == 'training':
                        loss.backward()
                        optimizer.step()

                # statistics
                epoch_samples += 1
                running_loss += loss.item()
                running_loss_control += loss_control.item()
                running_loss_lesion += loss_lesion.item()
                running_loss_intermod += loss_intermod.item()

            epoch_loss = running_loss / epoch_samples
            epoch_loss_control = running_loss_control / epoch_samples
            epoch_loss_lesion = running_loss_lesion / epoch_samples
            epoch_loss_intermod = running_loss_intermod / epoch_samples

            print('{}  Loss Seg Control: {:.4f}'.format(
                phase, epoch_loss_control))
            print('{}  Loss Seg Lesion: {:.4f}'.format(phase,
                                                       epoch_loss_lesion))
            print('{}  Loss Seg Intermod: {:.4f}'.format(
                phase, epoch_loss_intermod))

            if phase == 'validation':
                if val_eval_criterion_MA is None:  #first iteration
                    val_eval_criterion_MA = epoch_loss
                    best_val_eval_criterion_MA = val_eval_criterion_MA
                    print(val_eval_criterion_MA)

                else:  #update criterion
                    val_eval_criterion_MA = val_eval_criterion_alpha * val_eval_criterion_MA + (
                        1 - val_eval_criterion_alpha) * epoch_loss
                    print(val_eval_criterion_MA)

                lr_s.step(val_eval_criterion_MA)
                df = df.append(
                    {
                        'epoch': epoch,
                        'best_epoch': best_epoch,
                        'MA': val_eval_criterion_MA,
                        'best_MA': best_val_eval_criterion_MA,
                        'lr': param_group['lr']
                    },
                    ignore_index=True)
                df.to_csv(df_path)

                if val_eval_criterion_MA < best_val_eval_criterion_MA:
                    best_val_eval_criterion_MA = val_eval_criterion_MA
                    best_epoch = epoch
                    torch.save(model.state_dict(), save_path.format('best'))

                else:
                    if epoch - best_epoch > nb_patience:
                        continue_training = False

                if epoch == opt.warmup:
                    torch.save(model.state_dict(),
                               save_path.format(opt.warmup))

    time_elapsed = time.time() - since
    print('Training completed in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))

    print('Best epoch is {}'.format(best_epoch))
예제 #10
0
def train(paths_dict, model, discriminator, transformation, device, save_path,
          opt):

    since = time.time()

    dataloaders = dict()
    # Define transforms for data normalization and augmentation
    for dataset in DATASETS:
        subjects_dataset_train = ImagesDataset(
            paths_dict[dataset]['training'],
            transform=transformation['training'][dataset])

        subjects_dataset_val = ImagesDataset(
            paths_dict[dataset]['validation'],
            transform=transformation['validation'][dataset])

        # Number of qoekwea
        workers = 10

        # Define the dataset data as a queue of patches
        queue_dataset_train = Queue(
            subjects_dataset_train,
            queue_length,
            samples_per_volume,
            patch_size,
            ImageSampler,
            num_workers=workers,
            shuffle_subjects=True,
            shuffle_patches=True,
        )

        queue_dataset_val = Queue(
            subjects_dataset_val,
            queue_length,
            samples_per_volume,
            patch_size,
            ImageSampler,
            num_workers=workers,
            shuffle_subjects=True,
            shuffle_patches=True,
        )

        batch_loader_dataset_train = infinite_iterable(
            DataLoader(queue_dataset_train, batch_size=batch_size[dataset]))
        batch_loader_dataset_val = infinite_iterable(
            DataLoader(queue_dataset_val, batch_size=batch_size[dataset]))

        dataloaders_dataset = dict()
        dataloaders_dataset['training'] = batch_loader_dataset_train
        dataloaders_dataset['validation'] = batch_loader_dataset_val
        dataloaders[dataset] = dataloaders_dataset

    df_path = os.path.join(opt.model_dir, 'log.csv')
    if os.path.isfile(df_path):
        df = pd.read_csv(df_path)
        epoch = df.iloc[-1]['epoch']
        best_epoch = df.iloc[-1]['best_epoch']

        val_eval_criterion_MA = df.iloc[-1]['MA']
        best_val_eval_criterion_MA = df.iloc[-1]['best_MA']

        initial_lr = df.iloc[-1]['lr']

        model.load_state_dict(torch.load(save_path.format('best')))

    else:
        df = pd.DataFrame(
            columns=['epoch', 'best_epoch', 'MA', 'best_MA', 'lr'])
        val_eval_criterion_MA = None
        best_epoch = 0
        epoch = 0

        initial_lr = opt.learning_rate

    # Optimisation policy
    optimizer = torch.optim.Adam(model.parameters(),
                                 opt.learning_rate,
                                 weight_decay=weight_decay,
                                 amsgrad=True)
    lr_s = lr_scheduler.ReduceLROnPlateau(optimizer,
                                          mode='min',
                                          factor=0.2,
                                          patience=patience_lr,
                                          verbose=True,
                                          threshold=1e-3,
                                          threshold_mode="abs")
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(),
                                               lr=1e-4)
    criterion_discriminator = torch.nn.BCEWithLogitsLoss()

    model = model.to(device)
    discriminator = discriminator.to(device)

    continue_training = True

    ind_batch_train = np.arange(
        0, samples_per_volume * len(paths_dict['lesion']['training']),
        batch_size['lesion'])
    ind_batch_val = np.arange(
        0, samples_per_volume * len(paths_dict['lesion']['validation']),
        batch_size['lesion'])
    ind_batch = dict()
    ind_batch['training'] = ind_batch_train
    ind_batch['validation'] = ind_batch_val

    while continue_training:
        epoch += 1
        print('-' * 10)
        print('Epoch {}/'.format(epoch))
        for param_group in optimizer.param_groups:
            print("Current learning rate is: {}".format(param_group['lr']))

        # Each epoch has a training and validation phase
        for phase in ['training', 'validation']:
            #for phase in ['validation','training']:
            print(phase)
            if phase == 'training':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_loss_lesion = 0.0
            running_loss_control = 0.0
            running_loss_intermod = 0.0
            running_loss_discriminator = 0.0
            running_loss_DA = 0.0
            epoch_samples = 0

            # Iterate over data
            for _ in tqdm(ind_batch[phase]):
                batch_control = next(dataloaders['control'][phase])
                labels_control = batch_control['label'][DATA].to(device).type(
                    torch.cuda.IntTensor)
                inputs_control = {
                    mod: batch_control[mod][DATA].to(device)
                    for mod in MODALITIES_CONTROL
                }

                batch_lesion = next(dataloaders['lesion'][phase])
                labels_lesion = batch_lesion['label'][DATA].to(device).type(
                    torch.cuda.IntTensor)
                inputs_lesion = {
                    'T1':
                    batch_lesion['T1'][DATA].to(device),
                    'all':
                    torch.cat(
                        [batch_lesion[mod][DATA] for mod in MODALITIES_LESION],
                        1).to(device)
                }
                #inputs_lesion_ascontrol = {mod:batch_lesion[mod][DATA].to(device) for mod in MODALITIES_CONTROL}

                bs_control = inputs_control['T1'].shape[0]
                bs_lesion = inputs_lesion['T1'].shape[0]
                sum_batch_size = bs_control + bs_lesion
                #bs_lesion_T1 = inputs_lesion_T1['T1'].shape[0]

                # PHASE 1: TRAIN DISCRIMINATOR
                with torch.set_grad_enabled(phase == 'training'):
                    model.eval()
                    if phase == 'training':
                        discriminator.train()  # Set model to training mode
                    else:
                        discriminator.eval()  # Set model to evaluate mode

                    set_requires_grad(model, requires_grad=False)
                    set_requires_grad(discriminator, requires_grad=True)
                    skips = model(
                        {
                            'T1':
                            torch.cat(
                                [inputs_control['T1'], inputs_lesion['T1']], 0)
                        },
                        skip_only=True)

                    inputs_discrimator, z_dim = merge_skips(skips)

                    labels_discriminator = torch.cat((torch.zeros(
                        bs_control * z_dim), torch.ones(bs_lesion * z_dim)),
                                                     0).to(device)

                    outputs_discriminator = discriminator(
                        inputs_discrimator).squeeze()
                    loss_discriminator = criterion_discriminator(
                        outputs_discriminator, labels_discriminator)
                    accuracy_discriminator = ((outputs_discriminator > 0).long(
                    ) == labels_discriminator.long()).float().mean().item()

                    if phase == 'training':
                        optimizer_discriminator.zero_grad()
                        loss_discriminator.backward()
                        optimizer_discriminator.step()

                # PHASE 2: TRAIN SEGMENTER
                with torch.set_grad_enabled(phase == 'training'):
                    discriminator.eval()
                    if phase == 'training':
                        model.train()  # Set model to training mode
                    else:
                        model.eval()  # Set model to evaluate mode

                    set_requires_grad(model, requires_grad=True)
                    set_requires_grad(discriminator, requires_grad=False)

                    outputs_lesion, _ = model(inputs_lesion)
                    outputs, skips = model({
                        'T1':
                        torch.cat([inputs_control['T1'], inputs_lesion['T1']],
                                  0)
                    })

                    inputs_discrimator, _ = merge_skips(skips)

                    outputs_discriminator = discriminator(
                        inputs_discrimator).squeeze()
                    loss_discriminator_inv = criterion_discriminator(
                        outputs_discriminator, labels_discriminator)

                    outputs_control = outputs[:bs_control, ...]
                    outputs_lesion_ascontrol = outputs[bs_control:, ...]

                    loss_control = jaccard_tissue(outputs_control,
                                                  labels_control)
                    loss_lesion = jaccard_lesion(outputs_lesion,
                                                 6.0 + labels_lesion)
                    loss_intermod = jaccard_tissue(outputs_lesion,
                                                   outputs_lesion_ascontrol,
                                                   is_prob=True)

                    if epoch > opt.warmup and phase == 'training':
                        loss = loss_control + loss_lesion + loss_intermod - 0.1 * opt.weight_discri * loss_discriminator_inv
                    elif epoch > opt.warmup_discriminator and phase == 'training':
                        loss = loss_control + loss_lesion - 0.1 * opt.weight_discri * (
                            epoch - opt.warmup_discriminator
                        ) / (opt.warmup -
                             opt.warmup_discriminator) * loss_discriminator_inv
                    else:
                        loss = loss_control + loss_lesion

                    if phase == 'training':
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                # statistics
                epoch_samples += 1
                running_loss += loss.item()
                running_loss_control += loss_control.item()
                running_loss_lesion += loss_lesion.item()
                running_loss_intermod += loss_intermod.item()
                running_loss_discriminator += accuracy_discriminator

            epoch_loss = running_loss / epoch_samples
            epoch_loss_control = running_loss_control / epoch_samples
            epoch_loss_lesion = running_loss_lesion / epoch_samples
            epoch_loss_intermod = running_loss_intermod / epoch_samples
            epoch_loss_discrimator = running_loss_discriminator / epoch_samples

            print('{}  Loss Seg Control: {:.4f}'.format(
                phase, epoch_loss_control))
            print('{}  Loss Seg Lesion: {:.4f}'.format(phase,
                                                       epoch_loss_lesion))
            print('{}  Loss Seg Intermod: {:.4f}'.format(
                phase, epoch_loss_intermod))
            print('{}  Loss Discriminator: {:.4f}'.format(
                phase, epoch_loss_discrimator))

            if phase == 'validation':
                if val_eval_criterion_MA is None:  #first iteration
                    val_eval_criterion_MA = epoch_loss
                    best_val_eval_criterion_MA = val_eval_criterion_MA
                    print(val_eval_criterion_MA)

                else:  #update criterion
                    val_eval_criterion_MA = val_eval_criterion_alpha * val_eval_criterion_MA + (
                        1 - val_eval_criterion_alpha) * epoch_loss
                    print(val_eval_criterion_MA)

                lr_s.step(val_eval_criterion_MA)
                df = df.append(
                    {
                        'epoch': epoch,
                        'best_epoch': best_epoch,
                        'MA': val_eval_criterion_MA,
                        'best_MA': best_val_eval_criterion_MA,
                        'lr': param_group['lr']
                    },
                    ignore_index=True)
                df.to_csv(df_path)

                if val_eval_criterion_MA < best_val_eval_criterion_MA:
                    best_val_eval_criterion_MA = val_eval_criterion_MA
                    best_epoch = epoch
                    torch.save(model.state_dict(), save_path.format('best'))
                    torch.save(discriminator.state_dict(),
                               save_path.format('discriminator_best'))

                else:
                    if epoch - best_epoch > nb_patience:
                        continue_training = False

                if epoch == opt.warmup:
                    torch.save(model.state_dict(), save_path.format(epoch))
                    torch.save(discriminator.state_dict(),
                               save_path.format('discriminator_' + str(epoch)))

    time_elapsed = time.time() - since
    print('Training completed in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))

    print('Best epoch is {}'.format(best_epoch))