def experiment(exp, arch, loss, double_softmax, confidence_thresh, rampup,
               teacher_alpha, fix_ema, unsup_weight, cls_bal_scale,
               cls_bal_scale_range, cls_balance, cls_balance_loss,
               combine_batches, learning_rate, standardise_samples,
               src_affine_std, src_xlat_range, src_hflip, src_intens_flip,
               src_intens_scale_range, src_intens_offset_range,
               src_gaussian_noise_std, tgt_affine_std, tgt_xlat_range,
               tgt_hflip, tgt_intens_flip, tgt_intens_scale_range,
               tgt_intens_offset_range, tgt_gaussian_noise_std, num_epochs,
               batch_size, epoch_size, seed, log_file, model_file, device):
    settings = locals().copy()

    import os
    import sys
    import pickle
    import cmdline_helpers

    if log_file == '':
        log_file = 'output_aug_log_{}.txt'.format(exp)
    elif log_file == 'none':
        log_file = None

    if log_file is not None:
        if os.path.exists(log_file):
            print('Output log file {} already exists'.format(log_file))
            return

    use_rampup = rampup > 0

    src_intens_scale_range_lower, src_intens_scale_range_upper, src_intens_offset_range_lower, src_intens_offset_range_upper = \
        cmdline_helpers.intens_aug_options(src_intens_scale_range, src_intens_offset_range)
    tgt_intens_scale_range_lower, tgt_intens_scale_range_upper, tgt_intens_offset_range_lower, tgt_intens_offset_range_upper = \
        cmdline_helpers.intens_aug_options(tgt_intens_scale_range, tgt_intens_offset_range)

    import time
    import math
    import numpy as np
    from batchup import data_source, work_pool
    import data_loaders
    import standardisation
    import network_architectures
    import augmentation
    import torch, torch.cuda
    from torch import nn
    from torch.nn import functional as F
    import optim_weight_ema

    torch_device = torch.device(device)
    pool = work_pool.WorkerThreadPool(2)

    n_chn = 0

    if exp == 'svhn_mnist':
        d_source = data_loaders.load_svhn(zero_centre=False, greyscale=True)
        d_target = data_loaders.load_mnist(invert=False,
                                           zero_centre=False,
                                           pad32=True,
                                           val=False)
    elif exp == 'mnist_svhn':
        d_source = data_loaders.load_mnist(invert=False,
                                           zero_centre=False,
                                           pad32=True)
        d_target = data_loaders.load_svhn(zero_centre=False,
                                          greyscale=True,
                                          val=False)
    elif exp == 'svhn_mnist_rgb':
        d_source = data_loaders.load_svhn(zero_centre=False, greyscale=False)
        d_target = data_loaders.load_mnist(invert=False,
                                           zero_centre=False,
                                           pad32=True,
                                           val=False,
                                           rgb=True)
    elif exp == 'mnist_svhn_rgb':
        d_source = data_loaders.load_mnist(invert=False,
                                           zero_centre=False,
                                           pad32=True,
                                           rgb=True)
        d_target = data_loaders.load_svhn(zero_centre=False,
                                          greyscale=False,
                                          val=False)
    elif exp == 'cifar_stl':
        d_source = data_loaders.load_cifar10(range_01=False)
        d_target = data_loaders.load_stl(zero_centre=False, val=False)
    elif exp == 'stl_cifar':
        d_source = data_loaders.load_stl(zero_centre=False)
        d_target = data_loaders.load_cifar10(range_01=False, val=False)
    elif exp == 'mnist_usps':
        d_source = data_loaders.load_mnist(zero_centre=False)
        d_target = data_loaders.load_usps(zero_centre=False,
                                          scale28=True,
                                          val=False)
    elif exp == 'usps_mnist':
        d_source = data_loaders.load_usps(zero_centre=False, scale28=True)
        d_target = data_loaders.load_mnist(zero_centre=False, val=False)
    elif exp == 'syndigits_svhn':
        d_source = data_loaders.load_syn_digits(zero_centre=False)
        d_target = data_loaders.load_svhn(zero_centre=False, val=False)
    elif exp == 'synsigns_gtsrb':
        d_source = data_loaders.load_syn_signs(zero_centre=False)
        d_target = data_loaders.load_gtsrb(zero_centre=False, val=False)
    else:
        print('Unknown experiment type \'{}\''.format(exp))
        return

    # Delete the training ground truths as we should not be using them
    del d_target.train_y

    if standardise_samples:
        standardisation.standardise_dataset(d_source)
        standardisation.standardise_dataset(d_target)

    n_classes = d_source.n_classes

    print('Loaded data')

    if arch == '':
        if exp in {'mnist_usps', 'usps_mnist'}:
            arch = 'mnist-bn-32-64-256'
        if exp in {'svhn_mnist', 'mnist_svhn'}:
            arch = 'grey-32-64-128-gp'
        if exp in {
                'cifar_stl', 'stl_cifar', 'syndigits_svhn', 'svhn_mnist_rgb',
                'mnist_svhn_rgb'
        }:
            arch = 'rgb-128-256-down-gp'
        if exp in {'synsigns_gtsrb'}:
            arch = 'rgb40-96-192-384-gp'

    net_class, expected_shape = network_architectures.get_net_and_shape_for_architecture(
        arch)

    if expected_shape != d_source.train_X.shape[1:]:
        print(
            'Architecture {} not compatible with experiment {}; it needs samples of shape {}, '
            'data has samples of shape {}'.format(arch, exp, expected_shape,
                                                  d_source.train_X.shape[1:]))
        return

    student_net = net_class(n_classes).to(torch_device)
    teacher_net = net_class(n_classes).to(torch_device)
    student_params = list(student_net.parameters())
    teacher_params = list(teacher_net.parameters())
    for param in teacher_params:
        param.requires_grad = False

    student_optimizer = torch.optim.Adam(student_params, lr=learning_rate)
    if fix_ema:
        teacher_optimizer = optim_weight_ema.EMAWeightOptimizer(
            teacher_net, student_net, alpha=teacher_alpha)
    else:
        teacher_optimizer = optim_weight_ema.OldWeightEMA(teacher_net,
                                                          student_net,
                                                          alpha=teacher_alpha)
    classification_criterion = nn.CrossEntropyLoss()

    print('Built network')

    src_aug = augmentation.ImageAugmentation(
        src_hflip,
        src_xlat_range,
        src_affine_std,
        intens_flip=src_intens_flip,
        intens_scale_range_lower=src_intens_scale_range_lower,
        intens_scale_range_upper=src_intens_scale_range_upper,
        intens_offset_range_lower=src_intens_offset_range_lower,
        intens_offset_range_upper=src_intens_offset_range_upper,
        gaussian_noise_std=src_gaussian_noise_std)
    tgt_aug = augmentation.ImageAugmentation(
        tgt_hflip,
        tgt_xlat_range,
        tgt_affine_std,
        intens_flip=tgt_intens_flip,
        intens_scale_range_lower=tgt_intens_scale_range_lower,
        intens_scale_range_upper=tgt_intens_scale_range_upper,
        intens_offset_range_lower=tgt_intens_offset_range_lower,
        intens_offset_range_upper=tgt_intens_offset_range_upper,
        gaussian_noise_std=tgt_gaussian_noise_std)

    if combine_batches:

        def augment(X_sup, y_src, X_tgt):
            X_src_stu, X_src_tea = src_aug.augment_pair(X_sup)
            X_tgt_stu, X_tgt_tea = tgt_aug.augment_pair(X_tgt)
            return X_src_stu, X_src_tea, y_src, X_tgt_stu, X_tgt_tea
    else:

        def augment(X_src, y_src, X_tgt):
            X_src = src_aug.augment(X_src)
            X_tgt_stu, X_tgt_tea = tgt_aug.augment_pair(X_tgt)
            return X_src, y_src, X_tgt_stu, X_tgt_tea

    rampup_weight_in_list = [0]

    cls_bal_fn = network_architectures.get_cls_bal_function(cls_balance_loss)

    def compute_aug_loss(stu_out, tea_out):
        # Augmentation loss
        if use_rampup:
            unsup_mask = None
            conf_mask_count = None
            unsup_mask_count = None
        else:
            conf_tea = torch.max(tea_out, 1)[0]
            unsup_mask = conf_mask = (conf_tea > confidence_thresh).float()
            unsup_mask_count = conf_mask_count = conf_mask.sum()

        if loss == 'bce':
            aug_loss = network_architectures.robust_binary_crossentropy(
                stu_out, tea_out)
        else:
            d_aug_loss = stu_out - tea_out
            aug_loss = d_aug_loss * d_aug_loss

        # Class balance scaling
        if cls_bal_scale:
            if use_rampup:
                n_samples = float(aug_loss.shape[0])
            else:
                n_samples = unsup_mask.sum()
            avg_pred = n_samples / float(n_classes)
            bal_scale = avg_pred / torch.clamp(tea_out.sum(dim=0), min=1.0)
            if cls_bal_scale_range != 0.0:
                bal_scale = torch.clamp(bal_scale,
                                        min=1.0 / cls_bal_scale_range,
                                        max=cls_bal_scale_range)
            bal_scale = bal_scale.detach()
            aug_loss = aug_loss * bal_scale[None, :]

        aug_loss = aug_loss.mean(dim=1)

        if use_rampup:
            unsup_loss = aug_loss.mean() * rampup_weight_in_list[0]
        else:
            unsup_loss = (aug_loss * unsup_mask).mean()

        # Class balance loss
        if cls_balance > 0.0:
            # Compute per-sample average predicated probability
            # Average over samples to get average class prediction
            avg_cls_prob = stu_out.mean(dim=0)
            # Compute loss
            equalise_cls_loss = cls_bal_fn(avg_cls_prob,
                                           float(1.0 / n_classes))

            equalise_cls_loss = equalise_cls_loss.mean() * n_classes

            if use_rampup:
                equalise_cls_loss = equalise_cls_loss * rampup_weight_in_list[0]
            else:
                if rampup == 0:
                    equalise_cls_loss = equalise_cls_loss * unsup_mask.mean(
                        dim=0)

            unsup_loss += equalise_cls_loss * cls_balance

        return unsup_loss, conf_mask_count, unsup_mask_count

    if combine_batches:

        def f_train(X_src0, X_src1, y_src, X_tgt0, X_tgt1):
            X_src0 = torch.tensor(X_src0,
                                  dtype=torch.float,
                                  device=torch_device)
            X_src1 = torch.tensor(X_src1,
                                  dtype=torch.float,
                                  device=torch_device)
            y_src = torch.tensor(y_src, dtype=torch.long, device=torch_device)
            X_tgt0 = torch.tensor(X_tgt0,
                                  dtype=torch.float,
                                  device=torch_device)
            X_tgt1 = torch.tensor(X_tgt1,
                                  dtype=torch.float,
                                  device=torch_device)

            n_samples = X_src0.size()[0]
            n_total = n_samples + X_tgt0.size()[0]

            student_optimizer.zero_grad()
            student_net.train()
            teacher_net.train()

            # Concatenate source and target mini-batches
            X0 = torch.cat([X_src0, X_tgt0], 0)
            X1 = torch.cat([X_src1, X_tgt1], 0)

            student_logits_out = student_net(X0)
            student_prob_out = F.softmax(student_logits_out, dim=1)

            src_logits_out = student_logits_out[:n_samples]
            src_prob_out = student_prob_out[:n_samples]

            teacher_logits_out = teacher_net(X1)
            teacher_prob_out = F.softmax(teacher_logits_out, dim=1)

            # Supervised classification loss
            if double_softmax:
                clf_loss = classification_criterion(src_prob_out, y_src)
            else:
                clf_loss = classification_criterion(src_logits_out, y_src)

            unsup_loss, conf_mask_count, unsup_mask_count = compute_aug_loss(
                student_prob_out, teacher_prob_out)

            loss_expr = clf_loss + unsup_loss * unsup_weight

            loss_expr.backward()
            student_optimizer.step()
            teacher_optimizer.step()

            outputs = [
                float(clf_loss) * n_samples,
                float(unsup_loss) * n_total
            ]
            if not use_rampup:
                mask_count = float(conf_mask_count) * 0.5
                unsup_count = float(unsup_mask_count) * 0.5

                outputs.append(mask_count)
                outputs.append(unsup_count)
            return tuple(outputs)
    else:

        def f_train(X_src, y_src, X_tgt0, X_tgt1):
            X_src = torch.tensor(X_src, dtype=torch.float, device=torch_device)
            y_src = torch.tensor(y_src, dtype=torch.long, device=torch_device)
            X_tgt0 = torch.tensor(X_tgt0,
                                  dtype=torch.float,
                                  device=torch_device)
            X_tgt1 = torch.tensor(X_tgt1,
                                  dtype=torch.float,
                                  device=torch_device)

            student_optimizer.zero_grad()
            student_net.train()
            teacher_net.train()

            src_logits_out = student_net(X_src)
            student_tgt_logits_out = student_net(X_tgt0)
            student_tgt_prob_out = F.softmax(student_tgt_logits_out, dim=1)
            teacher_tgt_logits_out = teacher_net(X_tgt1)
            teacher_tgt_prob_out = F.softmax(teacher_tgt_logits_out, dim=1)

            # Supervised classification loss
            if double_softmax:
                clf_loss = classification_criterion(
                    F.softmax(src_logits_out, dim=1), y_src)
            else:
                clf_loss = classification_criterion(src_logits_out, y_src)

            unsup_loss, conf_mask_count, unsup_mask_count = compute_aug_loss(
                student_tgt_prob_out, teacher_tgt_prob_out)

            loss_expr = clf_loss + unsup_loss * unsup_weight

            loss_expr.backward()
            student_optimizer.step()
            teacher_optimizer.step()

            n_samples = X_src.size()[0]

            outputs = [
                float(clf_loss) * n_samples,
                float(unsup_loss) * n_samples
            ]
            if not use_rampup:
                mask_count = float(conf_mask_count)
                unsup_count = float(unsup_mask_count)

                outputs.append(mask_count)
                outputs.append(unsup_count)
            return tuple(outputs)

    print('Compiled training function')

    def f_pred_src(X_sup):
        X_var = torch.tensor(X_sup, dtype=torch.float, device=torch_device)
        student_net.eval()
        teacher_net.eval()
        return (F.softmax(student_net(X_var), dim=1).detach().cpu().numpy(),
                F.softmax(teacher_net(X_var), dim=1).detach().cpu().numpy())

    def f_pred_tgt(X_sup):
        X_var = torch.tensor(X_sup, dtype=torch.float, device=torch_device)
        student_net.eval()
        teacher_net.eval()
        return (F.softmax(student_net(X_var), dim=1).detach().cpu().numpy(),
                F.softmax(teacher_net(X_var), dim=1).detach().cpu().numpy())

    def f_eval_src(X_sup, y_sup):
        y_pred_prob_stu, y_pred_prob_tea = f_pred_src(X_sup)
        y_pred_stu = np.argmax(y_pred_prob_stu, axis=1)
        y_pred_tea = np.argmax(y_pred_prob_tea, axis=1)
        return (float(
            (y_pred_stu != y_sup).sum()), float((y_pred_tea != y_sup).sum()))

    def f_eval_tgt(X_sup, y_sup):
        y_pred_prob_stu, y_pred_prob_tea = f_pred_tgt(X_sup)
        y_pred_stu = np.argmax(y_pred_prob_stu, axis=1)
        y_pred_tea = np.argmax(y_pred_prob_tea, axis=1)
        return (float(
            (y_pred_stu != y_sup).sum()), float((y_pred_tea != y_sup).sum()))

    print('Compiled evaluation function')

    # Setup output
    def log(text):
        print(text)
        if log_file is not None:
            with open(log_file, 'a') as f:
                f.write(text + '\n')
                f.flush()
                f.close()

    cmdline_helpers.ensure_containing_dir_exists(log_file)

    # Report setttings
    log('Settings: {}'.format(', '.join([
        '{}={}'.format(key, settings[key])
        for key in sorted(list(settings.keys()))
    ])))

    # Report dataset size
    log('Dataset:')
    log('SOURCE Train: X.shape={}, y.shape={}'.format(d_source.train_X.shape,
                                                      d_source.train_y.shape))
    log('SOURCE Test: X.shape={}, y.shape={}'.format(d_source.test_X.shape,
                                                     d_source.test_y.shape))
    log('TARGET Train: X.shape={}'.format(d_target.train_X.shape))
    log('TARGET Test: X.shape={}, y.shape={}'.format(d_target.test_X.shape,
                                                     d_target.test_y.shape))

    print('Training...')
    sup_ds = data_source.ArrayDataSource([d_source.train_X, d_source.train_y],
                                         repeats=-1)
    tgt_train_ds = data_source.ArrayDataSource([d_target.train_X], repeats=-1)
    train_ds = data_source.CompositeDataSource([sup_ds,
                                                tgt_train_ds]).map(augment)
    train_ds = pool.parallel_data_source(train_ds)
    if epoch_size == 'large':
        n_samples = max(d_source.train_X.shape[0], d_target.train_X.shape[0])
    elif epoch_size == 'small':
        n_samples = min(d_source.train_X.shape[0], d_target.train_X.shape[0])
    elif epoch_size == 'target':
        n_samples = d_target.train_X.shape[0]
    n_train_batches = n_samples // batch_size

    source_test_ds = data_source.ArrayDataSource(
        [d_source.test_X, d_source.test_y])
    target_test_ds = data_source.ArrayDataSource(
        [d_target.test_X, d_target.test_y])

    if seed != 0:
        shuffle_rng = np.random.RandomState(seed)
    else:
        shuffle_rng = np.random

    train_batch_iter = train_ds.batch_iterator(batch_size=batch_size,
                                               shuffle=shuffle_rng)

    best_teacher_model_state = {
        k: v.cpu().numpy()
        for k, v in teacher_net.state_dict().items()
    }

    best_conf_mask_rate = 0.0
    best_src_test_err = 1.0
    for epoch in range(num_epochs):
        t1 = time.time()

        if use_rampup:
            if epoch < rampup:
                p = max(0.0, float(epoch)) / float(rampup)
                p = 1.0 - p
                rampup_value = math.exp(-p * p * 5.0)
            else:
                rampup_value = 1.0

            rampup_weight_in_list[0] = rampup_value

        train_res = data_source.batch_map_mean(f_train,
                                               train_batch_iter,
                                               n_batches=n_train_batches)

        train_clf_loss = train_res[0]
        if combine_batches:
            unsup_loss_string = 'unsup (both) loss={:.6f}'.format(train_res[1])
        else:
            unsup_loss_string = 'unsup (tgt) loss={:.6f}'.format(train_res[1])

        src_test_err_stu, src_test_err_tea = source_test_ds.batch_map_mean(
            f_eval_src, batch_size=batch_size * 2)
        tgt_test_err_stu, tgt_test_err_tea = target_test_ds.batch_map_mean(
            f_eval_tgt, batch_size=batch_size * 2)

        if use_rampup:
            unsup_loss_string = '{}, rampup={:.3%}'.format(
                unsup_loss_string, rampup_value)
            if src_test_err_stu < best_src_test_err:
                best_src_test_err = src_test_err_stu
                best_teacher_model_state = {
                    k: v.cpu().numpy()
                    for k, v in teacher_net.state_dict().items()
                }
                improve = '*** '
            else:
                improve = ''
        else:
            conf_mask_rate = train_res[-2]
            unsup_mask_rate = train_res[-1]
            if conf_mask_rate > best_conf_mask_rate:
                best_conf_mask_rate = conf_mask_rate
                improve = '*** '
                best_teacher_model_state = {
                    k: v.cpu().numpy()
                    for k, v in teacher_net.state_dict().items()
                }
            else:
                improve = ''
            unsup_loss_string = '{}, conf mask={:.3%}, unsup mask={:.3%}'.format(
                unsup_loss_string, conf_mask_rate, unsup_mask_rate)

        t2 = time.time()

        log('{}Epoch {} took {:.2f}s: TRAIN clf loss={:.6f}, {}; '
            'SRC TEST ERR={:.3%}, TGT TEST student err={:.3%}, TGT TEST teacher err={:.3%}'
            .format(improve, epoch, t2 - t1, train_clf_loss, unsup_loss_string,
                    src_test_err_stu, tgt_test_err_stu, tgt_test_err_tea))

    # Save network
    if model_file != '':
        cmdline_helpers.ensure_containing_dir_exists(model_file)
        with open(model_file, 'wb') as f:
            torch.save(best_teacher_model_state, f)
