Exemplo n.º 1
0
    def test_gmchallenge_dataset(self):
        composed_transform = transforms.Compose([
            mt_transforms.CenterCrop2D((200, 200)),
            mt_transforms.ToTensor(),
        ])

        dataset = mt_datasets.SCGMChallenge2D(root_dir=ROOT_DIR_GMCHALLENGE,
                                              transform=composed_transform)
        assert len(dataset) == 2204

        dataset = mt_datasets.SCGMChallenge2D(root_dir=ROOT_DIR_GMCHALLENGE,
                                              rater_ids=[4, ], subj_ids=[1, 2],
                                              transform=composed_transform)
        assert len(dataset) == 107

        dataloader = DataLoader(dataset, batch_size=4,
                                shuffle=True, num_workers=4,
                                collate_fn=mt_datasets.mt_collate)
        minibatch = next(iter(dataloader))
        assert len(minibatch) == 4
        assert minibatch['input'].size() == (4, 1, 200, 200)

        iterations = 0
        for minbatch in dataloader:
            iterations += 1
        assert iterations == 27
Exemplo n.º 2
0
def train_hr_transform(crop_size):
    return Compose(
        [mt_transforms.CenterCrop2D(crop_size),
         mt_transforms.ToTensor()])  #Compose([
Exemplo n.º 3
0
def run_main():
    train_transform = transforms.Compose([
        mt_transforms.CenterCrop2D((200, 200)),
        mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0),
                                       sigma_range=(3.5, 4.0),
                                       p=0.3),
        mt_transforms.RandomAffine(degrees=4.6,
                                   scale=(0.98, 1.02),
                                   translate=(0.03, 0.03)),
        mt_transforms.RandomTensorChannelShift((-0.10, 0.10)),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    val_transform = transforms.Compose([
        mt_transforms.CenterCrop2D((200, 200)),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    # Here we assume that the SC GM Challenge data is inside the folder
    # "../data" and it was previously resampled.
    gmdataset_train = mt_datasets.SCGMChallenge2DTrain(
        root_dir="../data",
        subj_ids=range(1, 9),
        transform=train_transform,
        slice_filter_fn=mt_filters.SliceFilter())

    # Here we assume that the SC GM Challenge data is inside the folder
    # "../data" and it was previously resampled.
    gmdataset_val = mt_datasets.SCGMChallenge2DTrain(root_dir="../data",
                                                     subj_ids=range(9, 11),
                                                     transform=val_transform)

    train_loader = DataLoader(gmdataset_train,
                              batch_size=16,
                              shuffle=True,
                              pin_memory=True,
                              collate_fn=mt_datasets.mt_collate,
                              num_workers=1)

    val_loader = DataLoader(gmdataset_val,
                            batch_size=16,
                            shuffle=True,
                            pin_memory=True,
                            collate_fn=mt_datasets.mt_collate,
                            num_workers=1)

    model = mt_models.Unet(drop_rate=0.4, bn_momentum=0.1)
    model.cuda()

    num_epochs = 200
    initial_lr = 0.001

    optimizer = optim.Adam(model.parameters(), lr=initial_lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)

    writer = SummaryWriter(log_dir="log_exp")
    for epoch in tqdm(range(1, num_epochs + 1)):
        start_time = time.time()

        scheduler.step()

        lr = scheduler.get_lr()[0]
        writer.add_scalar('learning_rate', lr, epoch)

        model.train()
        train_loss_total = 0.0
        num_steps = 0
        for i, batch in enumerate(train_loader):
            input_samples, gt_samples = batch["input"], batch["gt"]

            var_input = input_samples.cuda()
            var_gt = gt_samples.cuda(non_blocking=True)

            preds = model(var_input)

            loss = mt_losses.dice_loss(preds, var_gt)
            train_loss_total += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            num_steps += 1

            if epoch % 5 == 0:
                grid_img = vutils.make_grid(input_samples,
                                            normalize=True,
                                            scale_each=True)
                writer.add_image('Input', grid_img, epoch)

                grid_img = vutils.make_grid(preds.data.cpu(),
                                            normalize=True,
                                            scale_each=True)
                writer.add_image('Predictions', grid_img, epoch)

                grid_img = vutils.make_grid(gt_samples,
                                            normalize=True,
                                            scale_each=True)
                writer.add_image('Ground Truth', grid_img, epoch)

        train_loss_total_avg = train_loss_total / num_steps

        model.eval()
        val_loss_total = 0.0
        num_steps = 0

        metric_fns = [
            mt_metrics.dice_score, mt_metrics.hausdorff_score,
            mt_metrics.precision_score, mt_metrics.recall_score,
            mt_metrics.specificity_score, mt_metrics.intersection_over_union,
            mt_metrics.accuracy_score
        ]

        metric_mgr = mt_metrics.MetricManager(metric_fns)

        for i, batch in enumerate(val_loader):
            input_samples, gt_samples = batch["input"], batch["gt"]

            with torch.no_grad():
                var_input = input_samples.cuda()
                var_gt = gt_samples.cuda(async=True)

                preds = model(var_input)
                loss = mt_losses.dice_loss(preds, var_gt)
                val_loss_total += loss.item()

            # Metrics computation
            gt_npy = gt_samples.numpy().astype(np.uint8)
            gt_npy = gt_npy.squeeze(axis=1)

            preds = preds.data.cpu().numpy()
            preds = threshold_predictions(preds)
            preds = preds.astype(np.uint8)
            preds = preds.squeeze(axis=1)

            metric_mgr(preds, gt_npy)

            num_steps += 1

        metrics_dict = metric_mgr.get_results()
        metric_mgr.reset()

        writer.add_scalars('metrics', metrics_dict, epoch)

        val_loss_total_avg = val_loss_total / num_steps

        writer.add_scalars('losses', {
            'val_loss': val_loss_total_avg,
            'train_loss': train_loss_total_avg
        }, epoch)

        end_time = time.time()
        total_time = end_time - start_time
        tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time))

        writer.add_scalars('losses', {'train_loss': train_loss_total_avg},
                           epoch)
