コード例 #1
0
ファイル: main.py プロジェクト: zzy950117/domainadaptation
def linked_batch_augmentation(input_batch, preds_unsup):

    # Teach transformation
    teacher_transform = tv.transforms.Compose([
        mt_transforms.ToPIL(labeled=False),
        mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0),
                                       sigma_range=(3.5, 4.0),
                                       p=0.3,
                                       labeled=False),
        mt_transforms.RandomAffine(degrees=4.6,
                                   scale=(0.98, 1.02),
                                   translate=(0.03, 0.03),
                                   labeled=False),
        mt_transforms.RandomTensorChannelShift((-0.10, 0.10)),
        mt_transforms.ToTensor(labeled=False),
    ])

    input_batch_size = input_batch.size(0)

    input_batch_cpu = input_batch.cpu().detach()
    input_batch_cpu = input_batch_cpu.numpy()

    preds_unsup_cpu = preds_unsup.cpu().detach()
    preds_unsup_cpu = preds_unsup_cpu.numpy()

    samples_linked_aug = []
    for sample_idx in range(input_batch_size):
        sample_linked_aug = {
            'input':
            [input_batch_cpu[sample_idx], preds_unsup_cpu[sample_idx]]
        }
        out = teacher_transform(sample_linked_aug)
        samples_linked_aug.append(out)

    samples_linked_aug = mt_datasets.mt_collate(samples_linked_aug)
    return samples_linked_aug
コード例 #2
0
def run_main():
    train_transform = transforms.Compose([
        mt_transforms.CenterCrop2D((200, 200)),
        mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0),
                                       sigma_range=(3.5, 4.0),
                                       p=0.3),
        mt_transforms.RandomAffine(degrees=4.6,
                                   scale=(0.98, 1.02),
                                   translate=(0.03, 0.03)),
        mt_transforms.RandomTensorChannelShift((-0.10, 0.10)),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    val_transform = transforms.Compose([
        mt_transforms.CenterCrop2D((200, 200)),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    # Here we assume that the SC GM Challenge data is inside the folder
    # "../data" and it was previously resampled.
    gmdataset_train = mt_datasets.SCGMChallenge2DTrain(
        root_dir="../data",
        subj_ids=range(1, 9),
        transform=train_transform,
        slice_filter_fn=mt_filters.SliceFilter())

    # Here we assume that the SC GM Challenge data is inside the folder
    # "../data" and it was previously resampled.
    gmdataset_val = mt_datasets.SCGMChallenge2DTrain(root_dir="../data",
                                                     subj_ids=range(9, 11),
                                                     transform=val_transform)

    train_loader = DataLoader(gmdataset_train,
                              batch_size=16,
                              shuffle=True,
                              pin_memory=True,
                              collate_fn=mt_datasets.mt_collate,
                              num_workers=1)

    val_loader = DataLoader(gmdataset_val,
                            batch_size=16,
                            shuffle=True,
                            pin_memory=True,
                            collate_fn=mt_datasets.mt_collate,
                            num_workers=1)

    model = mt_models.Unet(drop_rate=0.4, bn_momentum=0.1)
    model.cuda()

    num_epochs = 200
    initial_lr = 0.001

    optimizer = optim.Adam(model.parameters(), lr=initial_lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)

    writer = SummaryWriter(log_dir="log_exp")
    for epoch in tqdm(range(1, num_epochs + 1)):
        start_time = time.time()

        scheduler.step()

        lr = scheduler.get_lr()[0]
        writer.add_scalar('learning_rate', lr, epoch)

        model.train()
        train_loss_total = 0.0
        num_steps = 0
        for i, batch in enumerate(train_loader):
            input_samples, gt_samples = batch["input"], batch["gt"]

            var_input = input_samples.cuda()
            var_gt = gt_samples.cuda(non_blocking=True)

            preds = model(var_input)

            loss = mt_losses.dice_loss(preds, var_gt)
            train_loss_total += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            num_steps += 1

            if epoch % 5 == 0:
                grid_img = vutils.make_grid(input_samples,
                                            normalize=True,
                                            scale_each=True)
                writer.add_image('Input', grid_img, epoch)

                grid_img = vutils.make_grid(preds.data.cpu(),
                                            normalize=True,
                                            scale_each=True)
                writer.add_image('Predictions', grid_img, epoch)

                grid_img = vutils.make_grid(gt_samples,
                                            normalize=True,
                                            scale_each=True)
                writer.add_image('Ground Truth', grid_img, epoch)

        train_loss_total_avg = train_loss_total / num_steps

        model.eval()
        val_loss_total = 0.0
        num_steps = 0

        metric_fns = [
            mt_metrics.dice_score, mt_metrics.hausdorff_score,
            mt_metrics.precision_score, mt_metrics.recall_score,
            mt_metrics.specificity_score, mt_metrics.intersection_over_union,
            mt_metrics.accuracy_score
        ]

        metric_mgr = mt_metrics.MetricManager(metric_fns)

        for i, batch in enumerate(val_loader):
            input_samples, gt_samples = batch["input"], batch["gt"]

            with torch.no_grad():
                var_input = input_samples.cuda()
                var_gt = gt_samples.cuda(async=True)

                preds = model(var_input)
                loss = mt_losses.dice_loss(preds, var_gt)
                val_loss_total += loss.item()

            # Metrics computation
            gt_npy = gt_samples.numpy().astype(np.uint8)
            gt_npy = gt_npy.squeeze(axis=1)

            preds = preds.data.cpu().numpy()
            preds = threshold_predictions(preds)
            preds = preds.astype(np.uint8)
            preds = preds.squeeze(axis=1)

            metric_mgr(preds, gt_npy)

            num_steps += 1

        metrics_dict = metric_mgr.get_results()
        metric_mgr.reset()

        writer.add_scalars('metrics', metrics_dict, epoch)

        val_loss_total_avg = val_loss_total / num_steps

        writer.add_scalars('losses', {
            'val_loss': val_loss_total_avg,
            'train_loss': train_loss_total_avg
        }, epoch)

        end_time = time.time()
        total_time = end_time - start_time
        tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time))

        writer.add_scalars('losses', {'train_loss': train_loss_total_avg},
                           epoch)