def train_seg_semisup_aug_mt(
        submit_config: job_helper.SubmitConfig, dataset, model, arch,
        freeze_bn, opt_type, sgd_momentum, sgd_nesterov, sgd_weight_decay,
        learning_rate, lr_sched, lr_step_epochs, lr_step_gamma, lr_poly_power,
        teacher_alpha, bin_fill_holes, crop_size, aug_offset_range, aug_hflip,
        aug_vflip, aug_hvflip, aug_scale_hung, aug_max_scale,
        aug_scale_non_uniform, aug_rot_mag, aug_free_scale_rot, cons_loss_fn,
        cons_weight, conf_thresh, conf_per_pixel, rampup, unsup_batch_ratio,
        num_epochs, iters_per_epoch, batch_size, n_sup, n_unsup, n_val,
        split_seed, split_path, val_seed, save_preds, save_model, num_workers):
    settings = locals().copy()
    del settings['submit_config']

    import os
    import math
    import time
    import itertools
    import numpy as np
    import torch.nn as nn, torch.nn.functional as F
    from architectures import network_architectures
    import torch.utils.data
    from datapipe import datasets
    from datapipe import seg_data, seg_transforms, seg_transforms_cv
    import evaluation
    import optim_weight_ema
    import lr_schedules

    from datapipe import torch_utils
    affine_align_corners_kw = torch_utils.affine_align_corners_kw(True)

    if crop_size == '':
        crop_size = None
    else:
        crop_size = [int(x.strip()) for x in crop_size.split(',')]

    torch_device = torch.device('cuda:0')

    #
    # Load data sets
    #
    ds_dict = datasets.load_dataset(dataset, n_val, val_seed, n_sup, n_unsup,
                                    split_seed, split_path)

    ds_src = ds_dict['ds_src']
    ds_tgt = ds_dict['ds_tgt']
    tgt_val_ndx = ds_dict['val_ndx_tgt']
    src_val_ndx = ds_dict['val_ndx_src'] if ds_src is not ds_tgt else None
    test_ndx = ds_dict['test_ndx_tgt']
    sup_ndx = ds_dict['sup_ndx']
    unsup_ndx = ds_dict['unsup_ndx']

    n_classes = ds_src.num_classes
    root_n_classes = math.sqrt(n_classes)

    if bin_fill_holes and n_classes != 2:
        print(
            'Binary hole filling can only be used with binary (2-class) segmentation datasets'
        )
        return

    print('Loaded data')

    # Build network
    NetClass = network_architectures.seg.get(arch)

    student_net = NetClass(ds_src.num_classes).to(torch_device)

    if opt_type == 'adam':
        student_optim = torch.optim.Adam([
            dict(params=student_net.pretrained_parameters(),
                 lr=learning_rate * 0.1),
            dict(params=student_net.new_parameters(), lr=learning_rate)
        ])
    elif opt_type == 'sgd':
        student_optim = torch.optim.SGD([
            dict(params=student_net.pretrained_parameters(),
                 lr=learning_rate * 0.1),
            dict(params=student_net.new_parameters(), lr=learning_rate)
        ],
                                        momentum=sgd_momentum,
                                        nesterov=sgd_nesterov,
                                        weight_decay=sgd_weight_decay)
    else:
        raise ValueError('Unknown opt_type {}'.format(opt_type))

    if model == 'mean_teacher':
        teacher_net = NetClass(ds_src.num_classes).to(torch_device)

        for p in teacher_net.parameters():
            p.requires_grad = False

        teacher_optim = optim_weight_ema.EMAWeightOptimizer(
            teacher_net, student_net, teacher_alpha)
        eval_net = teacher_net
    elif model == 'pi':
        teacher_net = student_net
        teacher_optim = None
        eval_net = student_net
    else:
        print('Unknown model type {}'.format(model))
        return

    BLOCK_SIZE = student_net.BLOCK_SIZE
    NET_MEAN, NET_STD = seg_transforms.get_mean_std(ds_tgt, student_net)

    if freeze_bn:
        if not hasattr(student_net, 'freeze_batchnorm'):
            raise ValueError(
                'Network {} does not support batchnorm freezing'.format(arch))

    clf_crossent_loss = nn.CrossEntropyLoss(ignore_index=255)

    print('Built network')

    if iters_per_epoch == -1:
        iters_per_epoch = len(unsup_ndx) // batch_size
    total_iters = iters_per_epoch * num_epochs

    lr_epoch_scheduler, lr_iter_scheduler = lr_schedules.make_lr_schedulers(
        optimizer=student_optim,
        total_iters=total_iters,
        schedule_type=lr_sched,
        step_epochs=lr_step_epochs,
        step_gamma=lr_step_gamma,
        poly_power=lr_poly_power)

    # Train data pipeline: transforms
    train_transforms = []

    if crop_size is not None:
        if aug_scale_hung:
            train_transforms.append(
                seg_transforms_cv.SegCVTransformRandomCropScaleHung(
                    crop_size, (aug_offset_range, aug_offset_range),
                    uniform_scale=not aug_scale_non_uniform))
        elif aug_max_scale != 1.0 or aug_rot_mag != 0.0:
            train_transforms.append(
                seg_transforms_cv.SegCVTransformRandomCropRotateScale(
                    crop_size, (aug_offset_range, aug_offset_range),
                    rot_mag=aug_rot_mag,
                    max_scale=aug_max_scale,
                    uniform_scale=not aug_scale_non_uniform,
                    constrain_rot_scale=not aug_free_scale_rot))
        else:
            train_transforms.append(
                seg_transforms_cv.SegCVTransformRandomCrop(
                    crop_size, (aug_offset_range, aug_offset_range)))
    else:
        if aug_scale_hung:
            raise NotImplementedError('aug_scale_hung requires a crop_size')

    if aug_hflip or aug_vflip or aug_hvflip:
        train_transforms.append(
            seg_transforms_cv.SegCVTransformRandomFlip(aug_hflip, aug_vflip,
                                                       aug_hvflip))
    train_transforms.append(
        seg_transforms_cv.SegCVTransformNormalizeToTensor(NET_MEAN, NET_STD))

    # Train data pipeline: supervised and unsupervised data sets
    train_sup_ds = ds_src.dataset(
        labels=True,
        mask=False,
        xf=False,
        pair=False,
        transforms=seg_transforms.SegTransformCompose(train_transforms),
        pipeline_type='cv')
    train_unsup_ds = ds_src.dataset(
        labels=False,
        mask=True,
        xf=True,
        pair=True,
        transforms=seg_transforms.SegTransformCompose(train_transforms),
        pipeline_type='cv')

    collate_fn = seg_data.SegCollate(BLOCK_SIZE)

    # Train data pipeline: data loaders
    sup_sampler = seg_data.RepeatSampler(
        torch.utils.data.SubsetRandomSampler(sup_ndx))
    train_sup_loader = torch.utils.data.DataLoader(train_sup_ds,
                                                   batch_size,
                                                   sampler=sup_sampler,
                                                   collate_fn=collate_fn,
                                                   num_workers=num_workers)
    if cons_weight > 0.0:
        unsup_sampler = seg_data.RepeatSampler(
            torch.utils.data.SubsetRandomSampler(unsup_ndx))
        train_unsup_loader = torch.utils.data.DataLoader(
            train_unsup_ds,
            batch_size,
            sampler=unsup_sampler,
            collate_fn=collate_fn,
            num_workers=num_workers)
    else:
        train_unsup_loader = None

    # Eval pipeline
    src_val_loader, tgt_val_loader, test_loader = datasets.eval_data_pipeline(
        ds_src, ds_tgt, src_val_ndx, tgt_val_ndx, test_ndx, batch_size,
        collate_fn, NET_MEAN, NET_STD, num_workers)

    # Report setttings
    print('Settings:')
    print(', '.join([
        '{}={}'.format(key, settings[key])
        for key in sorted(list(settings.keys()))
    ]))

    # Report dataset size
    print('Dataset:')
    print('len(sup_ndx)={}'.format(len(sup_ndx)))
    print('len(unsup_ndx)={}'.format(len(unsup_ndx)))
    if ds_src is not ds_tgt:
        print('len(src_val_ndx)={}'.format(len(tgt_val_ndx)))
        print('len(tgt_val_ndx)={}'.format(len(tgt_val_ndx)))
    else:
        print('len(val_ndx)={}'.format(len(tgt_val_ndx)))
    if test_ndx is not None:
        print('len(test_ndx)={}'.format(len(test_ndx)))

    if n_sup != -1:
        print('sup_ndx={}'.format(sup_ndx.tolist()))

    # Track mIoU for early stopping
    best_tgt_miou = None
    best_epoch = 0

    eval_net_state = {
        key: value.detach().cpu().numpy()
        for key, value in eval_net.state_dict().items()
    }

    # Create iterators
    train_sup_iter = iter(train_sup_loader)
    train_unsup_iter = iter(
        train_unsup_loader) if train_unsup_loader is not None else None

    iter_i = 0
    print('Training...')
    for epoch_i in range(num_epochs):
        if lr_epoch_scheduler is not None:
            lr_epoch_scheduler.step(epoch_i)

        t1 = time.time()

        if rampup > 0:
            ramp_val = network_architectures.sigmoid_rampup(epoch_i, rampup)
        else:
            ramp_val = 1.0

        student_net.train()
        if teacher_net is not student_net:
            teacher_net.train()

        if freeze_bn:
            student_net.freeze_batchnorm()
            if teacher_net is not student_net:
                teacher_net.freeze_batchnorm()

        sup_loss_acc = 0.0
        consistency_loss_acc = 0.0
        conf_rate_acc = 0.0
        n_sup_batches = 0
        n_unsup_batches = 0

        src_val_iter = iter(
            src_val_loader) if src_val_loader is not None else None
        tgt_val_iter = iter(
            tgt_val_loader) if tgt_val_loader is not None else None

        for sup_batch in itertools.islice(train_sup_iter, iters_per_epoch):
            if lr_iter_scheduler is not None:
                lr_iter_scheduler.step(iter_i)

            student_optim.zero_grad()

            #
            # Supervised branch
            #

            batch_x = sup_batch['image'].to(torch_device)
            batch_y = sup_batch['labels'].to(torch_device)

            logits_sup = student_net(batch_x)
            sup_loss = clf_crossent_loss(logits_sup, batch_y[:, 0, :, :])
            sup_loss.backward()

            if cons_weight > 0.0:
                for _ in range(unsup_batch_ratio):
                    #
                    # Unsupervised branch
                    #

                    # Cut mode: batch consists of unsupervised samples and mask params
                    unsup_batch = next(train_unsup_iter)

                    # Input images to torch tensor
                    batch_ux0 = unsup_batch['sample0']['image'].to(
                        torch_device)
                    batch_um0 = unsup_batch['sample0']['mask'].to(torch_device)
                    batch_ux1 = unsup_batch['sample1']['image'].to(
                        torch_device)
                    batch_um1 = unsup_batch['sample1']['mask'].to(torch_device)
                    batch_ufx0_to_1 = unsup_batch['xf0_to_1'].to(torch_device)

                    # Get teacher predictions for image0
                    with torch.no_grad():
                        logits_cons_tea = teacher_net(batch_ux0).detach()
                    # Get student prediction for image1
                    logits_cons_stu = student_net(batch_ux1)

                    # Transformation from teacher to student space
                    grid_tea_to_stu = F.affine_grid(batch_ufx0_to_1,
                                                    batch_ux0.shape,
                                                    **affine_align_corners_kw)
                    # Transform teacher predicted logits to student space
                    logits_cons_tea_in_stu = F.grid_sample(
                        logits_cons_tea, grid_tea_to_stu,
                        **affine_align_corners_kw)
                    # Transform mask from teacher to student space and multiply by student space mask
                    mask_tea_in_stu = F.grid_sample(
                        batch_um0, grid_tea_to_stu, **
                        affine_align_corners_kw) * batch_um1

                    # Logits -> probs
                    prob_cons_tea = F.softmax(logits_cons_tea, dim=1)
                    prob_cons_stu = F.softmax(logits_cons_stu, dim=1)
                    # Transform teacher predicted probabilities to student space
                    prob_cons_tea_in_stu = F.grid_sample(
                        prob_cons_tea, grid_tea_to_stu,
                        **affine_align_corners_kw)

                    # for i in range(len(batch_ux0)):
                    #     plt.figure(figsize=(18, 12))
                    #
                    #     x_0_in_1 = F.grid_sample(batch_ux0, grid_tea_to_stu)
                    #     d_x0_in_1 = torch.abs(x_0_in_1 - batch_ux1) * mask_tea_in_stu
                    #     mask_tea_in_stu_np = mask_tea_in_stu.detach().cpu().numpy()
                    #
                    #     plt.subplot(2, 4, 1)
                    #     plt.imshow(batch_ux0[i].detach().cpu().numpy().transpose(1, 2, 0) * 0.25 + 0.5)
                    #     plt.subplot(2, 4, 2)
                    #     plt.imshow(batch_ux1[i].detach().cpu().numpy().transpose(1, 2, 0) * 0.25 + 0.5)
                    #     plt.subplot(2, 4, 3)
                    #     plt.imshow(x_0_in_1[i].detach().cpu().numpy().transpose(1, 2, 0) * 0.25 + 0.5)
                    #     plt.subplot(2, 4, 4)
                    #     plt.imshow(d_x0_in_1[i].detach().cpu().numpy().transpose(1, 2, 0) * 10 + 0.5, cmap='gray')
                    #
                    #     plt.subplot(2, 4, 5)
                    #     plt.imshow(batch_um0[i,0].detach().cpu().numpy(), cmap='gray')
                    #     plt.subplot(2, 4, 6)
                    #     plt.imshow(batch_um1[i,0].detach().cpu().numpy(), cmap='gray')
                    #     plt.subplot(2, 4, 7)
                    #     plt.imshow(mask_tea_in_stu[i,0].detach().cpu().numpy(), cmap='gray')
                    #
                    #     plt.show()

                    loss_mask = mask_tea_in_stu

                    # Confidence thresholding
                    if conf_thresh > 0.0:
                        # Compute confidence of teacher predictions
                        conf_tea = prob_cons_tea_in_stu.max(dim=1)[0]
                        # Compute confidence mask
                        conf_mask = (conf_tea >=
                                     conf_thresh).float()[:, None, :, :]
                        # Record rate for reporting
                        conf_rate_acc += float(conf_mask.mean())
                        # Average confidence mask if requested
                        if not conf_per_pixel:
                            conf_mask = conf_mask.mean()

                        loss_mask = loss_mask * conf_mask
                    elif rampup > 0:
                        conf_rate_acc += ramp_val

                    # Compute per-pixel consistency loss
                    # Note that the way we aggregate the loss across the class/channel dimension (1)
                    # depends on the loss function used. Generally, summing over the class dimension
                    # keeps the magnitude of the gradient of the loss w.r.t. the logits
                    # nearly constant w.r.t. the number of classes. When using logit-variance,
                    # dividing by `sqrt(num_classes)` helps.
                    if cons_loss_fn == 'var':
                        delta_prob = prob_cons_stu - prob_cons_tea_in_stu
                        consistency_loss = delta_prob * delta_prob
                        consistency_loss = consistency_loss.sum(dim=1,
                                                                keepdim=True)
                    elif cons_loss_fn == 'logits_var':
                        delta_logits = logits_cons_stu - logits_cons_tea_in_stu
                        consistency_loss = delta_logits * delta_logits
                        consistency_loss = delta_prob * delta_prob
                        consistency_loss = consistency_loss.sum(
                            dim=1, keepdim=True) / root_n_classes
                    elif cons_loss_fn == 'logits_smoothl1':
                        consistency_loss = F.smooth_l1_loss(
                            logits_cons_stu,
                            logits_cons_tea_in_stu,
                            reduce=False)
                        consistency_loss = consistency_loss.sum(
                            dim=1, keepdim=True) / root_n_classes
                    elif cons_loss_fn == 'bce':
                        consistency_loss = network_architectures.robust_binary_crossentropy(
                            prob_cons_stu, prob_cons_tea_in_stu)
                        consistency_loss = consistency_loss.sum(dim=1,
                                                                keepdim=True)
                    elif cons_loss_fn == 'kld':
                        consistency_loss = F.kl_div(F.log_softmax(
                            logits_cons_stu, dim=1),
                                                    prob_cons_tea_in_stu,
                                                    reduce=False)
                        consistency_loss = consistency_loss.sum(dim=1,
                                                                keepdim=True)
                    else:
                        raise ValueError(
                            'Unknown consistency loss function {}'.format(
                                cons_loss_fn))

                    # Apply consistency loss mask and take the mean over pixels and images
                    consistency_loss = (consistency_loss * loss_mask).mean()

                    # Modulate with rampup if desired
                    if rampup > 0:
                        consistency_loss = consistency_loss * ramp_val

                    # Weight the consistency loss and back-prop
                    unsup_loss = consistency_loss * cons_weight
                    unsup_loss.backward()

                    consistency_loss_acc += float(consistency_loss.detach())

                    n_unsup_batches += 1

            student_optim.step()
            if teacher_optim is not None:
                teacher_optim.step()

            sup_loss_acc += float(sup_loss.detach())
            n_sup_batches += 1
            iter_i += 1

        sup_loss_acc /= n_sup_batches
        if n_unsup_batches > 0:
            consistency_loss_acc /= n_unsup_batches
            conf_rate_acc /= n_unsup_batches

        eval_net.eval()

        if src_val_iter is not None:
            src_iou_eval = evaluation.EvaluatorIoU(ds_src.num_classes,
                                                   bin_fill_holes)
            with torch.no_grad():
                for batch in src_val_iter:
                    batch_x = batch['image'].to(torch_device)
                    batch_y = batch['labels'].numpy()

                    logits = eval_net(batch_x)
                    pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                    for sample_i in range(len(batch_y)):
                        src_iou_eval.sample(batch_y[sample_i, 0],
                                            pred_y[sample_i],
                                            ignore_value=255)

            src_iou = src_iou_eval.score()
            src_miou = src_iou.mean()
        else:
            src_iou_eval = src_iou = src_miou = None

        tgt_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes,
                                               bin_fill_holes)
        with torch.no_grad():
            for batch in tgt_val_iter:
                batch_x = batch['image'].to(torch_device)
                batch_y = batch['labels'].numpy()

                logits = eval_net(batch_x)
                pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                for sample_i in range(len(batch_y)):
                    tgt_iou_eval.sample(batch_y[sample_i, 0],
                                        pred_y[sample_i],
                                        ignore_value=255)

        tgt_iou = tgt_iou_eval.score()
        tgt_miou = tgt_iou.mean()

        t2 = time.time()

        if ds_src is not ds_tgt:
            print(
                'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, '
                'SRC VAL mIoU={:.3%}, TGT VAL mIoU={:.3%}'.format(
                    epoch_i + 1, t2 - t1, sup_loss_acc, consistency_loss_acc,
                    conf_rate_acc, src_miou, tgt_miou))
            print('-- SRC {}'.format(', '.join(
                ['{:.3%}'.format(x) for x in src_iou])))
            print('-- TGT {}'.format(', '.join(
                ['{:.3%}'.format(x) for x in tgt_iou])))
        else:
            print(
                'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, VAL mIoU={:.3%}'
                .format(epoch_i + 1, t2 - t1, sup_loss_acc,
                        consistency_loss_acc, conf_rate_acc, tgt_miou))
            print('-- {}'.format(', '.join(
                ['{:.3%}'.format(x) for x in tgt_iou])))

    if save_model:
        model_path = os.path.join(submit_config.run_dir, "model.pth")
        torch.save(eval_net, model_path)

    if save_preds:
        out_dir = os.path.join(submit_config.run_dir, 'preds')
        os.makedirs(out_dir, exist_ok=True)
        with torch.no_grad():
            for batch in tgt_val_loader:
                batch_x = batch['image'].to(torch_device)
                batch_ndx = batch['index'].numpy()

                logits = eval_net(batch_x)
                pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                for sample_i, sample_ndx in enumerate(batch_ndx):
                    ds_tgt.save_prediction_by_index(
                        out_dir, pred_y[sample_i].astype(np.uint32),
                        sample_ndx)
    else:
        out_dir = None

    if test_loader is not None:
        test_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes,
                                                bin_fill_holes)
        with torch.no_grad():
            for batch in test_loader:
                batch_x = batch['image'].to(torch_device)
                batch_ndx = batch['index'].numpy()

                logits = eval_net(batch_x)
                pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                for sample_i, sample_ndx in enumerate(batch_ndx):
                    if save_preds:
                        ds_tgt.save_prediction_by_index(
                            out_dir, pred_y[sample_i].astype(np.uint32),
                            sample_ndx)
                    test_iou_eval.sample(batch_y[sample_i, 0],
                                         pred_y[sample_i],
                                         ignore_value=255)

        test_iou = test_iou_eval.score()
        test_miou = test_iou.mean()

        print('FINAL TEST: mIoU={:.3%}'.format(test_miou))
        print('-- TEST {}'.format(', '.join(
            ['{:.3%}'.format(x) for x in test_iou])))