Exemplo n.º 4
0
def cmd_train(context):
    """Main command do train the network.
    :param context: this is a dictionary with all data from the
                    configuration file:
                        - 'command': run the specified command (e.g. train, test)
                        - 'gpu': ID of the used GPU
                        - 'bids_path_train': list of relative paths of the BIDS folders of each training center
                        - 'bids_path_validation': list of relative paths of the BIDS folders of each validation center
                        - 'bids_path_test': list of relative paths of the BIDS folders of each test center
                        - 'batch_size'
                        - 'dropout_rate'
                        - 'batch_norm_momentum'
                        - 'num_epochs'
                        - 'initial_lr': initial learning rate
                        - 'log_directory': folder name where log files are saved
    """
    # Set the GPU
    gpu_number = context["gpu"]
    torch.cuda.set_device(gpu_number)

    # These are the training transformations
    train_transform = transforms.Compose([
        mt_transforms.CenterCrop2D((128, 128)),
        mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0),
                                       sigma_range=(3.5, 4.0),
                                       p=0.3),
        mt_transforms.RandomAffine(degrees=4.6,
                                   scale=(0.98, 1.02),
                                   translate=(0.03, 0.03)),
        mt_transforms.RandomTensorChannelShift((-0.10, 0.10)),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    # These are the validation/testing transformations
    val_transform = transforms.Compose([
        mt_transforms.CenterCrop2D((128, 128)),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    # This code will iterate over the folders and load the data, filtering
    # the slices without labels and then concatenating all the datasets together
    train_datasets = []
    for bids_ds in tqdm(context["bids_path_train"],
                        desc="Loading training set"):
        ds_train = loader.BidsDataset(bids_ds,
                                      transform=train_transform,
                                      slice_filter_fn=loader.SliceFilter())
        train_datasets.append(ds_train)

    ds_train = ConcatDataset(train_datasets)
    print(f"Loaded {len(ds_train)} axial slices for the training set.")
    train_loader = DataLoader(ds_train,
                              batch_size=context["batch_size"],
                              shuffle=True,
                              pin_memory=True,
                              collate_fn=mt_datasets.mt_collate,
                              num_workers=1)

    # Validation dataset ------------------------------------------------------
    validation_datasets = []
    for bids_ds in tqdm(context["bids_path_validation"],
                        desc="Loading validation set"):
        ds_val = loader.BidsDataset(bids_ds,
                                    transform=val_transform,
                                    slice_filter_fn=loader.SliceFilter())
        validation_datasets.append(ds_val)

    ds_val = ConcatDataset(validation_datasets)
    print(f"Loaded {len(ds_val)} axial slices for the validation set.")
    val_loader = DataLoader(ds_val,
                            batch_size=context["batch_size"],
                            shuffle=True,
                            pin_memory=True,
                            collate_fn=mt_datasets.mt_collate,
                            num_workers=1)

    model = M.Classifier(drop_rate=context["dropout_rate"],
                         bn_momentum=context["batch_norm_momentum"])
    model.cuda()

    num_epochs = context["num_epochs"]
    initial_lr = context["initial_lr"]

    # Using SGD with cosine annealing learning rate
    optimizer = optim.SGD(model.parameters(), lr=initial_lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)

    # Write the metrics, images, etc to TensorBoard format
    writer = SummaryWriter(log_dir=context["log_directory"])

    # Cross Entropy Loss
    criterion = nn.CrossEntropyLoss()

    # Training loop -----------------------------------------------------------
    best_validation_loss = float("inf")

    lst_train_loss = []
    lst_val_loss = []
    lst_accuracy = []

    for epoch in tqdm(range(1, num_epochs + 1), desc="Training"):
        start_time = time.time()

        scheduler.step()

        lr = scheduler.get_lr()[0]
        writer.add_scalar('learning_rate', lr, epoch)

        model.train()
        train_loss_total = 0.0
        num_steps = 0

        for i, batch in enumerate(train_loader):
            input_samples = batch["input"]
            input_labels = get_modality(batch)

            var_input = input_samples.cuda()
            var_labels = torch.cuda.LongTensor(input_labels).cuda(
                non_blocking=True)

            outputs = model(var_input)

            loss = criterion(outputs, var_labels)
            train_loss_total += loss.item()

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
            num_steps += 1

        train_loss_total_avg = train_loss_total / num_steps
        lst_train_loss.append(train_loss_total_avg)

        tqdm.write(f"Epoch {epoch} training loss: {train_loss_total_avg:.4f}.")

        # Validation loop -----------------------------------------------------
        model.eval()
        val_loss_total = 0.0
        num_steps = 0

        val_accuracy = 0
        num_samples = 0

        for i, batch in enumerate(val_loader):
            input_samples = batch["input"]
            input_labels = get_modality(batch)

            with torch.no_grad():
                var_input = input_samples.cuda()
                var_labels = torch.cuda.LongTensor(input_labels).cuda(
                    non_blocking=True)

                outputs = model(var_input)
                _, preds = torch.max(outputs, 1)

                loss = criterion(outputs, var_labels)
                val_loss_total += loss.item()

                val_accuracy += int((var_labels == preds).sum())

            num_steps += 1
            num_samples += context['batch_size']

        val_loss_total_avg = val_loss_total / num_steps
        lst_val_loss.append(val_loss_total_avg)
        tqdm.write(f"Epoch {epoch} validation loss: {val_loss_total_avg:.4f}.")

        val_accuracy_avg = 100 * val_accuracy / num_samples
        lst_accuracy.append(val_accuracy_avg)
        tqdm.write(f"Epoch {epoch} accuracy : {val_accuracy_avg:.4f}.")

        # add metrics for tensorboard
        writer.add_scalars('validation metrics', {
            'accuracy': val_accuracy_avg,
        }, epoch)

        writer.add_scalars('losses', {
            'train_loss': train_loss_total_avg,
            'val_loss': val_loss_total_avg,
        }, epoch)

        end_time = time.time()
        total_time = end_time - start_time
        tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time))

        if val_loss_total_avg < best_validation_loss:
            best_validation_loss = val_loss_total_avg
            torch.save(model,
                       "./" + context["log_directory"] + "/best_model.pt")

    # save final model
    torch.save(model, "./" + context["log_directory"] + "/final_model.pt")

    # save the metrics
    parameters = "CrossEntropyLoss/batchsize=" + str(context['batch_size'])
    parameters += "/initial_lr=" + str(context['initial_lr'])
    parameters += "/dropout=" + str(context['dropout_rate'])

    plt.subplot(2, 1, 1)
    plt.title(parameters)
    plt.plot(lst_train_loss, color='red', label='Training')
    plt.plot(lst_val_loss, color='blue', label='Validation')
    plt.legend(loc='upper right')
    plt.ylabel('Loss')

    plt.subplot(2, 1, 2)
    plt.plot(lst_accuracy)
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')

    plt.savefig(parameters + '.png')

    return