コード例 #3
0
def cmd_train(context):
    """Main command do train the network.
    :param context: this is a dictionary with all data from the
                    configuration file:
                        - 'command': run the specified command (e.g. train, test)
                        - 'gpu': ID of the used GPU
                        - 'bids_path_train': list of relative paths of the BIDS folders of each training center
                        - 'bids_path_validation': list of relative paths of the BIDS folders of each validation center
                        - 'bids_path_test': list of relative paths of the BIDS folders of each test center
                        - 'batch_size'
                        - 'dropout_rate'
                        - 'batch_norm_momentum'
                        - 'num_epochs'
                        - 'initial_lr': initial learning rate
                        - 'log_directory': folder name where log files are saved
    """
    # Set the GPU
    gpu_number = context["gpu"]
    torch.cuda.set_device(gpu_number)

    # These are the training transformations
    train_transform = transforms.Compose([
        mt_transforms.CenterCrop2D((128, 128)),
        mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0),
                                       sigma_range=(3.5, 4.0),
                                       p=0.3),
        mt_transforms.RandomAffine(degrees=4.6,
                                   scale=(0.98, 1.02),
                                   translate=(0.03, 0.03)),
        mt_transforms.RandomTensorChannelShift((-0.10, 0.10)),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    # These are the validation/testing transformations
    val_transform = transforms.Compose([
        mt_transforms.CenterCrop2D((128, 128)),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    # This code will iterate over the folders and load the data, filtering
    # the slices without labels and then concatenating all the datasets together
    train_datasets = []
    for bids_ds in tqdm(context["bids_path_train"],
                        desc="Loading training set"):
        ds_train = loader.BidsDataset(bids_ds,
                                      transform=train_transform,
                                      slice_filter_fn=loader.SliceFilter())
        train_datasets.append(ds_train)

    ds_train = ConcatDataset(train_datasets)
    print(f"Loaded {len(ds_train)} axial slices for the training set.")
    train_loader = DataLoader(ds_train,
                              batch_size=context["batch_size"],
                              shuffle=True,
                              pin_memory=True,
                              collate_fn=mt_datasets.mt_collate,
                              num_workers=1)

    # Validation dataset ------------------------------------------------------
    validation_datasets = []
    for bids_ds in tqdm(context["bids_path_validation"],
                        desc="Loading validation set"):
        ds_val = loader.BidsDataset(bids_ds,
                                    transform=val_transform,
                                    slice_filter_fn=loader.SliceFilter())
        validation_datasets.append(ds_val)

    ds_val = ConcatDataset(validation_datasets)
    print(f"Loaded {len(ds_val)} axial slices for the validation set.")
    val_loader = DataLoader(ds_val,
                            batch_size=context["batch_size"],
                            shuffle=True,
                            pin_memory=True,
                            collate_fn=mt_datasets.mt_collate,
                            num_workers=1)

    model = M.Classifier(drop_rate=context["dropout_rate"],
                         bn_momentum=context["batch_norm_momentum"])
    model.cuda()

    num_epochs = context["num_epochs"]
    initial_lr = context["initial_lr"]

    # Using SGD with cosine annealing learning rate
    optimizer = optim.SGD(model.parameters(), lr=initial_lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)

    # Write the metrics, images, etc to TensorBoard format
    writer = SummaryWriter(log_dir=context["log_directory"])

    # Cross Entropy Loss
    criterion = nn.CrossEntropyLoss()

    # Training loop -----------------------------------------------------------
    best_validation_loss = float("inf")

    lst_train_loss = []
    lst_val_loss = []
    lst_accuracy = []

    for epoch in tqdm(range(1, num_epochs + 1), desc="Training"):
        start_time = time.time()

        scheduler.step()

        lr = scheduler.get_lr()[0]
        writer.add_scalar('learning_rate', lr, epoch)

        model.train()
        train_loss_total = 0.0
        num_steps = 0

        for i, batch in enumerate(train_loader):
            input_samples = batch["input"]
            input_labels = get_modality(batch)

            var_input = input_samples.cuda()
            var_labels = torch.cuda.LongTensor(input_labels).cuda(
                non_blocking=True)

            outputs = model(var_input)

            loss = criterion(outputs, var_labels)
            train_loss_total += loss.item()

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
            num_steps += 1

        train_loss_total_avg = train_loss_total / num_steps
        lst_train_loss.append(train_loss_total_avg)

        tqdm.write(f"Epoch {epoch} training loss: {train_loss_total_avg:.4f}.")

        # Validation loop -----------------------------------------------------
        model.eval()
        val_loss_total = 0.0
        num_steps = 0

        val_accuracy = 0
        num_samples = 0

        for i, batch in enumerate(val_loader):
            input_samples = batch["input"]
            input_labels = get_modality(batch)

            with torch.no_grad():
                var_input = input_samples.cuda()
                var_labels = torch.cuda.LongTensor(input_labels).cuda(
                    non_blocking=True)

                outputs = model(var_input)
                _, preds = torch.max(outputs, 1)

                loss = criterion(outputs, var_labels)
                val_loss_total += loss.item()

                val_accuracy += int((var_labels == preds).sum())

            num_steps += 1
            num_samples += context['batch_size']

        val_loss_total_avg = val_loss_total / num_steps
        lst_val_loss.append(val_loss_total_avg)
        tqdm.write(f"Epoch {epoch} validation loss: {val_loss_total_avg:.4f}.")

        val_accuracy_avg = 100 * val_accuracy / num_samples
        lst_accuracy.append(val_accuracy_avg)
        tqdm.write(f"Epoch {epoch} accuracy : {val_accuracy_avg:.4f}.")

        # add metrics for tensorboard
        writer.add_scalars('validation metrics', {
            'accuracy': val_accuracy_avg,
        }, epoch)

        writer.add_scalars('losses', {
            'train_loss': train_loss_total_avg,
            'val_loss': val_loss_total_avg,
        }, epoch)

        end_time = time.time()
        total_time = end_time - start_time
        tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time))

        if val_loss_total_avg < best_validation_loss:
            best_validation_loss = val_loss_total_avg
            torch.save(model,
                       "./" + context["log_directory"] + "/best_model.pt")

    # save final model
    torch.save(model, "./" + context["log_directory"] + "/final_model.pt")

    # save the metrics
    parameters = "CrossEntropyLoss/batchsize=" + str(context['batch_size'])
    parameters += "/initial_lr=" + str(context['initial_lr'])
    parameters += "/dropout=" + str(context['dropout_rate'])

    plt.subplot(2, 1, 1)
    plt.title(parameters)
    plt.plot(lst_train_loss, color='red', label='Training')
    plt.plot(lst_val_loss, color='blue', label='Validation')
    plt.legend(loc='upper right')
    plt.ylabel('Loss')

    plt.subplot(2, 1, 2)
    plt.plot(lst_accuracy)
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')

    plt.savefig(parameters + '.png')

    return