def train_seg_semisup_vat_mt(
        submit_config: job_helper.SubmitConfig, dataset, model, arch,
        freeze_bn, opt_type, sgd_momentum, sgd_nesterov, sgd_weight_decay,
        learning_rate, lr_sched, lr_step_epochs, lr_step_gamma, lr_poly_power,
        teacher_alpha, bin_fill_holes, crop_size, aug_hflip, aug_vflip,
        aug_hvflip, aug_scale_hung, aug_max_scale, aug_scale_non_uniform,
        aug_rot_mag, vat_radius, adaptive_vat_radius, vat_dir_from_student,
        cons_loss_fn, cons_weight, conf_thresh, conf_per_pixel, rampup,
        unsup_batch_ratio, num_epochs, iters_per_epoch, batch_size, n_sup,
        n_unsup, n_val, split_seed, split_path, val_seed, save_preds,
        save_model, num_workers):
    settings = locals().copy()
    del settings['submit_config']

    import os
    import time
    import itertools
    import math
    import numpy as np
    import torch, torch.nn as nn, torch.nn.functional as F
    from architectures import network_architectures
    import torch.utils.data
    from datapipe import datasets
    from datapipe import seg_data, seg_transforms, seg_transforms_cv
    import evaluation
    import optim_weight_ema
    import lr_schedules

    if crop_size == '':
        crop_size = None
    else:
        crop_size = [int(x.strip()) for x in crop_size.split(',')]

    torch_device = torch.device('cuda:0')

    #
    # Load data sets
    #
    ds_dict = datasets.load_dataset(dataset, n_val, val_seed, n_sup, n_unsup,
                                    split_seed, split_path)

    ds_src = ds_dict['ds_src']
    ds_tgt = ds_dict['ds_tgt']
    tgt_val_ndx = ds_dict['val_ndx_tgt']
    src_val_ndx = ds_dict['val_ndx_src'] if ds_src is not ds_tgt else None
    test_ndx = ds_dict['test_ndx_tgt']
    sup_ndx = ds_dict['sup_ndx']
    unsup_ndx = ds_dict['unsup_ndx']

    n_classes = ds_src.num_classes
    root_n_classes = math.sqrt(n_classes)

    if bin_fill_holes and n_classes != 2:
        print(
            'Binary hole filling can only be used with binary (2-class) segmentation datasets'
        )
        return

    print('Loaded data')

    # Build network
    NetClass = network_architectures.seg.get(arch)

    student_net = NetClass(ds_src.num_classes).to(torch_device)

    if opt_type == 'adam':
        student_optim = torch.optim.Adam([
            dict(params=student_net.pretrained_parameters(),
                 lr=learning_rate * 0.1),
            dict(params=student_net.new_parameters(), lr=learning_rate)
        ])
    elif opt_type == 'sgd':
        student_optim = torch.optim.SGD([
            dict(params=student_net.pretrained_parameters(),
                 lr=learning_rate * 0.1),
            dict(params=student_net.new_parameters(), lr=learning_rate)
        ],
                                        momentum=sgd_momentum,
                                        nesterov=sgd_nesterov,
                                        weight_decay=sgd_weight_decay)
    else:
        raise ValueError('Unknown opt_type {}'.format(opt_type))

    if model == 'mean_teacher':
        teacher_net = NetClass(ds_src.num_classes).to(torch_device)

        for p in teacher_net.parameters():
            p.requires_grad = False

        teacher_optim = optim_weight_ema.EMAWeightOptimizer(
            teacher_net, student_net, teacher_alpha)
        eval_net = teacher_net
    elif model == 'pi':
        teacher_net = student_net
        teacher_optim = None
        eval_net = student_net
    else:
        print('Unknown model type {}'.format(model))
        return

    if vat_dir_from_student:
        vat_dir_net = student_net
    else:
        vat_dir_net = teacher_net

    BLOCK_SIZE = student_net.BLOCK_SIZE
    NET_MEAN, NET_STD = seg_transforms.get_mean_std(ds_tgt, student_net)

    if freeze_bn:
        if not hasattr(student_net, 'freeze_batchnorm'):
            raise ValueError(
                'Network {} does not support batchnorm freezing'.format(arch))

    clf_crossent_loss = nn.CrossEntropyLoss(ignore_index=255)

    print('Built network')

    if iters_per_epoch == -1:
        iters_per_epoch = len(unsup_ndx) // batch_size
    total_iters = iters_per_epoch * num_epochs

    lr_epoch_scheduler, lr_iter_scheduler = lr_schedules.make_lr_schedulers(
        optimizer=student_optim,
        total_iters=total_iters,
        schedule_type=lr_sched,
        step_epochs=lr_step_epochs,
        step_gamma=lr_step_gamma,
        poly_power=lr_poly_power)

    # Train data pipeline: transforms
    train_transforms = []

    if crop_size is not None:
        if aug_scale_hung:
            train_transforms.append(
                seg_transforms_cv.SegCVTransformRandomCropScaleHung(
                    crop_size, (0, 0),
                    uniform_scale=not aug_scale_non_uniform))
        elif aug_max_scale != 1.0 or aug_rot_mag != 0.0:
            train_transforms.append(
                seg_transforms_cv.SegCVTransformRandomCropRotateScale(
                    crop_size, (0, 0),
                    rot_mag=aug_rot_mag,
                    max_scale=aug_max_scale,
                    uniform_scale=not aug_scale_non_uniform,
                    constrain_rot_scale=True))
        else:
            train_transforms.append(
                seg_transforms_cv.SegCVTransformRandomCrop(crop_size, (0, 0)))
    else:
        if aug_scale_hung:
            raise NotImplementedError('aug_scale_hung requires a crop_size')

    if aug_hflip or aug_vflip or aug_hvflip:
        train_transforms.append(
            seg_transforms_cv.SegCVTransformRandomFlip(aug_hflip, aug_vflip,
                                                       aug_hvflip))
    train_transforms.append(
        seg_transforms_cv.SegCVTransformNormalizeToTensor(NET_MEAN, NET_STD))

    # Train data pipeline: supervised and unsupervised data sets
    train_sup_ds = ds_src.dataset(
        labels=True,
        mask=False,
        xf=False,
        pair=False,
        transforms=seg_transforms.SegTransformCompose(train_transforms),
        pipeline_type='cv')
    train_unsup_ds = ds_src.dataset(
        labels=False,
        mask=True,
        xf=False,
        pair=False,
        transforms=seg_transforms.SegTransformCompose(train_transforms),
        pipeline_type='cv')

    collate_fn = seg_data.SegCollate(BLOCK_SIZE)

    # Train data pipeline: data loaders
    sup_sampler = seg_data.RepeatSampler(
        torch.utils.data.SubsetRandomSampler(sup_ndx))
    train_sup_loader = torch.utils.data.DataLoader(train_sup_ds,
                                                   batch_size,
                                                   sampler=sup_sampler,
                                                   collate_fn=collate_fn,
                                                   num_workers=num_workers)
    if cons_weight > 0.0:
        unsup_sampler = seg_data.RepeatSampler(
            torch.utils.data.SubsetRandomSampler(unsup_ndx))
        train_unsup_loader = torch.utils.data.DataLoader(
            train_unsup_ds,
            batch_size,
            sampler=unsup_sampler,
            collate_fn=collate_fn,
            num_workers=num_workers)
    else:
        train_unsup_loader = None

    # Eval pipeline
    src_val_loader, tgt_val_loader, test_loader = datasets.eval_data_pipeline(
        ds_src, ds_tgt, src_val_ndx, tgt_val_ndx, test_ndx, batch_size,
        collate_fn, NET_MEAN, NET_STD, num_workers)

    # Report setttings
    print('Settings:')
    print(', '.join([
        '{}={}'.format(key, settings[key])
        for key in sorted(list(settings.keys()))
    ]))

    # Report dataset size
    print('Dataset:')
    print('len(sup_ndx)={}'.format(len(sup_ndx)))
    print('len(unsup_ndx)={}'.format(len(unsup_ndx)))
    if ds_src is not ds_tgt:
        print('len(src_val_ndx)={}'.format(len(tgt_val_ndx)))
        print('len(tgt_val_ndx)={}'.format(len(tgt_val_ndx)))
    else:
        print('len(val_ndx)={}'.format(len(tgt_val_ndx)))
    if test_ndx is not None:
        print('len(test_ndx)={}'.format(len(test_ndx)))

    if n_sup != -1:
        print('sup_ndx={}'.format(sup_ndx.tolist()))

    def t_dot(a, b):
        return (a * b).sum(dim=1, keepdim=True)

    def normalize_eps(x):
        x_flat = x.view(len(x), -1)
        mag = torch.sqrt((x_flat * x_flat).sum(dim=1))
        return x / (mag[:, None, None, None] + 1e-12)

    def normalized_noise(x, requires_grad=False, scale=1.0):
        eps = torch.randn(x.shape, dtype=torch.float, device=x.device)
        eps = normalize_eps(eps) * scale
        if requires_grad:
            eps = eps.clone().detach().requires_grad_(True)
        return eps

    def vat_direction(x):
        """
        Compute the VAT perturbation direction vector

        :param x: input image as a `(N, C, H, W)` tensor
        :return: VAT direction as a `(N, C, H, W)` tensor
        """
        # Put the network used to get the VAT direction in eval mode and get the predicted
        # logits and probabilities for the batch of samples x
        vat_dir_net.eval()
        with torch.no_grad():
            y_pred_logits = vat_dir_net(x).detach()
        y_pred_prob = F.softmax(y_pred_logits, dim=1)

        # Initial noise offset vector with requires_grad=True
        noise_scale = 1.0e-6 * x.shape[2] * x.shape[3] / 1000
        eps = normalized_noise(x, requires_grad=True, scale=noise_scale)

        # Predict logits and probs for sample perturbed by eps
        eps_pred_logits = vat_dir_net(x.detach() + eps)
        eps_pred_prob = F.softmax(eps_pred_logits, dim=1)

        # Choose our loss function
        if cons_loss_fn == 'var':
            delta = (eps_pred_prob - y_pred_prob)
            loss = (delta * delta).sum()
        elif cons_loss_fn == 'bce':
            loss = network_architectures.robust_binary_crossentropy(
                eps_pred_prob, y_pred_prob).sum()
        elif cons_loss_fn == 'kld':
            loss = F.kl_div(F.log_softmax(eps_pred_logits, dim=1),
                            y_pred_prob,
                            reduce=False).sum()
        elif cons_loss_fn == 'logits_var':
            delta = (eps_pred_logits - y_pred_logits)
            loss = (delta * delta).sum()
        else:
            raise ValueError(
                'Unknown consistency loss function {}'.format(cons_loss_fn))

        # Differentiate the loss w.r.t. the perturbation
        eps_adv = torch.autograd.grad(outputs=loss,
                                      inputs=eps,
                                      create_graph=True,
                                      retain_graph=True,
                                      only_inputs=True)[0]

        # Normalize the adversarial perturbation
        return normalize_eps(eps_adv), y_pred_logits, y_pred_prob

    def vat_perburbation(x, m):
        eps_adv_nrm, y_pred_logits, y_pred_prob = vat_direction(x)

        if adaptive_vat_radius:
            # We view semantic segmentation as predicting the class of a pixel
            # given a patch centred on that pixel.
            # The most similar patch in terms of pixel content to a patch P
            # is a patch Q whose central pixel is an immediate neighbour
            # of the central pixel P.
            # We therefore use the image Jacobian (gradient w.r.t. x and y) to
            # get a sense of the distance between neighbouring patches
            # so we can scale the VAT radius according to the image content.

            # Delta in vertical and horizontal directions
            delta_v = x[:, :, 2:, :] - x[:, :, :-2, :]
            delta_h = x[:, :, :, 2:] - x[:, :, :, :-2]

            # delta_h and delta_v are the difference between pixels where the step size is 2, rather than 1
            # So divide by 2 to get the magnitude of the Jacobian

            delta_v = delta_v.view(len(delta_v), -1)
            delta_h = delta_h.view(len(delta_h), -1)
            adv_radius = vat_radius * torch.sqrt(
                (delta_v**2).sum(dim=1) +
                (delta_h**2).sum(dim=1))[:, None, None, None] * 0.5
        else:
            scale = math.sqrt(float(x.shape[1] * x.shape[2] * x.shape[3]))
            adv_radius = vat_radius * scale

        return (eps_adv_nrm * adv_radius).detach(), y_pred_logits, y_pred_prob

    # Track mIoU for early stopping
    best_tgt_miou = None
    best_epoch = 0

    eval_net_state = {
        key: value.detach().cpu().numpy()
        for key, value in eval_net.state_dict().items()
    }

    # Create iterators
    train_sup_iter = iter(train_sup_loader)
    train_unsup_iter = iter(
        train_unsup_loader) if train_unsup_loader is not None else None

    iter_i = 0
    print('Training...')
    for epoch_i in range(num_epochs):
        if lr_epoch_scheduler is not None:
            lr_epoch_scheduler.step(epoch_i)

        t1 = time.time()

        if rampup > 0:
            ramp_val = network_architectures.sigmoid_rampup(epoch_i, rampup)
        else:
            ramp_val = 1.0

        student_net.train()
        if teacher_net is not student_net:
            teacher_net.train()

        if freeze_bn:
            student_net.freeze_batchnorm()
            if teacher_net is not student_net:
                teacher_net.freeze_batchnorm()

        sup_loss_acc = 0.0
        consistency_loss_acc = 0.0
        conf_rate_acc = 0.0
        n_sup_batches = 0
        n_unsup_batches = 0

        src_val_iter = iter(
            src_val_loader) if src_val_loader is not None else None
        tgt_val_iter = iter(
            tgt_val_loader) if tgt_val_loader is not None else None

        for sup_batch in itertools.islice(train_sup_iter, iters_per_epoch):
            if lr_iter_scheduler is not None:
                lr_iter_scheduler.step(iter_i)
            student_optim.zero_grad()

            #
            # Supervised branch
            #

            batch_x = sup_batch['image'].to(torch_device)
            batch_y = sup_batch['labels'].to(torch_device)

            logits_sup = student_net(batch_x)
            sup_loss = clf_crossent_loss(logits_sup, batch_y[:, 0, :, :])
            sup_loss.backward()

            if cons_weight > 0.0:
                for _ in range(unsup_batch_ratio):
                    #
                    # Unsupervised branch
                    #

                    unsup_batch = next(train_unsup_iter)

                    # Input images to torch tensor
                    batch_ux = unsup_batch['image'].to(torch_device)
                    batch_um = unsup_batch['mask'].to(torch_device)

                    # batch_um is a mask that is 1 for valid pixels, 0 for invalid pixels.
                    # It us used later on to scale the consistency loss, so that consistency loss is
                    # only computed for valid pixels.
                    # Explanation:
                    # When using geometric augmentations such as rotations, some pixels in the training
                    # crop may come from outside the bounds of the input image. These pixels will have a value
                    # of 0 in these masks. Similarly, when using scaled crops, the size of the crop
                    # from the input image that must be scaled to the size of the training crop may be
                    # larger than one/both of the input image dimensions. Pixels in the training crop
                    # that arise from outside the input image bounds will once again be given a value
                    # of 0 in these masks.

                    # Compute VAT perburbation
                    x_perturb, logits_cons_tea, prob_cons_tea = vat_perburbation(
                        batch_ux, batch_um)

                    # Perturb image
                    batch_ux_adv = batch_ux + x_perturb

                    # Get teacher predictions for original image
                    with torch.no_grad():
                        logits_cons_tea = teacher_net(batch_ux).detach()
                    # Get student prediction for cut image
                    logits_cons_stu = student_net(batch_ux_adv)

                    # Logits -> probs
                    prob_cons_tea = F.softmax(logits_cons_tea, dim=1)
                    prob_cons_stu = F.softmax(logits_cons_stu, dim=1)

                    loss_mask = batch_um

                    # Confidence thresholding
                    if conf_thresh > 0.0:
                        # Compute confidence of teacher predictions
                        conf_tea = prob_cons_tea.max(dim=1)[0]
                        # Compute confidence mask
                        conf_mask = (conf_tea >=
                                     conf_thresh).float()[:, None, :, :]
                        # Record rate for reporting
                        conf_rate_acc += float(conf_mask.mean())
                        # Average confidence mask if requested
                        if not conf_per_pixel:
                            conf_mask = conf_mask.mean()

                        loss_mask = loss_mask * conf_mask
                    elif rampup > 0:
                        conf_rate_acc += ramp_val

                    # Compute per-pixel consistency loss
                    # Note that the way we aggregate the loss across the class/channel dimension (1)
                    # depends on the loss function used. Generally, summing over the class dimension
                    # keeps the magnitude of the gradient of the loss w.r.t. the logits
                    # nearly constant w.r.t. the number of classes. When using logit-variance,
                    # dividing by `sqrt(num_classes)` helps.
                    if cons_loss_fn == 'var':
                        delta_prob = prob_cons_stu - prob_cons_tea
                        consistency_loss = delta_prob * delta_prob
                        consistency_loss = consistency_loss.sum(dim=1,
                                                                keepdim=True)
                    elif cons_loss_fn == 'logits_var':
                        delta_logits = logits_cons_stu - logits_cons_tea
                        consistency_loss = delta_logits * delta_logits
                        consistency_loss = consistency_loss.sum(
                            dim=1, keepdim=True) / root_n_classes
                    elif cons_loss_fn == 'bce':
                        consistency_loss = network_architectures.robust_binary_crossentropy(
                            prob_cons_stu, prob_cons_tea)
                        consistency_loss = consistency_loss.sum(dim=1,
                                                                keepdim=True)
                    elif cons_loss_fn == 'kld':
                        consistency_loss = F.kl_div(F.log_softmax(
                            logits_cons_stu, dim=1),
                                                    prob_cons_tea,
                                                    reduce=False)
                        consistency_loss = consistency_loss.sum(dim=1,
                                                                keepdim=True)
                    else:
                        raise ValueError(
                            'Unknown consistency loss function {}'.format(
                                cons_loss_fn))

                    # Apply consistency loss mask and take the mean over pixels and images
                    consistency_loss = (consistency_loss * loss_mask).mean()

                    # Modulate with rampup if desired
                    if rampup > 0:
                        consistency_loss = consistency_loss * ramp_val

                    # Weight the consistency loss and back-prop
                    unsup_loss = consistency_loss * cons_weight
                    unsup_loss.backward()

                    consistency_loss_val = float(consistency_loss.detach())
                    consistency_loss_acc += consistency_loss_val

                    if np.isnan(consistency_loss_val):
                        print(
                            'NaN detected in consistency loss; bailing out...')
                        return

                    n_unsup_batches += 1

            student_optim.step()
            if teacher_optim is not None:
                teacher_optim.step()

            sup_loss_val = float(sup_loss.detach())
            sup_loss_acc += sup_loss_val

            if np.isnan(sup_loss_val):
                print('NaN detected in supervised loss; bailing out...')
                return

            n_sup_batches += 1
            iter_i += 1

        sup_loss_acc /= n_sup_batches
        if n_unsup_batches > 0:
            consistency_loss_acc /= n_unsup_batches
            conf_rate_acc /= n_unsup_batches

        eval_net.eval()

        if src_val_iter is not None:
            src_iou_eval = evaluation.EvaluatorIoU(ds_src.num_classes,
                                                   bin_fill_holes)
            with torch.no_grad():
                for batch in src_val_iter:
                    batch_x = batch['image'].to(torch_device)
                    batch_y = batch['labels'].numpy()

                    logits = eval_net(batch_x)
                    pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                    for sample_i in range(len(batch_y)):
                        src_iou_eval.sample(batch_y[sample_i, 0],
                                            pred_y[sample_i],
                                            ignore_value=255)

            src_iou = src_iou_eval.score()
            src_miou = src_iou.mean()
        else:
            src_iou_eval = src_iou = src_miou = None

        tgt_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes,
                                               bin_fill_holes)
        with torch.no_grad():
            for batch in tgt_val_iter:
                batch_x = batch['image'].to(torch_device)
                batch_y = batch['labels'].numpy()

                logits = eval_net(batch_x)
                pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                for sample_i in range(len(batch_y)):
                    tgt_iou_eval.sample(batch_y[sample_i, 0],
                                        pred_y[sample_i],
                                        ignore_value=255)

        tgt_iou = tgt_iou_eval.score()
        tgt_miou = tgt_iou.mean()

        t2 = time.time()

        if ds_src is not ds_tgt:
            print(
                'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, '
                'SRC VAL mIoU={:.3%}, TGT VAL mIoU={:.3%}'.format(
                    epoch_i + 1, t2 - t1, sup_loss_acc, consistency_loss_acc,
                    conf_rate_acc, src_miou, tgt_miou))
            print('-- SRC {}'.format(', '.join(
                ['{:.3%}'.format(x) for x in src_iou])))
            print('-- TGT {}'.format(', '.join(
                ['{:.3%}'.format(x) for x in tgt_iou])))
        else:
            print(
                'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, VAL mIoU={:.3%}'
                .format(epoch_i + 1, t2 - t1, sup_loss_acc,
                        consistency_loss_acc, conf_rate_acc, tgt_miou))
            print('-- {}'.format(', '.join(
                ['{:.3%}'.format(x) for x in tgt_iou])))

    if save_model:
        model_path = os.path.join(submit_config.run_dir, "model.pth")
        torch.save(eval_net, model_path)

    if save_preds:
        out_dir = os.path.join(submit_config.run_dir, 'preds')
        os.makedirs(out_dir, exist_ok=True)
        with torch.no_grad():
            for batch in tgt_val_loader:
                batch_x = batch['image'].to(torch_device)
                batch_ndx = batch['index'].numpy()

                logits = eval_net(batch_x)
                pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                for sample_i, sample_ndx in enumerate(batch_ndx):
                    ds_tgt.save_prediction_by_index(
                        out_dir, pred_y[sample_i].astype(np.uint32),
                        sample_ndx)
    else:
        out_dir = None

    if test_loader is not None:
        test_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes,
                                                bin_fill_holes)
        with torch.no_grad():
            for batch in test_loader:
                batch_x = batch['image'].to(torch_device)
                batch_ndx = batch['index'].numpy()

                logits = eval_net(batch_x)
                pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                for sample_i, sample_ndx in enumerate(batch_ndx):
                    if save_preds:
                        ds_tgt.save_prediction_by_index(
                            out_dir, pred_y[sample_i].astype(np.uint32),
                            sample_ndx)
                    test_iou_eval.sample(batch_y[sample_i, 0],
                                         pred_y[sample_i],
                                         ignore_value=255)

        test_iou = test_iou_eval.score()
        test_miou = test_iou.mean()

        print('FINAL TEST: mIoU={:.3%}'.format(test_miou))
        print('-- TEST {}'.format(', '.join(
            ['{:.3%}'.format(x) for x in test_iou])))