Exemplo n.º 5
0
gt_slice = slice_pair["gt"]

print(input_slice.shape)

# img = input_slice
# plt.imshow(img)
# plt.show()

# img = input_slice
# plt.imshow(img)
# plt.show()

# transformer
composed_transform = transforms.Compose([
    mt_transforms.Resample(0.25, 0.25),
    mt_transforms.CenterCrop2D((200, 200)),
    mt_transforms.ToTensor(),
])

# load data
train_dataset = mt_datasets.SCGMChallenge2DTrain(root_dir=ROOT_DIR_GMCHALLENGE,
                                                 transform=composed_transform)
print(len(train_dataset))

# PyTorch data loader
dataloader = DataLoader(train_dataset,
                        batch_size=4,
                        shuffle=True,
                        num_workers=4,
                        collate_fn=mt_datasets.mt_collate)
def plot_2_pic(i):  # 对某个人画图
    img = nib.load('./data_ct/' + str(i) + 'Venous_tra_5mm.nii')
    img_arr = img.get_fdata()
    img_gt = nib.load('./data_ct/' + str(i) + 'Venous_tra_5mm_roi.nii')
    img_arr_gt = img_gt.get_fdata()
    for i in range(img_arr.shape[-1]):
        plt.subplot(1, 2, 1)
        plt.imshow(img_arr[:, :, i], cmap='gray')
        plt.subplot(1, 2, 2)
        plt.imshow(img_arr_gt[:, :, i], cmap='gray')
        plt.show()