コード例 #4
0
import numpy as np
from batchgenerators.transforms.color_transforms import GammaTransform
from batchgenerators.transforms.spatial_transforms import MirrorTransform, SpatialTransform, ZoomTransform
from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
from medicaltorch import transforms as mt_transforms
from medicaltorch import losses as mt_losses
from torchvision import transforms

train_transform = transforms.Compose([
    # mt_transforms.CenterCrop2D((200, 200)),
    mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0),
                                   sigma_range=(3.5, 4.0),
                                   p=0.3),
    mt_transforms.RandomAffine(degrees=4.6,
                               scale=(0.98, 1.02),
                               translate=(0.03, 0.03)),
    mt_transforms.RandomTensorChannelShift((-0.10, 0.10)),
    mt_transforms.ToTensor()
    # mt_transforms.NormalizeInstance(),
])

gamma_t = GammaTransform(data_key="img", gamma_range=(0.1, 10))

mirror_t = MirrorTransform(data_key="img", label_key="seg")

spatial_t = SpatialTransform(patch_size=(8, 8, 8),
                             data_key="img",
                             label_key="seg")

gauss_noise_t = GaussianNoiseTransform(data_key="img", noise_variance=(0, 1))
コード例 #5
0
    def __init__(self,
                 tcia_folder=None,
                 brats_folder=None,
                 lgg_folder=None,
                 hgg_folder=None,
                 type_str='T1',
                 stage='lgg',
                 flag_3d=False,
                 mode='train',
                 channel_size_3d=32,
                 mri_slice_dim=128,
                 aug=False):

        assert stage in ['lgg', 'hgg']
        assert type_str in ['T1', 'T2', 'FLAIR']

        Dataset.__init__(self)

        self.flag_3d = flag_3d

        self.tcia_folder = tcia_folder
        self.brats_folder = brats_folder
        self.lgg_folder = lgg_folder
        self.hgg_folder = hgg_folder

        if stage == 'lgg':
            self.folders = glob.glob(self.tcia_folder + '/*')
            self.dataset_types = ['tcia' for _ in range(len(self.folders))]

            brats = glob.glob(self.brats_folder + '/LGG/*')
            self.folders.extend(brats)
            self.dataset_types.extend(['brats' for _ in range(len(brats))])

            lgg = [
                x for x in glob.glob(self.lgg_folder + '/*')
                if os.path.isdir(x)
            ]
            self.folders.extend(lgg)
            self.dataset_types.extend(['lgg' for _ in range(len(lgg))])
        else:
            self.folders = glob.glob(self.brats_folder + '/HGG/*')
            self.dataset_types = ['brats' for _ in range(len(self.folders))]

            hgg = [
                x for x in glob.glob(self.hgg_folder + '/*')
                if os.path.isdir(x)
            ]
            self.folders.extend(hgg)
            self.dataset_types.extend(['hgg' for _ in range(len(hgg))])

        self.type = type_str
        self.stage = stage
        self.channel_size_3d = channel_size_3d

        self.seg_mapping = {
            'tcia': 'Segmentation',
            'brats': 'seg',
            'lgg': ['GlistrBoost_ManuallyCorrected', 'GlistrBoost*'],
            'hgg': ['GlistrBoost_ManuallyCorrected', 'GlistrBoost*']
        }

        self.type_mapping = {
            'tcia': self.type + '*',
            'brats': self.type.lower(),
            'lgg': self.type.lower(),
            'hgg': self.type.lower()
        }

        self.segmentation_pairs = []
        for idx in range(len(self.folders)):

            if isinstance(self.seg_mapping[self.dataset_types[idx]], list):
                try:
                    seg_fname = glob.glob(
                        self.folders[idx] + '/*' +
                        self.seg_mapping[self.dataset_types[idx]][0] +
                        '.nii.gz')[0]
                except Exception:
                    seg_fname = glob.glob(
                        self.folders[idx] + '/*' +
                        self.seg_mapping[self.dataset_types[idx]][1] +
                        '.nii.gz')[0]
            else:
                seg_fname = glob.glob(
                    self.folders[idx] + '/*' +
                    self.seg_mapping[self.dataset_types[idx]] + '.nii.gz')[0]

            vox_fname_list = glob.glob(
                self.folders[idx] + '/*' +
                self.type_mapping[self.dataset_types[idx]] + '.nii.gz')

            if vox_fname_list == []:
                continue
            else:
                vox_fname = vox_fname_list[0]

            self.segmentation_pairs.append([vox_fname, seg_fname])

        spl = [.8, .1, .1]

        train_ptr = int(spl[0] * len(self.segmentation_pairs))
        val_ptr = train_ptr + int(spl[1] * len(self.segmentation_pairs))

        if not flag_3d:
            if aug:
                train_transforms = transforms.Compose([
                    MTResize((mri_slice_dim, mri_slice_dim)),
                    transforms.RandomChoice([
                        mt_transforms.RandomRotation(30),
                        mt_transforms.ElasticTransform(alpha=2000, sigma=50),
                        mt_transforms.AdditiveGaussianNoise(mean=0.05,
                                                            std=0.01),
                        mt_transforms.RandomAffine(degrees,
                                                   translate=0.2,
                                                   scale=(0.8, 1.2),
                                                   shear=0.2)
                    ]),
                    mt_transforms.ToTensor(),
                    MTNormalize()
                ])

            else:
                train_transforms = transforms.Compose([
                    MTResize((mri_slice_dim, mri_slice_dim)),
                    mt_transforms.ToTensor(),
                    MTNormalize()
                ])

            val_transforms = transforms.Compose([
                transforms.Resize((mri_slice_dim, mri_slice_dim)),
                mt_transforms.ToTensor(),
                MTNormalize()
            ])

            train_unnormalized = train_transforms

        else:

            if aug:
                train_transforms = transforms.Compose([
                    ToPILImage3D(),
                    Resize3D((mri_slice_dim, mri_slice_dim)),
                    transforms.RandomChoice([
                        RandomHorizontalFlip3D(),
                        RandomVerticalFlip3D(),
                        RandomRotation3D(30),
                        RandomShear3D(45,
                                      translate=.4,
                                      scale=(.7, 1.3),
                                      shear=.2)
                    ]),
                    ToTensor3D(),
                    Normalize3D('min_max')
                ])

                train_unnormalized = transforms.Compose([
                    ToPILImage3D(),
                    Resize3D((mri_slice_dim, mri_slice_dim)),
                    transforms.RandomChoice([
                        RandomHorizontalFlip3D(),
                        RandomVerticalFlip3D(),
                        RandomRotation3D(30)
                    ]),
                    ToTensor3D(),
                ])
            else:
                train_transforms = transforms.Compose([
                    ToPILImage3D(),
                    Resize3D((mri_slice_dim, mri_slice_dim)),
                    ToTensor3D(),
                    Normalize3D('min_max')
                ])

                train_unnormalized = transforms.Compose([
                    ToPILImage3D(),
                    Resize3D((mri_slice_dim, mri_slice_dim)),
                    ToTensor3D(),
                ])

            val_transforms = transforms.Compose([
                ToPILImage3D(),
                Resize3D((mri_slice_dim, mri_slice_dim)),
                ToTensor3D(),
                IndividualNormalize3D(),
            ])

        if mode == 'train':
            self.segmentation_pairs = self.segmentation_pairs[:train_ptr]
            self.transforms = train_transforms
            self.seg_transforms = train_unnormalized
        elif mode == 'val':
            self.segmentation_pairs = self.segmentation_pairs[
                train_ptr:val_ptr]
            self.transforms = val_transforms
            self.seg_transforms = train_unnormalized
        else:
            self.segmentation_pairs = self.segmentation_pairs[val_ptr:]
            self.transforms = val_transforms
            self.seg_transforms = train_unnormalized

        if not flag_3d:
            self.twod_slices_dataset = mt_datasets.MRI2DSegmentationDataset(
                self.segmentation_pairs, transform=self.transforms)