예제 #4
0
def train_toy2d(submit_config: job_helper.SubmitConfig, dataset,
                region_erode_radius, img_noise_std, n_sup, balance_classes,
                seed, sup_path, model, n_hidden, hidden_size, hidden_act,
                norm_layer, perturb_noise_std, dist_contour_range, conf_thresh,
                conf_avg, cons_weight, cons_loss_fn, cons_no_dropout,
                learning_rate, teacher_alpha, num_epochs, batch_size,
                render_cons_grad, render_pred, device, save_output):
    settings = locals().copy()
    del settings['submit_config']

    import sys

    print('Command line:')
    print(' '.join(sys.argv))
    print('Settings:')
    print(', '.join(
        ['{}={}'.format(k, settings[k]) for k in sorted(settings.keys())]))

    import os
    import numpy as np
    import time
    import cv2
    from scipy.ndimage.morphology import distance_transform_edt
    import optim_weight_ema
    from toy2d import generate_data
    from datapipe.seg_data import RepeatSampler

    import torch, torch.nn as nn, torch.nn.functional as F
    import torch.utils.data

    rng = np.random.RandomState(seed)

    # Generate/load the dataset
    if dataset.startswith('img:'):
        # Generate a dataset from a black and white image
        image_path = dataset[4:]
        ds = generate_data.classification_dataset_from_image(
            image_path, region_erode_radius, img_noise_std, n_sup,
            balance_classes, rng)
        image = ds.image
    elif dataset == 'spiral':
        # Generate a spiral dataset
        ds = generate_data.spiral_classification_dataset(
            n_sup, balance_classes, rng)
        image = None
    else:
        print('Unknown dataset {}, should be spiral or img:<path>'.format(
            dataset))
        return

    # If a path to a supervised dataset has been provided, load it
    if sup_path is not None:
        ds.load_supervised(sup_path)

    # If we are constraining perturbations to lie along the contours of the distance map to the ground truth class boundary
    if dist_contour_range > 0.0:
        if image is None:
            print(
                'Constraining perturbations to lying on distance map contours is only supported for \'image\' experiments'
            )
            return
        img_1 = image >= 0.5
        # Compute signed distance map to boundary
        dist_1 = distance_transform_edt(img_1)
        dist_0 = distance_transform_edt(~img_1)
        dist_map = dist_1 * img_1 + -dist_0 * (~img_1)
    else:
        dist_map = None

    # PyTorch device
    torch_device = torch.device(device)

    # Convert perturbation noise std-dev to [y,x]
    try:
        perturb_noise_std = np.array(
            [float(x.strip()) for x in perturb_noise_std.split(',')])
    except ValueError:
        perturb_noise_std = np.array([6.0, 6.0])

    # Assume that perturbation noise std-dev is in pixel space (for image experiments), so convert
    perturb_noise_std_real_scale = perturb_noise_std / ds.img_scale * 2.0
    perturb_noise_std_real_scale = torch.tensor(perturb_noise_std_real_scale,
                                                dtype=torch.float,
                                                device=torch_device)

    # Define the neural network model (an MLP)
    class Network(nn.Module):
        def __init__(self):
            super(Network, self).__init__()

            self.drop = nn.Dropout()

            hidden = []
            chn_in = 2
            for i in range(n_hidden):
                if norm_layer == 'spectral_norm':
                    hidden.append(
                        nn.utils.spectral_norm(nn.Linear(chn_in, hidden_size)))
                elif norm_layer == 'weight_norm':
                    hidden.append(
                        nn.utils.weight_norm(nn.Linear(chn_in, hidden_size)))
                else:
                    hidden.append(nn.Linear(chn_in, hidden_size))

                if norm_layer == 'batch_norm':
                    hidden.append(nn.BatchNorm1d(hidden_size))
                elif norm_layer == 'group_norm':
                    hidden.append(nn.GroupNorm(4, hidden_size))

                if hidden_act == 'relu':
                    hidden.append(nn.ReLU())
                elif hidden_act == 'lrelu':
                    hidden.append(nn.LeakyReLU(0.01))
                else:
                    raise ValueError

                chn_in = hidden_size
            self.hidden = nn.Sequential(*hidden)

            # Final layer; 2-class output
            self.l_final = nn.Linear(chn_in, 2)

        def forward(self, x, use_dropout=True):
            x = self.hidden(x)
            if use_dropout:
                x = self.drop(x)
            x = self.l_final(x)
            return x

    # Build student network, optimizer and supervised loss criterion
    student_net = Network().to(torch_device)
    student_params = list(student_net.parameters())

    student_optimizer = torch.optim.Adam(student_params, lr=learning_rate)
    classification_criterion = nn.CrossEntropyLoss()

    # Build teacher network and optimizer
    if model == 'mean_teacher':
        teacher_net = Network().to(torch_device)
        teacher_params = list(teacher_net.parameters())
        for param in teacher_params:
            param.requires_grad = False
        teacher_optimizer = optim_weight_ema.EMAWeightOptimizer(
            teacher_net, student_net, ema_alpha=teacher_alpha)
        pred_net = teacher_net
    else:
        teacher_net = None
        teacher_optimizer = None
        pred_net = student_net

    # Robust BCE helper
    def robust_binary_crossentropy(pred, tgt):
        inv_tgt = -tgt + 1.0
        inv_pred = -pred + 1.0 + 1e-6
        return -(tgt * torch.log(pred + 1.0e-6) +
                 inv_tgt * torch.log(inv_pred))

    # If we are constraining perturbations to lie on distance map contours, load the distance map as a Torch tensor
    if dist_contour_range > 0.0:
        t_dist_map = torch.tensor(dist_map[None, None, ...],
                                  dtype=torch.float,
                                  device=torch_device)
    else:
        t_dist_map = None

    # Helper function to compute confidence thresholding factor
    def conf_factor(teacher_pred_prob):
        # Compute confidence
        conf_tea = torch.max(teacher_pred_prob, 1)[0]
        conf_tea = conf_tea.detach()
        # Compute factor based on threshold and `conf_avg` flag
        if conf_thresh > 0.0:
            conf_fac = (conf_tea >= conf_thresh).float()
        else:
            conf_fac = torch.ones(conf_tea.shape,
                                  dtype=torch.float,
                                  device=conf_tea.device)
        if conf_avg:
            conf_fac = torch.ones_like(conf_fac) * conf_fac.mean()
        return conf_fac

    # Helper function that constrains consistency loss to operate only when perturbations lie along
    # distance map contours.
    # When this feature is enabled, it masks to zero the loss for any unsupervised sample whose random perturbation
    # deviates too far from the distance map contour
    def dist_map_weighting(t_dist_map, batch_u_X, batch_u_X_1):
        if t_dist_map is not None and dist_contour_range > 0:
            # For each sample in `batch_u_X` and `batch_u_X_1`, both of which are
            # of shape `[n_points, [y,x]]` we want to get the value from the
            # distance map. For this we use `torch.nn.functional.grid_sample`.
            # This function expects grid look-up co-ordinates to have
            # the shape `[batch, height, width, [x, y]]`.
            # We reshape `batch_u_X` and `batch_u_X_1` to `[1, 1, n_points, [x,y]]` and stack along
            # the height dimension, making two rows to send to `grid_sample`.
            # The final shape will be `[1, 2, n_points, [x,y]]`:
            # 1 sample (1 image)
            # 2 rows; batch_u_X and batch_u_X_1
            # n_points columns
            # (x,y)
            # `[n_points, [y,x]]` -> `[1, 1, n_points, [x,y]]`
            sample_points_0 = torch.cat([
                batch_u_X[:, 1].view(1, 1, -1, 1), batch_u_X[:, 0].view(
                    1, 1, -1, 1)
            ],
                                        dim=3)
            # `[n_points, [y,x]]` -> `[1, 1, n_points, [x,y]]`
            sample_points_1 = torch.cat([
                batch_u_X_1[:, 1].view(1, 1, -1, 1), batch_u_X_1[:, 0].view(
                    1, 1, -1, 1)
            ],
                                        dim=3)
            # -> `[1, 2, n_points, [x,y]]`
            sample_points = torch.cat([sample_points_0, sample_points_1],
                                      dim=1)
            # Get distance to class boundary from distance map
            dist_from_boundary = F.grid_sample(t_dist_map, sample_points)
            # Get the squared difference between the distances from `batch_u_X` to the boundary
            # and the distances from `batch_u_X_1` to the boundary.
            delta_dist_sqr = (dist_from_boundary[0, 0, 0, :] -
                              dist_from_boundary[0, 0, 1, :]).pow(2)
            # Per-sample loss mask based on difference between distances
            weight = (delta_dist_sqr <=
                      (dist_contour_range * dist_contour_range)).float()

            return weight
        else:
            return torch.ones(len(batch_u_X),
                              dtype=torch.float,
                              device=batch_u_X.device)

    # Supervised dataset, sampler and loader
    sup_dataset = torch.utils.data.TensorDataset(
        torch.tensor(ds.sup_X, dtype=torch.float),
        torch.tensor(ds.sup_y, dtype=torch.long))
    sup_sampler = RepeatSampler(torch.utils.data.RandomSampler(sup_dataset))
    sup_sep_loader = torch.utils.data.DataLoader(sup_dataset,
                                                 batch_size,
                                                 sampler=sup_sampler,
                                                 num_workers=1)

    # Unsupervised dataset, sampler and loader
    unsup_dataset = torch.utils.data.TensorDataset(
        torch.tensor(ds.unsup_X, dtype=torch.float))
    unsup_sampler = torch.utils.data.RandomSampler(unsup_dataset)
    unsup_loader = torch.utils.data.DataLoader(unsup_dataset,
                                               batch_size,
                                               sampler=unsup_sampler,
                                               num_workers=1)

    # Complete dataset and loader
    all_dataset = torch.utils.data.TensorDataset(
        torch.tensor(ds.X, dtype=torch.float))
    all_loader = torch.utils.data.DataLoader(all_dataset,
                                             16384,
                                             shuffle=False,
                                             num_workers=1)

    # Grid points used to render visualizations
    vis_grid_dataset = torch.utils.data.TensorDataset(
        torch.tensor(ds.px_grid_vis, dtype=torch.float))
    vis_grid_loader = torch.utils.data.DataLoader(vis_grid_dataset,
                                                  16384,
                                                  shuffle=False,
                                                  num_workers=1)

    # Evaluation mode initially
    student_net.eval()
    if teacher_net is not None:
        teacher_net.eval()

    # Compute the magnitude of the gradient of the consistency loss at the logits
    def consistency_loss_logit_grad_mag(batch_u_X):
        u_shape = batch_u_X.shape

        batch_u_X_1 = batch_u_X + torch.randn(u_shape, dtype=torch.float, device=torch_device) * \
                                  perturb_noise_std_real_scale[None, :]

        student_optimizer.zero_grad()

        grads = [None]

        if teacher_net is not None:
            teacher_unsup_logits = teacher_net(batch_u_X).detach()
        else:
            teacher_unsup_logits = student_net(batch_u_X)
        teacher_unsup_prob = F.softmax(teacher_unsup_logits, dim=1)
        student_unsup_logits = student_net(batch_u_X_1)

        def grad_hook(grad):
            grads[0] = torch.sqrt((grad * grad).sum(dim=1))

        student_unsup_logits.register_hook(grad_hook)
        student_unsup_prob = F.softmax(student_unsup_logits, dim=1)

        weight = dist_map_weighting(t_dist_map, batch_u_X, batch_u_X_1)

        mod_fac = conf_factor(teacher_unsup_prob) * weight

        if cons_loss_fn == 'bce':
            aug_loss = robust_binary_crossentropy(student_unsup_prob,
                                                  teacher_unsup_prob)
            aug_loss = aug_loss.mean(dim=1) * mod_fac
            unsup_loss = aug_loss.mean()
        elif cons_loss_fn == 'var':
            d_aug_loss = student_unsup_prob - teacher_unsup_prob
            aug_loss = d_aug_loss * d_aug_loss
            aug_loss = aug_loss.mean(dim=1) * mod_fac
            unsup_loss = aug_loss.mean()
        elif cons_loss_fn == 'logits_var':
            d_aug_loss = student_unsup_logits - teacher_unsup_logits
            aug_loss = d_aug_loss * d_aug_loss
            aug_loss = aug_loss.mean(dim=1) * mod_fac
            unsup_loss = aug_loss.mean()
        else:
            raise ValueError

        unsup_loss.backward()

        return (grads[0].cpu().numpy(), )

    # Helper function for rendering an output image for visualization
    def render_output_image():
        # Generate output for plotting
        with torch.no_grad():
            vis_pred = []
            vis_grad = [] if render_cons_grad else None
            for (batch_X, ) in vis_grid_loader:
                batch_X = batch_X.to(torch_device)
                batch_pred_logits = pred_net(batch_X)
                if render_pred == 'prob':
                    batch_vis = F.softmax(batch_pred_logits, dim=1)[:, 1]
                elif render_pred == 'class':
                    batch_vis = torch.argmax(batch_pred_logits, dim=1)
                else:
                    raise ValueError(
                        'Unknown prediction render {}'.format(render_pred))
                batch_vis = batch_vis.detach().cpu().numpy()
                vis_pred.append(batch_vis)

                if render_cons_grad:
                    batch_grad = consistency_loss_logit_grad_mag(batch_X)
                    vis_grad.append(batch_grad.detach().cpu().numpy())

            vis_pred = np.concatenate(vis_pred, axis=0)
            if render_cons_grad:
                vis_grad = np.concatenate(vis_grad, axis=0)

        out_image = ds.semisup_image_plot(vis_pred, vis_grad)
        return out_image

    # Output image for first frame
    if save_output and submit_config.run_dir is not None:
        plot_path = os.path.join(submit_config.run_dir,
                                 'epoch_{:05d}.png'.format(0))
        cv2.imwrite(plot_path, render_output_image())
    else:
        cv2.imshow('Vis', render_output_image())
        k = cv2.waitKey(1)

    # Train
    print('|sup|={}'.format(len(ds.sup_X)))
    print('|unsup|={}'.format(len(ds.unsup_X)))
    print('|all|={}'.format(len(ds.X)))
    print('Training...')

    terminated = False
    for epoch in range(num_epochs):
        t1 = time.time()
        student_net.train()
        if teacher_net is not None:
            teacher_net.train()

        batch_sup_loss_accum = 0.0
        batch_conf_mask_sum_accum = 0.0
        batch_cons_loss_accum = 0.0
        batch_N_accum = 0.0
        for sup_batch, unsup_batch in zip(sup_sep_loader, unsup_loader):
            (batch_X, batch_y) = sup_batch
            (batch_u_X, ) = unsup_batch

            batch_X = batch_X.to(torch_device)
            batch_y = batch_y.to(torch_device)
            batch_u_X = batch_u_X.to(torch_device)

            # Apply perturbation to generate `batch_u_X_1`
            aug_perturbation = torch.randn(batch_u_X.shape,
                                           dtype=torch.float,
                                           device=torch_device)
            batch_u_X_1 = batch_u_X + aug_perturbation * perturb_noise_std_real_scale[
                None, :]

            # Supervised loss path
            student_optimizer.zero_grad()
            student_sup_logits = student_net(batch_X)
            sup_loss = classification_criterion(student_sup_logits, batch_y)

            if cons_weight > 0.0:
                # Consistency loss path

                # Logits are computed differently depending on model
                if model == 'mean_teacher':
                    teacher_unsup_logits = teacher_net(
                        batch_u_X, use_dropout=not cons_no_dropout).detach()
                    student_unsup_logits = student_net(
                        batch_u_X_1, use_dropout=not cons_no_dropout)
                elif model == 'pi':
                    teacher_unsup_logits = student_net(
                        batch_u_X, use_dropout=not cons_no_dropout)
                    student_unsup_logits = student_net(
                        batch_u_X_1, use_dropout=not cons_no_dropout)
                elif model == 'pi_onebatch':
                    batch_both = torch.cat([batch_u_X, batch_u_X_1], dim=0)
                    both_unsup_logits = student_net(
                        batch_both, use_dropout=not cons_no_dropout)
                    teacher_unsup_logits = both_unsup_logits[:len(batch_u_X)]
                    student_unsup_logits = both_unsup_logits[len(batch_u_X):]
                else:
                    raise RuntimeError

                # Compute predicted probabilities
                teacher_unsup_prob = F.softmax(teacher_unsup_logits, dim=1)
                student_unsup_prob = F.softmax(student_unsup_logits, dim=1)

                # Distance map weighting
                # (if dist_contour_range is 0 then weight will just be 1)
                weight = dist_map_weighting(t_dist_map, batch_u_X, batch_u_X_1)

                # Confidence thresholding
                conf_fac = conf_factor(teacher_unsup_prob)
                mod_fac = conf_fac * weight

                # Compute consistency loss
                if cons_loss_fn == 'bce':
                    aug_loss = robust_binary_crossentropy(
                        student_unsup_prob, teacher_unsup_prob)
                    aug_loss = aug_loss.mean(dim=1) * mod_fac
                    cons_loss = aug_loss.sum() / weight.sum()
                elif cons_loss_fn == 'var':
                    d_aug_loss = student_unsup_prob - teacher_unsup_prob
                    aug_loss = d_aug_loss * d_aug_loss
                    aug_loss = aug_loss.mean(dim=1) * mod_fac
                    cons_loss = aug_loss.sum() / weight.sum()
                elif cons_loss_fn == 'logits_var':
                    d_aug_loss = student_unsup_logits - teacher_unsup_logits
                    aug_loss = d_aug_loss * d_aug_loss
                    aug_loss = aug_loss.mean(dim=1) * mod_fac
                    cons_loss = aug_loss.sum() / weight.sum()
                else:
                    raise ValueError

                # Combine supervised and consistency loss
                loss = sup_loss + cons_loss * cons_weight

                conf_rate = float(conf_fac.sum())
            else:
                loss = sup_loss
                conf_rate = 0.0
                cons_loss = 0.0

            loss.backward()
            student_optimizer.step()
            if teacher_optimizer is not None:
                teacher_optimizer.step()

            batch_sup_loss_accum += float(sup_loss)
            batch_conf_mask_sum_accum += conf_rate
            batch_cons_loss_accum += float(cons_loss)
            batch_N_accum += len(batch_X)

        if batch_N_accum > 0:
            batch_sup_loss_accum /= batch_N_accum
            batch_conf_mask_sum_accum /= batch_N_accum
            batch_cons_loss_accum /= batch_N_accum

        student_net.eval()
        if teacher_net is not None:
            teacher_net.eval()

        # Generate output for plotting
        if save_output and submit_config.run_dir is not None:
            plot_path = os.path.join(submit_config.run_dir,
                                     'epoch_{:05d}.png'.format(epoch + 1))
            cv2.imwrite(plot_path, render_output_image())
        else:
            cv2.imshow('Vis', render_output_image())

            k = cv2.waitKey(1)
            if (k & 255) == 27:
                terminated = True
                break

        t2 = time.time()
        # print('Epoch {}: took {:.3f}s: clf loss={:.6f}'.format(epoch, t2-t1, clf_loss))
        print(
            'Epoch {}: took {:.3f}s: clf loss={:.6f}, conf rate={:.3%}, cons loss={:.6f}'
            .format(epoch + 1, t2 - t1, batch_sup_loss_accum,
                    batch_conf_mask_sum_accum, batch_cons_loss_accum))

    # Get final score based on all samples
    all_pred_y = []
    with torch.no_grad():
        for (batch_X, ) in all_loader:
            batch_X = batch_X.to(torch_device)
            batch_pred_logits = pred_net(batch_X)
            batch_pred_cls = torch.argmax(batch_pred_logits, dim=1)
            all_pred_y.append(batch_pred_cls.detach().cpu().numpy())
    all_pred_y = np.concatenate(all_pred_y, axis=0)
    err_rate = (all_pred_y != ds.y).mean()
    print(
        'FINAL RESULT: Error rate={:.6%} (supervised and unsupervised samples)'
        .format(err_rate))

    if not save_output:
        # Close output window
        if not terminated:
            cv2.waitKey()

        cv2.destroyAllWindows()