train_transform = transforms.Compose([
    mt_transforms.Resample(1.6, 1.6),
    mt_transforms.CenterCrop2D((256, 256)),
    #         mt_transforms.ElasticTransform(alpha_range=(40.0, 60.0),
    #                                        sigma_range=(2.5, 4.0),
    #                                        p=0.3),#弹性形变
    #         mt_transforms.RandomAffine(degrees=4.6,
    #                                    scale=(0.98, 1.02),
    #                                    translate=(0.03, 0.03)),#随机仿射变换
    #         mt_transforms.RandomTensorChannelShift((-0.10, 0.10)),#????随机偏移?
    #         transforms.Resize(256),
    mt_transforms.ToTensor(),
    #         transforms.Resize(256),
    mt_transforms.NormalizeInstance(),
])

val_transform = transforms.Compose([
    mt_transforms.Resample(1.6, 1.6),
Exemplo n.º 7
0
def cmd_train(ctx):
    global_step = 0

    num_workers = ctx["num_workers"]
    num_epochs = ctx["num_epochs"]
    experiment_name = ctx["experiment_name"]
    cons_weight = ctx["cons_weight"]
    initial_lr = ctx["initial_lr"]
    consistency_rampup = ctx["consistency_rampup"]
    weight_decay = ctx["weight_decay"]
    rootdir_gmchallenge_train = ctx["rootdir_gmchallenge_train"]
    rootdir_gmchallenge_test = ctx["rootdir_gmchallenge_test"]
    supervised_only = ctx["supervised_only"]
    """
    experiment_name
    """
    # experiment_name += '-e%s-cw%s-lr%s-cr%s-lramp%s-wd%s-cl%s-sc%s-ac%s-vc%s' % \
    #                     (num_epochs, cons_weight, initial_lr, consistency_rampup,
    #                     ctx["initial_lr_rampup"], weight_decay, ctx["consistency_loss"],
    #                     ctx["source_centers"], ctx["adapt_centers"], ctx["val_centers"])

    # Decay for learning rate
    if "constant" in ctx["decay_lr"]:
        decay_lr_fn = decay_constant_lr

    if "poly" in ctx["decay_lr"]:
        decay_lr_fn = decay_poly_lr

    if "cosine" in ctx["decay_lr"]:
        decay_lr_fn = cosine_lr

    # Consistency loss
    #
    # mse = Mean Squared Error
    # dice = Dice loss
    # cross_entropy = Cross Entropy
    # mse_confidence = MSE with Confidence Threshold
    if ctx["consistency_loss"] == "dice":
        consistency_loss_fn = mt_losses.dice_loss
    if ctx["consistency_loss"] == "mse":
        consistency_loss_fn = F.mse_loss
    if ctx["consistency_loss"] == "cross_entropy":
        consistency_loss_fn = F.binary_cross_entropy
    if ctx["consistency_loss"] == "mse_confident":
        confidence_threshold = ctx["confidence_threshold"]
        consistency_loss_fn = mt_losses.ConfidentMSELoss(confidence_threshold)

    # Xs, Ys = Source input and source label, train
    # Xt1, Xt2 = Target, domain adaptation, no label, different aug (same sample), train
    # Xv, Yv = Target input and target label, validation

    # Sample Xs and Ys from this
    source_train = mt_datasets.SCGMChallenge2DTrain(
        rootdir_gmchallenge_train,
        slice_filter_fn=mt_filters.SliceFilter(),
        site_ids=ctx["source_centers"],  # Test = 1,2,3, train = 1,2
        subj_ids=range(1, 11))

    # Sample Xt1, Xt2 from this
    unlabeled_filter = mt_filters.SliceFilter(filter_empty_mask=False)
    target_adapt_train = mt_datasets.SCGMChallenge2DTest(
        rootdir_gmchallenge_test,
        slice_filter_fn=unlabeled_filter,
        site_ids=ctx["adapt_centers"],  # 3 = train, 4 = test
        subj_ids=range(11, 21))

    # Sample Xv, Yv from this
    validation_centers = []
    for center in ctx["val_centers"]:
        validation_centers.append(
            mt_datasets.SCGMChallenge2DTrain(
                rootdir_gmchallenge_train,
                slice_filter_fn=mt_filters.SliceFilter(),
                site_ids=[center],  # 3 = train, 4 = test
                subj_ids=range(1, 11)))

    # Training source data augmentation
    source_transform = tv.transforms.Compose([
        mt_transforms.CenterCrop2D((200, 200)),
        mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0),
                                       sigma_range=(3.5, 4.0),
                                       p=0.3),
        mt_transforms.RandomAffine(degrees=4.6,
                                   scale=(0.98, 1.02),
                                   translate=(0.03, 0.03)),
        mt_transforms.RandomTensorChannelShift((-0.10, 0.10)),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    # Target adaptation data augmentation
    target_adapt_transform = tv.transforms.Compose([
        mt_transforms.CenterCrop2D((200, 200), labeled=False),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    # Target adaptation data augmentation
    target_val_adapt_transform = tv.transforms.Compose([
        mt_transforms.CenterCrop2D((200, 200)),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    source_train.set_transform(source_transform)

    target_adapt_train.set_transform(target_adapt_transform)
    for center in validation_centers:
        center.set_transform(target_val_adapt_transform)

    source_train_loader = DataLoader(source_train,
                                     batch_size=ctx["source_batch_size"],
                                     shuffle=True,
                                     drop_last=True,
                                     num_workers=num_workers,
                                     collate_fn=mt_datasets.mt_collate,
                                     pin_memory=True)

    target_adapt_train_loader = DataLoader(target_adapt_train,
                                           batch_size=ctx["target_batch_size"],
                                           shuffle=True,
                                           drop_last=True,
                                           num_workers=num_workers,
                                           collate_fn=mt_datasets.mt_collate,
                                           pin_memory=True)

    validation_centers_loaders = []
    for center in validation_centers:
        validation_centers_loaders.append(
            DataLoader(center,
                       batch_size=ctx["target_batch_size"],
                       shuffle=False,
                       drop_last=False,
                       num_workers=num_workers,
                       collate_fn=mt_datasets.mt_collate,
                       pin_memory=True))

    model = create_model(ctx)

    if not supervised_only:
        model_ema = create_model(ctx, ema=True)
    else:
        model_ema = None

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=initial_lr,
                                 weight_decay=weight_decay)

    writer = SummaryWriter(log_dir="log_{}".format(experiment_name))

    # Training loop
    for epoch in tqdm(range(1, num_epochs + 1), desc="Epochs"):
        start_time = time.time()

        # Rampup -----
        initial_lr_rampup = ctx["initial_lr_rampup"]

        if initial_lr_rampup > 0:
            if epoch <= initial_lr_rampup:
                lr = initial_lr * sigmoid_rampup(epoch, initial_lr_rampup)
            else:
                lr = decay_lr_fn(epoch - initial_lr_rampup,
                                 num_epochs - initial_lr_rampup, initial_lr)
        else:
            lr = decay_lr_fn(epoch, num_epochs, initial_lr)

        writer.add_scalar('learning_rate', lr, epoch)

        for param_group in optimizer.param_groups:
            tqdm.write("Learning Rate: {:.6f}".format(lr))
            param_group['lr'] = lr

        consistency_weight = get_current_consistency_weight(
            cons_weight, epoch, consistency_rampup)
        writer.add_scalar('consistency_weight', consistency_weight, epoch)

        # Train mode
        model.train()

        if not supervised_only:
            model_ema.train()

        composite_loss_total = 0.0
        class_loss_total = 0.0
        consistency_loss_total = 0.0

        num_steps = 0
        target_adapt_train_iter = iter(target_adapt_train_loader)

        for i, train_batch in enumerate(source_train_loader):
            # Keys: 'input', 'gt', 'input_metadata', 'gt_metadata'

            # Supervised component --------------------------------------------
            train_input, train_gt = train_batch["input"], train_batch["gt"]
            train_input = train_input.cuda()
            train_gt = train_gt.cuda(async=True)
            preds_supervised = model(train_input)
            class_loss = mt_losses.dice_loss(preds_supervised, train_gt)

            if not supervised_only:

                # Unsupervised component ------------------------------------------
                try:
                    target_adapt_batch = target_adapt_train_iter.next()
                except StopIteration:
                    target_adapt_train_iter = iter(target_adapt_train_loader)
                    target_adapt_batch = target_adapt_train_iter.next()

                target_adapt_input = target_adapt_batch["input"]
                target_adapt_input = target_adapt_input.cuda()

                # Teacher forward
                with torch.no_grad():
                    teacher_preds_unsup = model_ema(target_adapt_input)

                linked_aug_batch = \
                    linked_batch_augmentation(target_adapt_input, teacher_preds_unsup)

                adapt_input_batch = linked_aug_batch['input'][0].cuda()
                teacher_preds_unsup_aug = linked_aug_batch['input'][1].cuda()

                # Student forward
                student_preds_unsup = model(adapt_input_batch)

                consistency_loss = consistency_weight * consistency_loss_fn(
                    student_preds_unsup, teacher_preds_unsup_aug)
            else:
                consistency_loss = torch.FloatTensor([0.]).cuda()

            composite_loss = class_loss + consistency_loss

            optimizer.zero_grad()
            composite_loss.backward()

            optimizer.step()

            composite_loss_total += composite_loss.item()
            consistency_loss_total += consistency_loss.item()
            class_loss_total += class_loss.item()

            num_steps += 1
            global_step += 1

            if not supervised_only:
                if epoch <= ctx["ema_late_epoch"]:
                    update_ema_variables(model, model_ema, ctx["ema_alpha"],
                                         global_step)
                else:
                    update_ema_variables(model, model_ema,
                                         ctx["ema_alpha_late"], global_step)

        # Write histogram of the probs
        if not supervised_only:
            npy_teacher_preds = teacher_preds_unsup.detach().cpu().numpy()
            writer.add_histogram("Teacher Preds Hist", npy_teacher_preds,
                                 epoch)

            npy_student_preds = student_preds_unsup.detach().cpu().numpy()
            writer.add_histogram("Student Preds Hist", npy_student_preds,
                                 epoch)

        npy_supervised_preds = preds_supervised.detach().cpu().numpy()
        writer.add_histogram("Supervised Preds Hist", npy_supervised_preds,
                             epoch)

        composite_loss_avg = composite_loss_total / num_steps
        class_loss_avg = class_loss_total / num_steps
        consistency_loss_avg = consistency_loss_total / num_steps

        tqdm.write("Steps p/ Epoch: {}".format(num_steps))
        tqdm.write("Consistency Weight: {:.6f}".format(consistency_weight))
        tqdm.write("Composite Loss: {:.6f}".format(composite_loss_avg))
        tqdm.write("Class Loss: {:.6f}".format(class_loss_avg))
        tqdm.write("Consistency Loss: {:.6f}".format(consistency_loss_avg))

        # Write sample images
        if ctx["write_images"] and epoch % ctx["write_images_interval"] == 0:
            try:
                plot_img = vutils.make_grid(preds_supervised,
                                            normalize=True,
                                            scale_each=True)
                writer.add_image('Train Source Prediction', plot_img, epoch)

                plot_img = vutils.make_grid(train_input,
                                            normalize=True,
                                            scale_each=True)
                writer.add_image('Train Source Input', plot_img, epoch)

                plot_img = vutils.make_grid(train_gt,
                                            normalize=True,
                                            scale_each=True)
                writer.add_image('Train Source Ground Truth', plot_img, epoch)

                # Unsupervised component viz
                if not supervised_only:
                    plot_img = vutils.make_grid(target_adapt_input,
                                                normalize=True,
                                                scale_each=True)
                    writer.add_image('Train Target Student Input', plot_img,
                                     epoch)

                    plot_img = vutils.make_grid(teacher_preds_unsup,
                                                normalize=True,
                                                scale_each=True)
                    writer.add_image('Train Target Student Preds', plot_img,
                                     epoch)

                    plot_img = vutils.make_grid(adapt_input_batch,
                                                normalize=True,
                                                scale_each=True)
                    writer.add_image('Train Target Teacher Input', plot_img,
                                     epoch)

                    plot_img = vutils.make_grid(student_preds_unsup,
                                                normalize=True,
                                                scale_each=True)
                    writer.add_image('Train Target Teacher Preds', plot_img,
                                     epoch)

                    plot_img = vutils.make_grid(student_preds_unsup,
                                                normalize=True,
                                                scale_each=True)
                    writer.add_image('Train Target Student Preds (augmented)',
                                     plot_img, epoch)
            except:
                tqdm.write("*** Error writing images ***")

        writer.add_scalars(
            'losses', {
                'composite_loss': composite_loss_avg,
                'class_loss': class_loss_avg,
                'consistency_loss': consistency_loss_avg
            }, epoch)

        # Evaluation mode
        model.eval()

        if not supervised_only:
            model_ema.eval()

        metric_fns = [
            mt_metrics.dice_score, mt_metrics.jaccard_score,
            mt_metrics.hausdorff_score, mt_metrics.precision_score,
            mt_metrics.recall_score, mt_metrics.specificity_score,
            mt_metrics.intersection_over_union, mt_metrics.accuracy_score
        ]

        for center, loader in enumerate(validation_centers_loaders):
            validation(model, model_ema, loader, writer, metric_fns, epoch,
                       ctx, 'val_%s' % ctx["val_centers"][center])

        end_time = time.time()
        total_time = end_time - start_time
        tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time))