コード例 #6
0
from medicaltorch import transforms as mt_transforms
from medicaltorch import losses as mt_losses
from torchvision import transforms

packed_transforms = [
    mt_transforms.RandomRotation(degrees=(90, 180)),
    mt_transforms.ElasticTransform(),
    mt_transforms.AdditiveGaussianNoise(mean=0.0, std=0.05),
    mt_transforms.RandomAffine(),
    mt_transforms.ToTensor()
]
コード例 #7
0
ファイル: main.py プロジェクト: zzy950117/domainadaptation
def cmd_train(ctx):
    global_step = 0

    num_workers = ctx["num_workers"]
    num_epochs = ctx["num_epochs"]
    experiment_name = ctx["experiment_name"]
    cons_weight = ctx["cons_weight"]
    initial_lr = ctx["initial_lr"]
    consistency_rampup = ctx["consistency_rampup"]
    weight_decay = ctx["weight_decay"]
    rootdir_gmchallenge_train = ctx["rootdir_gmchallenge_train"]
    rootdir_gmchallenge_test = ctx["rootdir_gmchallenge_test"]
    supervised_only = ctx["supervised_only"]
    """
    experiment_name
    """
    # experiment_name += '-e%s-cw%s-lr%s-cr%s-lramp%s-wd%s-cl%s-sc%s-ac%s-vc%s' % \
    #                     (num_epochs, cons_weight, initial_lr, consistency_rampup,
    #                     ctx["initial_lr_rampup"], weight_decay, ctx["consistency_loss"],
    #                     ctx["source_centers"], ctx["adapt_centers"], ctx["val_centers"])

    # Decay for learning rate
    if "constant" in ctx["decay_lr"]:
        decay_lr_fn = decay_constant_lr

    if "poly" in ctx["decay_lr"]:
        decay_lr_fn = decay_poly_lr

    if "cosine" in ctx["decay_lr"]:
        decay_lr_fn = cosine_lr

    # Consistency loss
    #
    # mse = Mean Squared Error
    # dice = Dice loss
    # cross_entropy = Cross Entropy
    # mse_confidence = MSE with Confidence Threshold
    if ctx["consistency_loss"] == "dice":
        consistency_loss_fn = mt_losses.dice_loss
    if ctx["consistency_loss"] == "mse":
        consistency_loss_fn = F.mse_loss
    if ctx["consistency_loss"] == "cross_entropy":
        consistency_loss_fn = F.binary_cross_entropy
    if ctx["consistency_loss"] == "mse_confident":
        confidence_threshold = ctx["confidence_threshold"]
        consistency_loss_fn = mt_losses.ConfidentMSELoss(confidence_threshold)

    # Xs, Ys = Source input and source label, train
    # Xt1, Xt2 = Target, domain adaptation, no label, different aug (same sample), train
    # Xv, Yv = Target input and target label, validation

    # Sample Xs and Ys from this
    source_train = mt_datasets.SCGMChallenge2DTrain(
        rootdir_gmchallenge_train,
        slice_filter_fn=mt_filters.SliceFilter(),
        site_ids=ctx["source_centers"],  # Test = 1,2,3, train = 1,2
        subj_ids=range(1, 11))

    # Sample Xt1, Xt2 from this
    unlabeled_filter = mt_filters.SliceFilter(filter_empty_mask=False)
    target_adapt_train = mt_datasets.SCGMChallenge2DTest(
        rootdir_gmchallenge_test,
        slice_filter_fn=unlabeled_filter,
        site_ids=ctx["adapt_centers"],  # 3 = train, 4 = test
        subj_ids=range(11, 21))

    # Sample Xv, Yv from this
    validation_centers = []
    for center in ctx["val_centers"]:
        validation_centers.append(
            mt_datasets.SCGMChallenge2DTrain(
                rootdir_gmchallenge_train,
                slice_filter_fn=mt_filters.SliceFilter(),
                site_ids=[center],  # 3 = train, 4 = test
                subj_ids=range(1, 11)))

    # Training source data augmentation
    source_transform = tv.transforms.Compose([
        mt_transforms.CenterCrop2D((200, 200)),
        mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0),
                                       sigma_range=(3.5, 4.0),
                                       p=0.3),
        mt_transforms.RandomAffine(degrees=4.6,
                                   scale=(0.98, 1.02),
                                   translate=(0.03, 0.03)),
        mt_transforms.RandomTensorChannelShift((-0.10, 0.10)),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    # Target adaptation data augmentation
    target_adapt_transform = tv.transforms.Compose([
        mt_transforms.CenterCrop2D((200, 200), labeled=False),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    # Target adaptation data augmentation
    target_val_adapt_transform = tv.transforms.Compose([
        mt_transforms.CenterCrop2D((200, 200)),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    source_train.set_transform(source_transform)

    target_adapt_train.set_transform(target_adapt_transform)
    for center in validation_centers:
        center.set_transform(target_val_adapt_transform)

    source_train_loader = DataLoader(source_train,
                                     batch_size=ctx["source_batch_size"],
                                     shuffle=True,
                                     drop_last=True,
                                     num_workers=num_workers,
                                     collate_fn=mt_datasets.mt_collate,
                                     pin_memory=True)

    target_adapt_train_loader = DataLoader(target_adapt_train,
                                           batch_size=ctx["target_batch_size"],
                                           shuffle=True,
                                           drop_last=True,
                                           num_workers=num_workers,
                                           collate_fn=mt_datasets.mt_collate,
                                           pin_memory=True)

    validation_centers_loaders = []
    for center in validation_centers:
        validation_centers_loaders.append(
            DataLoader(center,
                       batch_size=ctx["target_batch_size"],
                       shuffle=False,
                       drop_last=False,
                       num_workers=num_workers,
                       collate_fn=mt_datasets.mt_collate,
                       pin_memory=True))

    model = create_model(ctx)

    if not supervised_only:
        model_ema = create_model(ctx, ema=True)
    else:
        model_ema = None

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=initial_lr,
                                 weight_decay=weight_decay)

    writer = SummaryWriter(log_dir="log_{}".format(experiment_name))

    # Training loop
    for epoch in tqdm(range(1, num_epochs + 1), desc="Epochs"):
        start_time = time.time()

        # Rampup -----
        initial_lr_rampup = ctx["initial_lr_rampup"]

        if initial_lr_rampup > 0:
            if epoch <= initial_lr_rampup:
                lr = initial_lr * sigmoid_rampup(epoch, initial_lr_rampup)
            else:
                lr = decay_lr_fn(epoch - initial_lr_rampup,
                                 num_epochs - initial_lr_rampup, initial_lr)
        else:
            lr = decay_lr_fn(epoch, num_epochs, initial_lr)

        writer.add_scalar('learning_rate', lr, epoch)

        for param_group in optimizer.param_groups:
            tqdm.write("Learning Rate: {:.6f}".format(lr))
            param_group['lr'] = lr

        consistency_weight = get_current_consistency_weight(
            cons_weight, epoch, consistency_rampup)
        writer.add_scalar('consistency_weight', consistency_weight, epoch)

        # Train mode
        model.train()

        if not supervised_only:
            model_ema.train()

        composite_loss_total = 0.0
        class_loss_total = 0.0
        consistency_loss_total = 0.0

        num_steps = 0
        target_adapt_train_iter = iter(target_adapt_train_loader)

        for i, train_batch in enumerate(source_train_loader):
            # Keys: 'input', 'gt', 'input_metadata', 'gt_metadata'

            # Supervised component --------------------------------------------
            train_input, train_gt = train_batch["input"], train_batch["gt"]
            train_input = train_input.cuda()
            train_gt = train_gt.cuda(async=True)
            preds_supervised = model(train_input)
            class_loss = mt_losses.dice_loss(preds_supervised, train_gt)

            if not supervised_only:

                # Unsupervised component ------------------------------------------
                try:
                    target_adapt_batch = target_adapt_train_iter.next()
                except StopIteration:
                    target_adapt_train_iter = iter(target_adapt_train_loader)
                    target_adapt_batch = target_adapt_train_iter.next()

                target_adapt_input = target_adapt_batch["input"]
                target_adapt_input = target_adapt_input.cuda()

                # Teacher forward
                with torch.no_grad():
                    teacher_preds_unsup = model_ema(target_adapt_input)

                linked_aug_batch = \
                    linked_batch_augmentation(target_adapt_input, teacher_preds_unsup)

                adapt_input_batch = linked_aug_batch['input'][0].cuda()
                teacher_preds_unsup_aug = linked_aug_batch['input'][1].cuda()

                # Student forward
                student_preds_unsup = model(adapt_input_batch)

                consistency_loss = consistency_weight * consistency_loss_fn(
                    student_preds_unsup, teacher_preds_unsup_aug)
            else:
                consistency_loss = torch.FloatTensor([0.]).cuda()

            composite_loss = class_loss + consistency_loss

            optimizer.zero_grad()
            composite_loss.backward()

            optimizer.step()

            composite_loss_total += composite_loss.item()
            consistency_loss_total += consistency_loss.item()
            class_loss_total += class_loss.item()

            num_steps += 1
            global_step += 1

            if not supervised_only:
                if epoch <= ctx["ema_late_epoch"]:
                    update_ema_variables(model, model_ema, ctx["ema_alpha"],
                                         global_step)
                else:
                    update_ema_variables(model, model_ema,
                                         ctx["ema_alpha_late"], global_step)

        # Write histogram of the probs
        if not supervised_only:
            npy_teacher_preds = teacher_preds_unsup.detach().cpu().numpy()
            writer.add_histogram("Teacher Preds Hist", npy_teacher_preds,
                                 epoch)

            npy_student_preds = student_preds_unsup.detach().cpu().numpy()
            writer.add_histogram("Student Preds Hist", npy_student_preds,
                                 epoch)

        npy_supervised_preds = preds_supervised.detach().cpu().numpy()
        writer.add_histogram("Supervised Preds Hist", npy_supervised_preds,
                             epoch)

        composite_loss_avg = composite_loss_total / num_steps
        class_loss_avg = class_loss_total / num_steps
        consistency_loss_avg = consistency_loss_total / num_steps

        tqdm.write("Steps p/ Epoch: {}".format(num_steps))
        tqdm.write("Consistency Weight: {:.6f}".format(consistency_weight))
        tqdm.write("Composite Loss: {:.6f}".format(composite_loss_avg))
        tqdm.write("Class Loss: {:.6f}".format(class_loss_avg))
        tqdm.write("Consistency Loss: {:.6f}".format(consistency_loss_avg))

        # Write sample images
        if ctx["write_images"] and epoch % ctx["write_images_interval"] == 0:
            try:
                plot_img = vutils.make_grid(preds_supervised,
                                            normalize=True,
                                            scale_each=True)
                writer.add_image('Train Source Prediction', plot_img, epoch)

                plot_img = vutils.make_grid(train_input,
                                            normalize=True,
                                            scale_each=True)
                writer.add_image('Train Source Input', plot_img, epoch)

                plot_img = vutils.make_grid(train_gt,
                                            normalize=True,
                                            scale_each=True)
                writer.add_image('Train Source Ground Truth', plot_img, epoch)

                # Unsupervised component viz
                if not supervised_only:
                    plot_img = vutils.make_grid(target_adapt_input,
                                                normalize=True,
                                                scale_each=True)
                    writer.add_image('Train Target Student Input', plot_img,
                                     epoch)

                    plot_img = vutils.make_grid(teacher_preds_unsup,
                                                normalize=True,
                                                scale_each=True)
                    writer.add_image('Train Target Student Preds', plot_img,
                                     epoch)

                    plot_img = vutils.make_grid(adapt_input_batch,
                                                normalize=True,
                                                scale_each=True)
                    writer.add_image('Train Target Teacher Input', plot_img,
                                     epoch)

                    plot_img = vutils.make_grid(student_preds_unsup,
                                                normalize=True,
                                                scale_each=True)
                    writer.add_image('Train Target Teacher Preds', plot_img,
                                     epoch)

                    plot_img = vutils.make_grid(student_preds_unsup,
                                                normalize=True,
                                                scale_each=True)
                    writer.add_image('Train Target Student Preds (augmented)',
                                     plot_img, epoch)
            except:
                tqdm.write("*** Error writing images ***")

        writer.add_scalars(
            'losses', {
                'composite_loss': composite_loss_avg,
                'class_loss': class_loss_avg,
                'consistency_loss': consistency_loss_avg
            }, epoch)

        # Evaluation mode
        model.eval()

        if not supervised_only:
            model_ema.eval()

        metric_fns = [
            mt_metrics.dice_score, mt_metrics.jaccard_score,
            mt_metrics.hausdorff_score, mt_metrics.precision_score,
            mt_metrics.recall_score, mt_metrics.specificity_score,
            mt_metrics.intersection_over_union, mt_metrics.accuracy_score
        ]

        for center, loader in enumerate(validation_centers_loaders):
            validation(model, model_ema, loader, writer, metric_fns, epoch,
                       ctx, 'val_%s' % ctx["val_centers"][center])

        end_time = time.time()
        total_time = end_time - start_time
        tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time))