예제 #5
0
def train_seg_semisup_mask_mt(
        submit_config: job_helper.SubmitConfig, dataset, model, arch,
        freeze_bn, opt_type, sgd_momentum, sgd_nesterov, sgd_weight_decay,
        learning_rate, lr_sched, lr_step_epochs, lr_step_gamma, lr_poly_power,
        teacher_alpha, bin_fill_holes, crop_size, aug_hflip, aug_vflip,
        aug_hvflip, aug_scale_hung, aug_max_scale, aug_scale_non_uniform,
        aug_rot_mag, mask_mode, mask_prop_range, boxmask_n_boxes,
        boxmask_fixed_aspect_ratio, boxmask_by_size, boxmask_outside_bounds,
        boxmask_no_invert, cons_loss_fn, cons_weight, conf_thresh,
        conf_per_pixel, rampup, unsup_batch_ratio, num_epochs, iters_per_epoch,
        batch_size, n_sup, n_unsup, n_val, split_seed, split_path, val_seed,
        save_preds, save_model, num_workers):
    settings = locals().copy()
    del settings['submit_config']

    if ':' in mask_prop_range:
        a, b = mask_prop_range.split(':')
        mask_prop_range = (float(a.strip()), float(b.strip()))
        del a, b
    else:
        mask_prop_range = float(mask_prop_range)

    if mask_mode == 'zero':
        mask_mix = False
    elif mask_mode == 'mix':
        mask_mix = True
    else:
        raise ValueError('Unknown mask_mode {}'.format(mask_mode))
    del mask_mode

    import os
    import math
    import time
    import itertools
    import numpy as np
    import torch.nn as nn, torch.nn.functional as F
    from architectures import network_architectures
    import torch.utils.data
    from datapipe import datasets
    from datapipe import seg_data, seg_transforms, seg_transforms_cv
    import evaluation
    import optim_weight_ema
    import mask_gen
    import lr_schedules

    if crop_size == '':
        crop_size = None
    else:
        crop_size = [int(x.strip()) for x in crop_size.split(',')]

    torch_device = torch.device('cuda:0')

    #
    # Load data sets
    #
    ds_dict = datasets.load_dataset(dataset, n_val, val_seed, n_sup, n_unsup,
                                    split_seed, split_path)

    ds_src = ds_dict['ds_src']
    ds_tgt = ds_dict['ds_tgt']
    tgt_val_ndx = ds_dict['val_ndx_tgt']
    src_val_ndx = ds_dict['val_ndx_src'] if ds_src is not ds_tgt else None
    test_ndx = ds_dict['test_ndx_tgt']
    sup_ndx = ds_dict['sup_ndx']
    unsup_ndx = ds_dict['unsup_ndx']

    n_classes = ds_src.num_classes
    root_n_classes = math.sqrt(n_classes)

    if bin_fill_holes and n_classes != 2:
        print(
            'Binary hole filling can only be used with binary (2-class) segmentation datasets'
        )
        return

    print('Loaded data')

    # Build network
    NetClass = network_architectures.seg.get(arch)

    student_net = NetClass(ds_src.num_classes).to(torch_device)

    if opt_type == 'adam':
        student_optim = torch.optim.Adam([
            dict(params=student_net.pretrained_parameters(),
                 lr=learning_rate * 0.1),
            dict(params=student_net.new_parameters(), lr=learning_rate)
        ])
    elif opt_type == 'sgd':
        student_optim = torch.optim.SGD([
            dict(params=student_net.pretrained_parameters(),
                 lr=learning_rate * 0.1),
            dict(params=student_net.new_parameters(), lr=learning_rate)
        ],
                                        momentum=sgd_momentum,
                                        nesterov=sgd_nesterov,
                                        weight_decay=sgd_weight_decay)
    else:
        raise ValueError('Unknown opt_type {}'.format(opt_type))

    if model == 'mean_teacher':
        teacher_net = NetClass(ds_src.num_classes).to(torch_device)

        for p in teacher_net.parameters():
            p.requires_grad = False

        teacher_optim = optim_weight_ema.EMAWeightOptimizer(
            teacher_net, student_net, teacher_alpha)
        eval_net = teacher_net
    elif model == 'pi':
        teacher_net = student_net
        teacher_optim = None
        eval_net = student_net
    else:
        print('Unknown model type {}'.format(model))
        return

    BLOCK_SIZE = student_net.BLOCK_SIZE
    NET_MEAN, NET_STD = seg_transforms.get_mean_std(ds_tgt, student_net)

    if freeze_bn:
        if not hasattr(student_net, 'freeze_batchnorm'):
            raise ValueError(
                'Network {} does not support batchnorm freezing'.format(arch))

    clf_crossent_loss = nn.CrossEntropyLoss(ignore_index=255)

    print('Built network')

    mask_generator = mask_gen.BoxMaskGenerator(
        prop_range=mask_prop_range,
        n_boxes=boxmask_n_boxes,
        random_aspect_ratio=not boxmask_fixed_aspect_ratio,
        prop_by_area=not boxmask_by_size,
        within_bounds=not boxmask_outside_bounds,
        invert=not boxmask_no_invert)

    if iters_per_epoch == -1:
        iters_per_epoch = len(unsup_ndx) // batch_size
    total_iters = iters_per_epoch * num_epochs

    lr_epoch_scheduler, lr_iter_scheduler = lr_schedules.make_lr_schedulers(
        optimizer=student_optim,
        total_iters=total_iters,
        schedule_type=lr_sched,
        step_epochs=lr_step_epochs,
        step_gamma=lr_step_gamma,
        poly_power=lr_poly_power)

    train_transforms = []
    eval_transforms = []

    if crop_size is not None:
        if aug_scale_hung:
            train_transforms.append(
                seg_transforms_cv.SegCVTransformRandomCropScaleHung(
                    crop_size, (0, 0),
                    uniform_scale=not aug_scale_non_uniform))
        elif aug_max_scale != 1.0 or aug_rot_mag != 0.0:
            train_transforms.append(
                seg_transforms_cv.SegCVTransformRandomCropRotateScale(
                    crop_size, (0, 0),
                    rot_mag=aug_rot_mag,
                    max_scale=aug_max_scale,
                    uniform_scale=not aug_scale_non_uniform,
                    constrain_rot_scale=True))
        else:
            train_transforms.append(
                seg_transforms_cv.SegCVTransformRandomCrop(crop_size, (0, 0)))
    else:
        if aug_scale_hung:
            raise NotImplementedError('aug_scale_hung requires a crop_size')

    if aug_hflip or aug_vflip or aug_hvflip:
        train_transforms.append(
            seg_transforms_cv.SegCVTransformRandomFlip(aug_hflip, aug_vflip,
                                                       aug_hvflip))
    train_transforms.append(
        seg_transforms_cv.SegCVTransformNormalizeToTensor(NET_MEAN, NET_STD))
    eval_transforms.append(
        seg_transforms_cv.SegCVTransformNormalizeToTensor(NET_MEAN, NET_STD))

    train_sup_ds = ds_src.dataset(
        labels=True,
        mask=False,
        xf=False,
        pair=False,
        transforms=seg_transforms.SegTransformCompose(train_transforms),
        pipeline_type='cv')
    train_unsup_ds = ds_src.dataset(
        labels=False,
        mask=True,
        xf=False,
        pair=False,
        transforms=seg_transforms.SegTransformCompose(train_transforms),
        pipeline_type='cv')
    eval_ds = ds_src.dataset(
        labels=True,
        mask=False,
        xf=False,
        pair=False,
        transforms=seg_transforms.SegTransformCompose(eval_transforms),
        pipeline_type='cv')

    add_mask_params_to_batch = mask_gen.AddMaskParamsToBatch(mask_generator)

    collate_fn = seg_data.SegCollate(BLOCK_SIZE)
    mask_collate_fn = seg_data.SegCollate(
        BLOCK_SIZE, batch_aug_fn=add_mask_params_to_batch)

    # Train data pipeline: data loaders
    sup_sampler = seg_data.RepeatSampler(
        torch.utils.data.SubsetRandomSampler(sup_ndx))
    train_sup_loader = torch.utils.data.DataLoader(train_sup_ds,
                                                   batch_size,
                                                   sampler=sup_sampler,
                                                   collate_fn=collate_fn,
                                                   num_workers=num_workers)
    if cons_weight > 0.0:
        unsup_sampler = seg_data.RepeatSampler(
            torch.utils.data.SubsetRandomSampler(unsup_ndx))
        train_unsup_loader_0 = torch.utils.data.DataLoader(
            train_unsup_ds,
            batch_size,
            sampler=unsup_sampler,
            collate_fn=mask_collate_fn,
            num_workers=num_workers)
        if mask_mix:
            train_unsup_loader_1 = torch.utils.data.DataLoader(
                train_unsup_ds,
                batch_size,
                sampler=unsup_sampler,
                collate_fn=collate_fn,
                num_workers=num_workers)
        else:
            train_unsup_loader_1 = None
    else:
        train_unsup_loader_0 = None
        train_unsup_loader_1 = None

    # Eval pipeline
    src_val_loader, tgt_val_loader, test_loader = datasets.eval_data_pipeline(
        ds_src, ds_tgt, src_val_ndx, tgt_val_ndx, test_ndx, batch_size,
        collate_fn, NET_MEAN, NET_STD, num_workers)

    # Report setttings
    print('Settings:')
    print(', '.join([
        '{}={}'.format(key, settings[key])
        for key in sorted(list(settings.keys()))
    ]))

    # Report dataset size
    print('Dataset:')
    print('len(sup_ndx)={}'.format(len(sup_ndx)))
    print('len(unsup_ndx)={}'.format(len(unsup_ndx)))
    if ds_src is not ds_tgt:
        print('len(src_val_ndx)={}'.format(len(tgt_val_ndx)))
        print('len(tgt_val_ndx)={}'.format(len(tgt_val_ndx)))
    else:
        print('len(val_ndx)={}'.format(len(tgt_val_ndx)))
    if test_ndx is not None:
        print('len(test_ndx)={}'.format(len(test_ndx)))

    if n_sup != -1:
        print('sup_ndx={}'.format(sup_ndx.tolist()))

    # Track mIoU for early stopping
    best_tgt_miou = None
    best_epoch = 0

    eval_net_state = {
        key: value.detach().cpu().numpy()
        for key, value in eval_net.state_dict().items()
    }

    # Create iterators
    train_sup_iter = iter(train_sup_loader)
    train_unsup_iter_0 = iter(
        train_unsup_loader_0) if train_unsup_loader_0 is not None else None
    train_unsup_iter_1 = iter(
        train_unsup_loader_1) if train_unsup_loader_1 is not None else None

    iter_i = 0
    print('Training...')
    for epoch_i in range(num_epochs):
        if lr_epoch_scheduler is not None:
            lr_epoch_scheduler.step(epoch_i)

        t1 = time.time()

        if rampup > 0:
            ramp_val = network_architectures.sigmoid_rampup(epoch_i, rampup)
        else:
            ramp_val = 1.0

        student_net.train()
        if teacher_net is not student_net:
            teacher_net.train()

        if freeze_bn:
            student_net.freeze_batchnorm()
            if teacher_net is not student_net:
                teacher_net.freeze_batchnorm()

        sup_loss_acc = 0.0
        consistency_loss_acc = 0.0
        conf_rate_acc = 0.0
        n_sup_batches = 0
        n_unsup_batches = 0

        src_val_iter = iter(
            src_val_loader) if src_val_loader is not None else None
        tgt_val_iter = iter(
            tgt_val_loader) if tgt_val_loader is not None else None

        for sup_batch in itertools.islice(train_sup_iter, iters_per_epoch):
            if lr_iter_scheduler is not None:
                lr_iter_scheduler.step(iter_i)
            student_optim.zero_grad()

            #
            # Supervised branch
            #

            batch_x = sup_batch['image'].to(torch_device)
            batch_y = sup_batch['labels'].to(torch_device)

            logits_sup = student_net(batch_x)
            sup_loss = clf_crossent_loss(logits_sup, batch_y[:, 0, :, :])
            sup_loss.backward()

            if cons_weight > 0.0:
                for _ in range(unsup_batch_ratio):
                    #
                    # Unsupervised branch
                    #

                    if mask_mix:
                        # Mix mode: batch consists of paired unsupervised samples and mask parameters
                        unsup_batch0 = next(train_unsup_iter_0)
                        unsup_batch1 = next(train_unsup_iter_1)
                        batch_ux0 = unsup_batch0['image'].to(torch_device)
                        batch_um0 = unsup_batch0['mask'].to(torch_device)
                        batch_ux1 = unsup_batch1['image'].to(torch_device)
                        batch_um1 = unsup_batch1['mask'].to(torch_device)
                        batch_mask_params = unsup_batch0['mask_params'].to(
                            torch_device)

                        # batch_um0 and batch_um1 are masks that are 1 for valid pixels, 0 for invalid pixels.
                        # They are used later on to scale the consistency loss, so that consistency loss is
                        # only computed for valid pixels.
                        # Explanation:
                        # When using geometric augmentations such as rotations, some pixels in the training
                        # crop may come from outside the bounds of the input image. These pixels will have a value
                        # of 0 in these masks. Similarly, when using scaled crops, the size of the crop
                        # from the input image that must be scaled to the size of the training crop may be
                        # larger than one/both of the input image dimensions. Pixels in the training crop
                        # that arise from outside the input image bounds will once again be given a value
                        # of 0 in these masks.

                        # Convert mask parameters to masks of shape (N,1,H,W)
                        batch_mix_masks = mask_generator.torch_masks_from_params(
                            batch_mask_params, batch_ux0.shape[2:4],
                            torch_device)

                        # Mix images with masks
                        batch_ux_mixed = batch_ux0 * (
                            1 - batch_mix_masks) + batch_ux1 * batch_mix_masks
                        batch_um_mixed = batch_um0 * (
                            1 - batch_mix_masks) + batch_um1 * batch_mix_masks

                        # Get teacher predictions for original images
                        with torch.no_grad():
                            logits_u0_tea = teacher_net(batch_ux0).detach()
                            logits_u1_tea = teacher_net(batch_ux1).detach()
                        # Get student prediction for mixed image
                        logits_cons_stu = student_net(batch_ux_mixed)

                        # Mix teacher predictions using same mask
                        # It makes no difference whether we do this with logits or probabilities as
                        # the mask pixels are either 1 or 0
                        logits_cons_tea = logits_u0_tea * (
                            1 -
                            batch_mix_masks) + logits_u1_tea * batch_mix_masks

                        # Logits -> probs
                        prob_cons_tea = F.softmax(logits_cons_tea, dim=1)
                        prob_cons_stu = F.softmax(logits_cons_stu, dim=1)

                        loss_mask = batch_um_mixed

                    else:
                        # Cut mode: batch consists of unsupervised samples and mask params
                        unsup_batch = next(train_unsup_iter_0)
                        batch_ux = unsup_batch['image'].to(torch_device)
                        batch_um = unsup_batch['mask'].to(torch_device)
                        batch_mask_params = unsup_batch['mask_params'].to(
                            torch_device)

                        # Convert mask parameters to masks of shape (N,1,H,W)
                        batch_cut_masks = mask_generator.torch_masks_from_params(
                            batch_mask_params, batch_ux.shape[2:4],
                            torch_device)

                        # Cut image with mask (mask regions to zero)
                        batch_ux_cut = batch_ux * batch_cut_masks

                        # Get teacher predictions for original image
                        with torch.no_grad():
                            logits_cons_tea = teacher_net(batch_ux).detach()
                        # Get student prediction for cut image
                        logits_cons_stu = student_net(batch_ux_cut)

                        # Logits -> probs
                        prob_cons_tea = F.softmax(logits_cons_tea, dim=1)
                        prob_cons_stu = F.softmax(logits_cons_stu, dim=1)

                        loss_mask = batch_cut_masks * batch_um

                    # -- shared by mix and cut --

                    # Confidence thresholding
                    if conf_thresh > 0.0:
                        # Compute confidence of teacher predictions
                        conf_tea = prob_cons_tea.max(dim=1)[0]
                        # Compute confidence mask
                        conf_mask = (conf_tea >=
                                     conf_thresh).float()[:, None, :, :]
                        # Record rate for reporting
                        conf_rate_acc += float(conf_mask.mean())
                        # Average confidence mask if requested
                        if not conf_per_pixel:
                            conf_mask = conf_mask.mean()

                        loss_mask = loss_mask * conf_mask
                    elif rampup > 0:
                        conf_rate_acc += ramp_val

                    # Compute per-pixel consistency loss
                    # Note that the way we aggregate the loss across the class/channel dimension (1)
                    # depends on the loss function used. Generally, summing over the class dimension
                    # keeps the magnitude of the gradient of the loss w.r.t. the logits
                    # nearly constant w.r.t. the number of classes. When using logit-variance,
                    # dividing by `sqrt(num_classes)` helps.
                    if cons_loss_fn == 'var':
                        delta_prob = prob_cons_stu - prob_cons_tea
                        consistency_loss = delta_prob * delta_prob
                        consistency_loss = consistency_loss.sum(dim=1,
                                                                keepdim=True)
                    elif cons_loss_fn == 'logits_var':
                        delta_logits = logits_cons_stu - logits_cons_tea
                        consistency_loss = delta_logits * delta_logits
                        consistency_loss = consistency_loss.sum(
                            dim=1, keepdim=True) / root_n_classes
                    elif cons_loss_fn == 'logits_smoothl1':
                        consistency_loss = F.smooth_l1_loss(logits_cons_stu,
                                                            logits_cons_tea,
                                                            reduce=False)
                        consistency_loss = consistency_loss.sum(
                            dim=1, keepdim=True) / root_n_classes
                    elif cons_loss_fn == 'bce':
                        consistency_loss = network_architectures.robust_binary_crossentropy(
                            prob_cons_stu, prob_cons_tea)
                        consistency_loss = consistency_loss.sum(dim=1,
                                                                keepdim=True)
                    elif cons_loss_fn == 'kld':
                        consistency_loss = F.kl_div(F.log_softmax(
                            logits_cons_stu, dim=1),
                                                    prob_cons_tea,
                                                    reduce=False)
                        consistency_loss = consistency_loss.sum(dim=1,
                                                                keepdim=True)
                    else:
                        raise ValueError(
                            'Unknown consistency loss function {}'.format(
                                cons_loss_fn))

                    # Apply consistency loss mask and take the mean over pixels and images
                    consistency_loss = (consistency_loss * loss_mask).mean()

                    # Modulate with rampup if desired
                    if rampup > 0:
                        consistency_loss = consistency_loss * ramp_val

                    # Weight the consistency loss and back-prop
                    unsup_loss = consistency_loss * cons_weight
                    unsup_loss.backward()

                    consistency_loss_acc += float(consistency_loss.detach())

                    n_unsup_batches += 1

            student_optim.step()
            if teacher_optim is not None:
                teacher_optim.step()

            sup_loss_val = float(sup_loss.detach())
            if np.isnan(sup_loss_val):
                print('NaN detected; network dead, bailing.')
                return

            sup_loss_acc += sup_loss_val
            n_sup_batches += 1
            iter_i += 1

        sup_loss_acc /= n_sup_batches
        if n_unsup_batches > 0:
            consistency_loss_acc /= n_unsup_batches
            conf_rate_acc /= n_unsup_batches

        eval_net.eval()

        if ds_src is not ds_tgt:
            src_iou_eval = evaluation.EvaluatorIoU(ds_src.num_classes,
                                                   bin_fill_holes)
            with torch.no_grad():
                for batch in src_val_iter:
                    batch_x = batch['image'].to(torch_device)
                    batch_y = batch['labels'].numpy()

                    logits = eval_net(batch_x)
                    pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                    for sample_i in range(len(batch_y)):
                        src_iou_eval.sample(batch_y[sample_i, 0],
                                            pred_y[sample_i],
                                            ignore_value=255)

            src_iou = src_iou_eval.score()
            src_miou = src_iou.mean()
        else:
            src_iou_eval = src_iou = src_miou = None

        tgt_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes,
                                               bin_fill_holes)
        with torch.no_grad():
            for batch in tgt_val_iter:
                batch_x = batch['image'].to(torch_device)
                batch_y = batch['labels'].numpy()

                logits = eval_net(batch_x)
                pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                for sample_i in range(len(batch_y)):
                    tgt_iou_eval.sample(batch_y[sample_i, 0],
                                        pred_y[sample_i],
                                        ignore_value=255)

        tgt_iou = tgt_iou_eval.score()
        tgt_miou = tgt_iou.mean()

        t2 = time.time()

        if ds_src is not ds_tgt:
            print(
                'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, '
                'SRC VAL mIoU={:.3%}, TGT VAL mIoU={:.3%}'.format(
                    epoch_i + 1, t2 - t1, sup_loss_acc, consistency_loss_acc,
                    conf_rate_acc, src_miou, tgt_miou))
            print('-- SRC {}'.format(', '.join(
                ['{:.3%}'.format(x) for x in src_iou])))
            print('-- TGT {}'.format(', '.join(
                ['{:.3%}'.format(x) for x in tgt_iou])))
        else:
            print(
                'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, VAL mIoU={:.3%}'
                .format(epoch_i + 1, t2 - t1, sup_loss_acc,
                        consistency_loss_acc, conf_rate_acc, tgt_miou))
            print('-- {}'.format(', '.join(
                ['{:.3%}'.format(x) for x in tgt_iou])))

    if save_model:
        model_path = os.path.join(submit_config.run_dir, "model.pth")
        torch.save(eval_net, model_path)

    if save_preds:
        out_dir = os.path.join(submit_config.run_dir, 'preds')
        os.makedirs(out_dir, exist_ok=True)
        with torch.no_grad():
            for batch in tgt_val_loader:
                batch_x = batch['image'].to(torch_device)
                batch_ndx = batch['index'].numpy()

                logits = eval_net(batch_x)
                pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                for sample_i, sample_ndx in enumerate(batch_ndx):
                    ds_tgt.save_prediction_by_index(
                        out_dir, pred_y[sample_i].astype(np.uint32),
                        sample_ndx)
    else:
        out_dir = None

    if test_loader is not None:
        test_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes,
                                                bin_fill_holes)
        with torch.no_grad():
            for batch in test_loader:
                batch_x = batch['image'].to(torch_device)
                batch_ndx = batch['index'].numpy()

                logits = eval_net(batch_x)
                pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                for sample_i, sample_ndx in enumerate(batch_ndx):
                    if save_preds:
                        ds_tgt.save_prediction_by_index(
                            out_dir, pred_y[sample_i].astype(np.uint32),
                            sample_ndx)
                    test_iou_eval.sample(batch_y[sample_i, 0],
                                         pred_y[sample_i],
                                         ignore_value=255)

        test_iou = test_iou_eval.score()
        test_miou = test_iou.mean()

        print('FINAL TEST: mIoU={:.3%}'.format(test_miou))
        print('-- TEST {}'.format(', '.join(
            ['{:.3%}'.format(x) for x in test_iou])))