def train_seg_semisup_vat_mt(
        submit_config: job_helper.SubmitConfig, dataset, model, arch,
        freeze_bn, opt_type, sgd_momentum, sgd_nesterov, sgd_weight_decay,
        learning_rate, lr_sched, lr_step_epochs, lr_step_gamma, lr_poly_power,
        teacher_alpha, bin_fill_holes, crop_size, aug_hflip, aug_vflip,
        aug_hvflip, aug_scale_hung, aug_max_scale, aug_scale_non_uniform,
        aug_rot_mag, vat_radius, adaptive_vat_radius, vat_dir_from_student,
        cons_loss_fn, cons_weight, conf_thresh, conf_per_pixel, rampup,
        unsup_batch_ratio, num_epochs, iters_per_epoch, batch_size, n_sup,
        n_unsup, n_val, split_seed, split_path, val_seed, save_preds,
        save_model, num_workers):
    settings = locals().copy()
    del settings['submit_config']

    import os
    import time
    import itertools
    import math
    import numpy as np
    import torch, torch.nn as nn, torch.nn.functional as F
    from architectures import network_architectures
    import torch.utils.data
    from datapipe import datasets
    from datapipe import seg_data, seg_transforms, seg_transforms_cv
    import evaluation
    import optim_weight_ema
    import lr_schedules

    if crop_size == '':
        crop_size = None
    else:
        crop_size = [int(x.strip()) for x in crop_size.split(',')]

    torch_device = torch.device('cuda:0')

    #
    # Load data sets
    #
    ds_dict = datasets.load_dataset(dataset, n_val, val_seed, n_sup, n_unsup,
                                    split_seed, split_path)

    ds_src = ds_dict['ds_src']
    ds_tgt = ds_dict['ds_tgt']
    tgt_val_ndx = ds_dict['val_ndx_tgt']
    src_val_ndx = ds_dict['val_ndx_src'] if ds_src is not ds_tgt else None
    test_ndx = ds_dict['test_ndx_tgt']
    sup_ndx = ds_dict['sup_ndx']
    unsup_ndx = ds_dict['unsup_ndx']

    n_classes = ds_src.num_classes
    root_n_classes = math.sqrt(n_classes)

    if bin_fill_holes and n_classes != 2:
        print(
            'Binary hole filling can only be used with binary (2-class) segmentation datasets'
        )
        return

    print('Loaded data')

    # Build network
    NetClass = network_architectures.seg.get(arch)

    student_net = NetClass(ds_src.num_classes).to(torch_device)

    if opt_type == 'adam':
        student_optim = torch.optim.Adam([
            dict(params=student_net.pretrained_parameters(),
                 lr=learning_rate * 0.1),
            dict(params=student_net.new_parameters(), lr=learning_rate)
        ])
    elif opt_type == 'sgd':
        student_optim = torch.optim.SGD([
            dict(params=student_net.pretrained_parameters(),
                 lr=learning_rate * 0.1),
            dict(params=student_net.new_parameters(), lr=learning_rate)
        ],
                                        momentum=sgd_momentum,
                                        nesterov=sgd_nesterov,
                                        weight_decay=sgd_weight_decay)
    else:
        raise ValueError('Unknown opt_type {}'.format(opt_type))

    if model == 'mean_teacher':
        teacher_net = NetClass(ds_src.num_classes).to(torch_device)

        for p in teacher_net.parameters():
            p.requires_grad = False

        teacher_optim = optim_weight_ema.EMAWeightOptimizer(
            teacher_net, student_net, teacher_alpha)
        eval_net = teacher_net
    elif model == 'pi':
        teacher_net = student_net
        teacher_optim = None
        eval_net = student_net
    else:
        print('Unknown model type {}'.format(model))
        return

    if vat_dir_from_student:
        vat_dir_net = student_net
    else:
        vat_dir_net = teacher_net

    BLOCK_SIZE = student_net.BLOCK_SIZE
    NET_MEAN, NET_STD = seg_transforms.get_mean_std(ds_tgt, student_net)

    if freeze_bn:
        if not hasattr(student_net, 'freeze_batchnorm'):
            raise ValueError(
                'Network {} does not support batchnorm freezing'.format(arch))

    clf_crossent_loss = nn.CrossEntropyLoss(ignore_index=255)

    print('Built network')

    if iters_per_epoch == -1:
        iters_per_epoch = len(unsup_ndx) // batch_size
    total_iters = iters_per_epoch * num_epochs

    lr_epoch_scheduler, lr_iter_scheduler = lr_schedules.make_lr_schedulers(
        optimizer=student_optim,
        total_iters=total_iters,
        schedule_type=lr_sched,
        step_epochs=lr_step_epochs,
        step_gamma=lr_step_gamma,
        poly_power=lr_poly_power)

    # Train data pipeline: transforms
    train_transforms = []

    if crop_size is not None:
        if aug_scale_hung:
            train_transforms.append(
                seg_transforms_cv.SegCVTransformRandomCropScaleHung(
                    crop_size, (0, 0),
                    uniform_scale=not aug_scale_non_uniform))
        elif aug_max_scale != 1.0 or aug_rot_mag != 0.0:
            train_transforms.append(
                seg_transforms_cv.SegCVTransformRandomCropRotateScale(
                    crop_size, (0, 0),
                    rot_mag=aug_rot_mag,
                    max_scale=aug_max_scale,
                    uniform_scale=not aug_scale_non_uniform,
                    constrain_rot_scale=True))
        else:
            train_transforms.append(
                seg_transforms_cv.SegCVTransformRandomCrop(crop_size, (0, 0)))
    else:
        if aug_scale_hung:
            raise NotImplementedError('aug_scale_hung requires a crop_size')

    if aug_hflip or aug_vflip or aug_hvflip:
        train_transforms.append(
            seg_transforms_cv.SegCVTransformRandomFlip(aug_hflip, aug_vflip,
                                                       aug_hvflip))
    train_transforms.append(
        seg_transforms_cv.SegCVTransformNormalizeToTensor(NET_MEAN, NET_STD))

    # Train data pipeline: supervised and unsupervised data sets
    train_sup_ds = ds_src.dataset(
        labels=True,
        mask=False,
        xf=False,
        pair=False,
        transforms=seg_transforms.SegTransformCompose(train_transforms),
        pipeline_type='cv')
    train_unsup_ds = ds_src.dataset(
        labels=False,
        mask=True,
        xf=False,
        pair=False,
        transforms=seg_transforms.SegTransformCompose(train_transforms),
        pipeline_type='cv')

    collate_fn = seg_data.SegCollate(BLOCK_SIZE)

    # Train data pipeline: data loaders
    sup_sampler = seg_data.RepeatSampler(
        torch.utils.data.SubsetRandomSampler(sup_ndx))
    train_sup_loader = torch.utils.data.DataLoader(train_sup_ds,
                                                   batch_size,
                                                   sampler=sup_sampler,
                                                   collate_fn=collate_fn,
                                                   num_workers=num_workers)
    if cons_weight > 0.0:
        unsup_sampler = seg_data.RepeatSampler(
            torch.utils.data.SubsetRandomSampler(unsup_ndx))
        train_unsup_loader = torch.utils.data.DataLoader(
            train_unsup_ds,
            batch_size,
            sampler=unsup_sampler,
            collate_fn=collate_fn,
            num_workers=num_workers)
    else:
        train_unsup_loader = None

    # Eval pipeline
    src_val_loader, tgt_val_loader, test_loader = datasets.eval_data_pipeline(
        ds_src, ds_tgt, src_val_ndx, tgt_val_ndx, test_ndx, batch_size,
        collate_fn, NET_MEAN, NET_STD, num_workers)

    # Report setttings
    print('Settings:')
    print(', '.join([
        '{}={}'.format(key, settings[key])
        for key in sorted(list(settings.keys()))
    ]))

    # Report dataset size
    print('Dataset:')
    print('len(sup_ndx)={}'.format(len(sup_ndx)))
    print('len(unsup_ndx)={}'.format(len(unsup_ndx)))
    if ds_src is not ds_tgt:
        print('len(src_val_ndx)={}'.format(len(tgt_val_ndx)))
        print('len(tgt_val_ndx)={}'.format(len(tgt_val_ndx)))
    else:
        print('len(val_ndx)={}'.format(len(tgt_val_ndx)))
    if test_ndx is not None:
        print('len(test_ndx)={}'.format(len(test_ndx)))

    if n_sup != -1:
        print('sup_ndx={}'.format(sup_ndx.tolist()))

    def t_dot(a, b):
        return (a * b).sum(dim=1, keepdim=True)

    def normalize_eps(x):
        x_flat = x.view(len(x), -1)
        mag = torch.sqrt((x_flat * x_flat).sum(dim=1))
        return x / (mag[:, None, None, None] + 1e-12)

    def normalized_noise(x, requires_grad=False, scale=1.0):
        eps = torch.randn(x.shape, dtype=torch.float, device=x.device)
        eps = normalize_eps(eps) * scale
        if requires_grad:
            eps = eps.clone().detach().requires_grad_(True)
        return eps

    def vat_direction(x):
        """
        Compute the VAT perturbation direction vector

        :param x: input image as a `(N, C, H, W)` tensor
        :return: VAT direction as a `(N, C, H, W)` tensor
        """
        # Put the network used to get the VAT direction in eval mode and get the predicted
        # logits and probabilities for the batch of samples x
        vat_dir_net.eval()
        with torch.no_grad():
            y_pred_logits = vat_dir_net(x).detach()
        y_pred_prob = F.softmax(y_pred_logits, dim=1)

        # Initial noise offset vector with requires_grad=True
        noise_scale = 1.0e-6 * x.shape[2] * x.shape[3] / 1000
        eps = normalized_noise(x, requires_grad=True, scale=noise_scale)

        # Predict logits and probs for sample perturbed by eps
        eps_pred_logits = vat_dir_net(x.detach() + eps)
        eps_pred_prob = F.softmax(eps_pred_logits, dim=1)

        # Choose our loss function
        if cons_loss_fn == 'var':
            delta = (eps_pred_prob - y_pred_prob)
            loss = (delta * delta).sum()
        elif cons_loss_fn == 'bce':
            loss = network_architectures.robust_binary_crossentropy(
                eps_pred_prob, y_pred_prob).sum()
        elif cons_loss_fn == 'kld':
            loss = F.kl_div(F.log_softmax(eps_pred_logits, dim=1),
                            y_pred_prob,
                            reduce=False).sum()
        elif cons_loss_fn == 'logits_var':
            delta = (eps_pred_logits - y_pred_logits)
            loss = (delta * delta).sum()
        else:
            raise ValueError(
                'Unknown consistency loss function {}'.format(cons_loss_fn))

        # Differentiate the loss w.r.t. the perturbation
        eps_adv = torch.autograd.grad(outputs=loss,
                                      inputs=eps,
                                      create_graph=True,
                                      retain_graph=True,
                                      only_inputs=True)[0]

        # Normalize the adversarial perturbation
        return normalize_eps(eps_adv), y_pred_logits, y_pred_prob

    def vat_perburbation(x, m):
        eps_adv_nrm, y_pred_logits, y_pred_prob = vat_direction(x)

        if adaptive_vat_radius:
            # We view semantic segmentation as predicting the class of a pixel
            # given a patch centred on that pixel.
            # The most similar patch in terms of pixel content to a patch P
            # is a patch Q whose central pixel is an immediate neighbour
            # of the central pixel P.
            # We therefore use the image Jacobian (gradient w.r.t. x and y) to
            # get a sense of the distance between neighbouring patches
            # so we can scale the VAT radius according to the image content.

            # Delta in vertical and horizontal directions
            delta_v = x[:, :, 2:, :] - x[:, :, :-2, :]
            delta_h = x[:, :, :, 2:] - x[:, :, :, :-2]

            # delta_h and delta_v are the difference between pixels where the step size is 2, rather than 1
            # So divide by 2 to get the magnitude of the Jacobian

            delta_v = delta_v.view(len(delta_v), -1)
            delta_h = delta_h.view(len(delta_h), -1)
            adv_radius = vat_radius * torch.sqrt(
                (delta_v**2).sum(dim=1) +
                (delta_h**2).sum(dim=1))[:, None, None, None] * 0.5
        else:
            scale = math.sqrt(float(x.shape[1] * x.shape[2] * x.shape[3]))
            adv_radius = vat_radius * scale

        return (eps_adv_nrm * adv_radius).detach(), y_pred_logits, y_pred_prob

    # Track mIoU for early stopping
    best_tgt_miou = None
    best_epoch = 0

    eval_net_state = {
        key: value.detach().cpu().numpy()
        for key, value in eval_net.state_dict().items()
    }

    # Create iterators
    train_sup_iter = iter(train_sup_loader)
    train_unsup_iter = iter(
        train_unsup_loader) if train_unsup_loader is not None else None

    iter_i = 0
    print('Training...')
    for epoch_i in range(num_epochs):
        if lr_epoch_scheduler is not None:
            lr_epoch_scheduler.step(epoch_i)

        t1 = time.time()

        if rampup > 0:
            ramp_val = network_architectures.sigmoid_rampup(epoch_i, rampup)
        else:
            ramp_val = 1.0

        student_net.train()
        if teacher_net is not student_net:
            teacher_net.train()

        if freeze_bn:
            student_net.freeze_batchnorm()
            if teacher_net is not student_net:
                teacher_net.freeze_batchnorm()

        sup_loss_acc = 0.0
        consistency_loss_acc = 0.0
        conf_rate_acc = 0.0
        n_sup_batches = 0
        n_unsup_batches = 0

        src_val_iter = iter(
            src_val_loader) if src_val_loader is not None else None
        tgt_val_iter = iter(
            tgt_val_loader) if tgt_val_loader is not None else None

        for sup_batch in itertools.islice(train_sup_iter, iters_per_epoch):
            if lr_iter_scheduler is not None:
                lr_iter_scheduler.step(iter_i)
            student_optim.zero_grad()

            #
            # Supervised branch
            #

            batch_x = sup_batch['image'].to(torch_device)
            batch_y = sup_batch['labels'].to(torch_device)

            logits_sup = student_net(batch_x)
            sup_loss = clf_crossent_loss(logits_sup, batch_y[:, 0, :, :])
            sup_loss.backward()

            if cons_weight > 0.0:
                for _ in range(unsup_batch_ratio):
                    #
                    # Unsupervised branch
                    #

                    unsup_batch = next(train_unsup_iter)

                    # Input images to torch tensor
                    batch_ux = unsup_batch['image'].to(torch_device)
                    batch_um = unsup_batch['mask'].to(torch_device)

                    # batch_um is a mask that is 1 for valid pixels, 0 for invalid pixels.
                    # It us used later on to scale the consistency loss, so that consistency loss is
                    # only computed for valid pixels.
                    # Explanation:
                    # When using geometric augmentations such as rotations, some pixels in the training
                    # crop may come from outside the bounds of the input image. These pixels will have a value
                    # of 0 in these masks. Similarly, when using scaled crops, the size of the crop
                    # from the input image that must be scaled to the size of the training crop may be
                    # larger than one/both of the input image dimensions. Pixels in the training crop
                    # that arise from outside the input image bounds will once again be given a value
                    # of 0 in these masks.

                    # Compute VAT perburbation
                    x_perturb, logits_cons_tea, prob_cons_tea = vat_perburbation(
                        batch_ux, batch_um)

                    # Perturb image
                    batch_ux_adv = batch_ux + x_perturb

                    # Get teacher predictions for original image
                    with torch.no_grad():
                        logits_cons_tea = teacher_net(batch_ux).detach()
                    # Get student prediction for cut image
                    logits_cons_stu = student_net(batch_ux_adv)

                    # Logits -> probs
                    prob_cons_tea = F.softmax(logits_cons_tea, dim=1)
                    prob_cons_stu = F.softmax(logits_cons_stu, dim=1)

                    loss_mask = batch_um

                    # Confidence thresholding
                    if conf_thresh > 0.0:
                        # Compute confidence of teacher predictions
                        conf_tea = prob_cons_tea.max(dim=1)[0]
                        # Compute confidence mask
                        conf_mask = (conf_tea >=
                                     conf_thresh).float()[:, None, :, :]
                        # Record rate for reporting
                        conf_rate_acc += float(conf_mask.mean())
                        # Average confidence mask if requested
                        if not conf_per_pixel:
                            conf_mask = conf_mask.mean()

                        loss_mask = loss_mask * conf_mask
                    elif rampup > 0:
                        conf_rate_acc += ramp_val

                    # Compute per-pixel consistency loss
                    # Note that the way we aggregate the loss across the class/channel dimension (1)
                    # depends on the loss function used. Generally, summing over the class dimension
                    # keeps the magnitude of the gradient of the loss w.r.t. the logits
                    # nearly constant w.r.t. the number of classes. When using logit-variance,
                    # dividing by `sqrt(num_classes)` helps.
                    if cons_loss_fn == 'var':
                        delta_prob = prob_cons_stu - prob_cons_tea
                        consistency_loss = delta_prob * delta_prob
                        consistency_loss = consistency_loss.sum(dim=1,
                                                                keepdim=True)
                    elif cons_loss_fn == 'logits_var':
                        delta_logits = logits_cons_stu - logits_cons_tea
                        consistency_loss = delta_logits * delta_logits
                        consistency_loss = consistency_loss.sum(
                            dim=1, keepdim=True) / root_n_classes
                    elif cons_loss_fn == 'bce':
                        consistency_loss = network_architectures.robust_binary_crossentropy(
                            prob_cons_stu, prob_cons_tea)
                        consistency_loss = consistency_loss.sum(dim=1,
                                                                keepdim=True)
                    elif cons_loss_fn == 'kld':
                        consistency_loss = F.kl_div(F.log_softmax(
                            logits_cons_stu, dim=1),
                                                    prob_cons_tea,
                                                    reduce=False)
                        consistency_loss = consistency_loss.sum(dim=1,
                                                                keepdim=True)
                    else:
                        raise ValueError(
                            'Unknown consistency loss function {}'.format(
                                cons_loss_fn))

                    # Apply consistency loss mask and take the mean over pixels and images
                    consistency_loss = (consistency_loss * loss_mask).mean()

                    # Modulate with rampup if desired
                    if rampup > 0:
                        consistency_loss = consistency_loss * ramp_val

                    # Weight the consistency loss and back-prop
                    unsup_loss = consistency_loss * cons_weight
                    unsup_loss.backward()

                    consistency_loss_val = float(consistency_loss.detach())
                    consistency_loss_acc += consistency_loss_val

                    if np.isnan(consistency_loss_val):
                        print(
                            'NaN detected in consistency loss; bailing out...')
                        return

                    n_unsup_batches += 1

            student_optim.step()
            if teacher_optim is not None:
                teacher_optim.step()

            sup_loss_val = float(sup_loss.detach())
            sup_loss_acc += sup_loss_val

            if np.isnan(sup_loss_val):
                print('NaN detected in supervised loss; bailing out...')
                return

            n_sup_batches += 1
            iter_i += 1

        sup_loss_acc /= n_sup_batches
        if n_unsup_batches > 0:
            consistency_loss_acc /= n_unsup_batches
            conf_rate_acc /= n_unsup_batches

        eval_net.eval()

        if src_val_iter is not None:
            src_iou_eval = evaluation.EvaluatorIoU(ds_src.num_classes,
                                                   bin_fill_holes)
            with torch.no_grad():
                for batch in src_val_iter:
                    batch_x = batch['image'].to(torch_device)
                    batch_y = batch['labels'].numpy()

                    logits = eval_net(batch_x)
                    pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                    for sample_i in range(len(batch_y)):
                        src_iou_eval.sample(batch_y[sample_i, 0],
                                            pred_y[sample_i],
                                            ignore_value=255)

            src_iou = src_iou_eval.score()
            src_miou = src_iou.mean()
        else:
            src_iou_eval = src_iou = src_miou = None

        tgt_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes,
                                               bin_fill_holes)
        with torch.no_grad():
            for batch in tgt_val_iter:
                batch_x = batch['image'].to(torch_device)
                batch_y = batch['labels'].numpy()

                logits = eval_net(batch_x)
                pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                for sample_i in range(len(batch_y)):
                    tgt_iou_eval.sample(batch_y[sample_i, 0],
                                        pred_y[sample_i],
                                        ignore_value=255)

        tgt_iou = tgt_iou_eval.score()
        tgt_miou = tgt_iou.mean()

        t2 = time.time()

        if ds_src is not ds_tgt:
            print(
                'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, '
                'SRC VAL mIoU={:.3%}, TGT VAL mIoU={:.3%}'.format(
                    epoch_i + 1, t2 - t1, sup_loss_acc, consistency_loss_acc,
                    conf_rate_acc, src_miou, tgt_miou))
            print('-- SRC {}'.format(', '.join(
                ['{:.3%}'.format(x) for x in src_iou])))
            print('-- TGT {}'.format(', '.join(
                ['{:.3%}'.format(x) for x in tgt_iou])))
        else:
            print(
                'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, VAL mIoU={:.3%}'
                .format(epoch_i + 1, t2 - t1, sup_loss_acc,
                        consistency_loss_acc, conf_rate_acc, tgt_miou))
            print('-- {}'.format(', '.join(
                ['{:.3%}'.format(x) for x in tgt_iou])))

    if save_model:
        model_path = os.path.join(submit_config.run_dir, "model.pth")
        torch.save(eval_net, model_path)

    if save_preds:
        out_dir = os.path.join(submit_config.run_dir, 'preds')
        os.makedirs(out_dir, exist_ok=True)
        with torch.no_grad():
            for batch in tgt_val_loader:
                batch_x = batch['image'].to(torch_device)
                batch_ndx = batch['index'].numpy()

                logits = eval_net(batch_x)
                pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                for sample_i, sample_ndx in enumerate(batch_ndx):
                    ds_tgt.save_prediction_by_index(
                        out_dir, pred_y[sample_i].astype(np.uint32),
                        sample_ndx)
    else:
        out_dir = None

    if test_loader is not None:
        test_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes,
                                                bin_fill_holes)
        with torch.no_grad():
            for batch in test_loader:
                batch_x = batch['image'].to(torch_device)
                batch_ndx = batch['index'].numpy()

                logits = eval_net(batch_x)
                pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                for sample_i, sample_ndx in enumerate(batch_ndx):
                    if save_preds:
                        ds_tgt.save_prediction_by_index(
                            out_dir, pred_y[sample_i].astype(np.uint32),
                            sample_ndx)
                    test_iou_eval.sample(batch_y[sample_i, 0],
                                         pred_y[sample_i],
                                         ignore_value=255)

        test_iou = test_iou_eval.score()
        test_miou = test_iou.mean()

        print('FINAL TEST: mIoU={:.3%}'.format(test_miou))
        print('-- TEST {}'.format(', '.join(
            ['{:.3%}'.format(x) for x in test_iou])))
def train_seg_semisup_aug_mt(
        submit_config: job_helper.SubmitConfig, dataset, model, arch,
        freeze_bn, opt_type, sgd_momentum, sgd_nesterov, sgd_weight_decay,
        learning_rate, lr_sched, lr_step_epochs, lr_step_gamma, lr_poly_power,
        teacher_alpha, bin_fill_holes, crop_size, aug_offset_range, aug_hflip,
        aug_vflip, aug_hvflip, aug_scale_hung, aug_max_scale,
        aug_scale_non_uniform, aug_rot_mag, aug_free_scale_rot, cons_loss_fn,
        cons_weight, conf_thresh, conf_per_pixel, rampup, unsup_batch_ratio,
        num_epochs, iters_per_epoch, batch_size, n_sup, n_unsup, n_val,
        split_seed, split_path, val_seed, save_preds, save_model, num_workers):
    settings = locals().copy()
    del settings['submit_config']

    import os
    import math
    import time
    import itertools
    import numpy as np
    import torch.nn as nn, torch.nn.functional as F
    from architectures import network_architectures
    import torch.utils.data
    from datapipe import datasets
    from datapipe import seg_data, seg_transforms, seg_transforms_cv
    import evaluation
    import optim_weight_ema
    import lr_schedules

    from datapipe import torch_utils
    affine_align_corners_kw = torch_utils.affine_align_corners_kw(True)

    if crop_size == '':
        crop_size = None
    else:
        crop_size = [int(x.strip()) for x in crop_size.split(',')]

    torch_device = torch.device('cuda:0')

    #
    # Load data sets
    #
    ds_dict = datasets.load_dataset(dataset, n_val, val_seed, n_sup, n_unsup,
                                    split_seed, split_path)

    ds_src = ds_dict['ds_src']
    ds_tgt = ds_dict['ds_tgt']
    tgt_val_ndx = ds_dict['val_ndx_tgt']
    src_val_ndx = ds_dict['val_ndx_src'] if ds_src is not ds_tgt else None
    test_ndx = ds_dict['test_ndx_tgt']
    sup_ndx = ds_dict['sup_ndx']
    unsup_ndx = ds_dict['unsup_ndx']

    n_classes = ds_src.num_classes
    root_n_classes = math.sqrt(n_classes)

    if bin_fill_holes and n_classes != 2:
        print(
            'Binary hole filling can only be used with binary (2-class) segmentation datasets'
        )
        return

    print('Loaded data')

    # Build network
    NetClass = network_architectures.seg.get(arch)

    student_net = NetClass(ds_src.num_classes).to(torch_device)

    if opt_type == 'adam':
        student_optim = torch.optim.Adam([
            dict(params=student_net.pretrained_parameters(),
                 lr=learning_rate * 0.1),
            dict(params=student_net.new_parameters(), lr=learning_rate)
        ])
    elif opt_type == 'sgd':
        student_optim = torch.optim.SGD([
            dict(params=student_net.pretrained_parameters(),
                 lr=learning_rate * 0.1),
            dict(params=student_net.new_parameters(), lr=learning_rate)
        ],
                                        momentum=sgd_momentum,
                                        nesterov=sgd_nesterov,
                                        weight_decay=sgd_weight_decay)
    else:
        raise ValueError('Unknown opt_type {}'.format(opt_type))

    if model == 'mean_teacher':
        teacher_net = NetClass(ds_src.num_classes).to(torch_device)

        for p in teacher_net.parameters():
            p.requires_grad = False

        teacher_optim = optim_weight_ema.EMAWeightOptimizer(
            teacher_net, student_net, teacher_alpha)
        eval_net = teacher_net
    elif model == 'pi':
        teacher_net = student_net
        teacher_optim = None
        eval_net = student_net
    else:
        print('Unknown model type {}'.format(model))
        return

    BLOCK_SIZE = student_net.BLOCK_SIZE
    NET_MEAN, NET_STD = seg_transforms.get_mean_std(ds_tgt, student_net)

    if freeze_bn:
        if not hasattr(student_net, 'freeze_batchnorm'):
            raise ValueError(
                'Network {} does not support batchnorm freezing'.format(arch))

    clf_crossent_loss = nn.CrossEntropyLoss(ignore_index=255)

    print('Built network')

    if iters_per_epoch == -1:
        iters_per_epoch = len(unsup_ndx) // batch_size
    total_iters = iters_per_epoch * num_epochs

    lr_epoch_scheduler, lr_iter_scheduler = lr_schedules.make_lr_schedulers(
        optimizer=student_optim,
        total_iters=total_iters,
        schedule_type=lr_sched,
        step_epochs=lr_step_epochs,
        step_gamma=lr_step_gamma,
        poly_power=lr_poly_power)

    # Train data pipeline: transforms
    train_transforms = []

    if crop_size is not None:
        if aug_scale_hung:
            train_transforms.append(
                seg_transforms_cv.SegCVTransformRandomCropScaleHung(
                    crop_size, (aug_offset_range, aug_offset_range),
                    uniform_scale=not aug_scale_non_uniform))
        elif aug_max_scale != 1.0 or aug_rot_mag != 0.0:
            train_transforms.append(
                seg_transforms_cv.SegCVTransformRandomCropRotateScale(
                    crop_size, (aug_offset_range, aug_offset_range),
                    rot_mag=aug_rot_mag,
                    max_scale=aug_max_scale,
                    uniform_scale=not aug_scale_non_uniform,
                    constrain_rot_scale=not aug_free_scale_rot))
        else:
            train_transforms.append(
                seg_transforms_cv.SegCVTransformRandomCrop(
                    crop_size, (aug_offset_range, aug_offset_range)))
    else:
        if aug_scale_hung:
            raise NotImplementedError('aug_scale_hung requires a crop_size')

    if aug_hflip or aug_vflip or aug_hvflip:
        train_transforms.append(
            seg_transforms_cv.SegCVTransformRandomFlip(aug_hflip, aug_vflip,
                                                       aug_hvflip))
    train_transforms.append(
        seg_transforms_cv.SegCVTransformNormalizeToTensor(NET_MEAN, NET_STD))

    # Train data pipeline: supervised and unsupervised data sets
    train_sup_ds = ds_src.dataset(
        labels=True,
        mask=False,
        xf=False,
        pair=False,
        transforms=seg_transforms.SegTransformCompose(train_transforms),
        pipeline_type='cv')
    train_unsup_ds = ds_src.dataset(
        labels=False,
        mask=True,
        xf=True,
        pair=True,
        transforms=seg_transforms.SegTransformCompose(train_transforms),
        pipeline_type='cv')

    collate_fn = seg_data.SegCollate(BLOCK_SIZE)

    # Train data pipeline: data loaders
    sup_sampler = seg_data.RepeatSampler(
        torch.utils.data.SubsetRandomSampler(sup_ndx))
    train_sup_loader = torch.utils.data.DataLoader(train_sup_ds,
                                                   batch_size,
                                                   sampler=sup_sampler,
                                                   collate_fn=collate_fn,
                                                   num_workers=num_workers)
    if cons_weight > 0.0:
        unsup_sampler = seg_data.RepeatSampler(
            torch.utils.data.SubsetRandomSampler(unsup_ndx))
        train_unsup_loader = torch.utils.data.DataLoader(
            train_unsup_ds,
            batch_size,
            sampler=unsup_sampler,
            collate_fn=collate_fn,
            num_workers=num_workers)
    else:
        train_unsup_loader = None

    # Eval pipeline
    src_val_loader, tgt_val_loader, test_loader = datasets.eval_data_pipeline(
        ds_src, ds_tgt, src_val_ndx, tgt_val_ndx, test_ndx, batch_size,
        collate_fn, NET_MEAN, NET_STD, num_workers)

    # Report setttings
    print('Settings:')
    print(', '.join([
        '{}={}'.format(key, settings[key])
        for key in sorted(list(settings.keys()))
    ]))

    # Report dataset size
    print('Dataset:')
    print('len(sup_ndx)={}'.format(len(sup_ndx)))
    print('len(unsup_ndx)={}'.format(len(unsup_ndx)))
    if ds_src is not ds_tgt:
        print('len(src_val_ndx)={}'.format(len(tgt_val_ndx)))
        print('len(tgt_val_ndx)={}'.format(len(tgt_val_ndx)))
    else:
        print('len(val_ndx)={}'.format(len(tgt_val_ndx)))
    if test_ndx is not None:
        print('len(test_ndx)={}'.format(len(test_ndx)))

    if n_sup != -1:
        print('sup_ndx={}'.format(sup_ndx.tolist()))

    # Track mIoU for early stopping
    best_tgt_miou = None
    best_epoch = 0

    eval_net_state = {
        key: value.detach().cpu().numpy()
        for key, value in eval_net.state_dict().items()
    }

    # Create iterators
    train_sup_iter = iter(train_sup_loader)
    train_unsup_iter = iter(
        train_unsup_loader) if train_unsup_loader is not None else None

    iter_i = 0
    print('Training...')
    for epoch_i in range(num_epochs):
        if lr_epoch_scheduler is not None:
            lr_epoch_scheduler.step(epoch_i)

        t1 = time.time()

        if rampup > 0:
            ramp_val = network_architectures.sigmoid_rampup(epoch_i, rampup)
        else:
            ramp_val = 1.0

        student_net.train()
        if teacher_net is not student_net:
            teacher_net.train()

        if freeze_bn:
            student_net.freeze_batchnorm()
            if teacher_net is not student_net:
                teacher_net.freeze_batchnorm()

        sup_loss_acc = 0.0
        consistency_loss_acc = 0.0
        conf_rate_acc = 0.0
        n_sup_batches = 0
        n_unsup_batches = 0

        src_val_iter = iter(
            src_val_loader) if src_val_loader is not None else None
        tgt_val_iter = iter(
            tgt_val_loader) if tgt_val_loader is not None else None

        for sup_batch in itertools.islice(train_sup_iter, iters_per_epoch):
            if lr_iter_scheduler is not None:
                lr_iter_scheduler.step(iter_i)

            student_optim.zero_grad()

            #
            # Supervised branch
            #

            batch_x = sup_batch['image'].to(torch_device)
            batch_y = sup_batch['labels'].to(torch_device)

            logits_sup = student_net(batch_x)
            sup_loss = clf_crossent_loss(logits_sup, batch_y[:, 0, :, :])
            sup_loss.backward()

            if cons_weight > 0.0:
                for _ in range(unsup_batch_ratio):
                    #
                    # Unsupervised branch
                    #

                    # Cut mode: batch consists of unsupervised samples and mask params
                    unsup_batch = next(train_unsup_iter)

                    # Input images to torch tensor
                    batch_ux0 = unsup_batch['sample0']['image'].to(
                        torch_device)
                    batch_um0 = unsup_batch['sample0']['mask'].to(torch_device)
                    batch_ux1 = unsup_batch['sample1']['image'].to(
                        torch_device)
                    batch_um1 = unsup_batch['sample1']['mask'].to(torch_device)
                    batch_ufx0_to_1 = unsup_batch['xf0_to_1'].to(torch_device)

                    # Get teacher predictions for image0
                    with torch.no_grad():
                        logits_cons_tea = teacher_net(batch_ux0).detach()
                    # Get student prediction for image1
                    logits_cons_stu = student_net(batch_ux1)

                    # Transformation from teacher to student space
                    grid_tea_to_stu = F.affine_grid(batch_ufx0_to_1,
                                                    batch_ux0.shape,
                                                    **affine_align_corners_kw)
                    # Transform teacher predicted logits to student space
                    logits_cons_tea_in_stu = F.grid_sample(
                        logits_cons_tea, grid_tea_to_stu,
                        **affine_align_corners_kw)
                    # Transform mask from teacher to student space and multiply by student space mask
                    mask_tea_in_stu = F.grid_sample(
                        batch_um0, grid_tea_to_stu, **
                        affine_align_corners_kw) * batch_um1

                    # Logits -> probs
                    prob_cons_tea = F.softmax(logits_cons_tea, dim=1)
                    prob_cons_stu = F.softmax(logits_cons_stu, dim=1)
                    # Transform teacher predicted probabilities to student space
                    prob_cons_tea_in_stu = F.grid_sample(
                        prob_cons_tea, grid_tea_to_stu,
                        **affine_align_corners_kw)

                    # for i in range(len(batch_ux0)):
                    #     plt.figure(figsize=(18, 12))
                    #
                    #     x_0_in_1 = F.grid_sample(batch_ux0, grid_tea_to_stu)
                    #     d_x0_in_1 = torch.abs(x_0_in_1 - batch_ux1) * mask_tea_in_stu
                    #     mask_tea_in_stu_np = mask_tea_in_stu.detach().cpu().numpy()
                    #
                    #     plt.subplot(2, 4, 1)
                    #     plt.imshow(batch_ux0[i].detach().cpu().numpy().transpose(1, 2, 0) * 0.25 + 0.5)
                    #     plt.subplot(2, 4, 2)
                    #     plt.imshow(batch_ux1[i].detach().cpu().numpy().transpose(1, 2, 0) * 0.25 + 0.5)
                    #     plt.subplot(2, 4, 3)
                    #     plt.imshow(x_0_in_1[i].detach().cpu().numpy().transpose(1, 2, 0) * 0.25 + 0.5)
                    #     plt.subplot(2, 4, 4)
                    #     plt.imshow(d_x0_in_1[i].detach().cpu().numpy().transpose(1, 2, 0) * 10 + 0.5, cmap='gray')
                    #
                    #     plt.subplot(2, 4, 5)
                    #     plt.imshow(batch_um0[i,0].detach().cpu().numpy(), cmap='gray')
                    #     plt.subplot(2, 4, 6)
                    #     plt.imshow(batch_um1[i,0].detach().cpu().numpy(), cmap='gray')
                    #     plt.subplot(2, 4, 7)
                    #     plt.imshow(mask_tea_in_stu[i,0].detach().cpu().numpy(), cmap='gray')
                    #
                    #     plt.show()

                    loss_mask = mask_tea_in_stu

                    # Confidence thresholding
                    if conf_thresh > 0.0:
                        # Compute confidence of teacher predictions
                        conf_tea = prob_cons_tea_in_stu.max(dim=1)[0]
                        # Compute confidence mask
                        conf_mask = (conf_tea >=
                                     conf_thresh).float()[:, None, :, :]
                        # Record rate for reporting
                        conf_rate_acc += float(conf_mask.mean())
                        # Average confidence mask if requested
                        if not conf_per_pixel:
                            conf_mask = conf_mask.mean()

                        loss_mask = loss_mask * conf_mask
                    elif rampup > 0:
                        conf_rate_acc += ramp_val

                    # Compute per-pixel consistency loss
                    # Note that the way we aggregate the loss across the class/channel dimension (1)
                    # depends on the loss function used. Generally, summing over the class dimension
                    # keeps the magnitude of the gradient of the loss w.r.t. the logits
                    # nearly constant w.r.t. the number of classes. When using logit-variance,
                    # dividing by `sqrt(num_classes)` helps.
                    if cons_loss_fn == 'var':
                        delta_prob = prob_cons_stu - prob_cons_tea_in_stu
                        consistency_loss = delta_prob * delta_prob
                        consistency_loss = consistency_loss.sum(dim=1,
                                                                keepdim=True)
                    elif cons_loss_fn == 'logits_var':
                        delta_logits = logits_cons_stu - logits_cons_tea_in_stu
                        consistency_loss = delta_logits * delta_logits
                        consistency_loss = delta_prob * delta_prob
                        consistency_loss = consistency_loss.sum(
                            dim=1, keepdim=True) / root_n_classes
                    elif cons_loss_fn == 'logits_smoothl1':
                        consistency_loss = F.smooth_l1_loss(
                            logits_cons_stu,
                            logits_cons_tea_in_stu,
                            reduce=False)
                        consistency_loss = consistency_loss.sum(
                            dim=1, keepdim=True) / root_n_classes
                    elif cons_loss_fn == 'bce':
                        consistency_loss = network_architectures.robust_binary_crossentropy(
                            prob_cons_stu, prob_cons_tea_in_stu)
                        consistency_loss = consistency_loss.sum(dim=1,
                                                                keepdim=True)
                    elif cons_loss_fn == 'kld':
                        consistency_loss = F.kl_div(F.log_softmax(
                            logits_cons_stu, dim=1),
                                                    prob_cons_tea_in_stu,
                                                    reduce=False)
                        consistency_loss = consistency_loss.sum(dim=1,
                                                                keepdim=True)
                    else:
                        raise ValueError(
                            'Unknown consistency loss function {}'.format(
                                cons_loss_fn))

                    # Apply consistency loss mask and take the mean over pixels and images
                    consistency_loss = (consistency_loss * loss_mask).mean()

                    # Modulate with rampup if desired
                    if rampup > 0:
                        consistency_loss = consistency_loss * ramp_val

                    # Weight the consistency loss and back-prop
                    unsup_loss = consistency_loss * cons_weight
                    unsup_loss.backward()

                    consistency_loss_acc += float(consistency_loss.detach())

                    n_unsup_batches += 1

            student_optim.step()
            if teacher_optim is not None:
                teacher_optim.step()

            sup_loss_acc += float(sup_loss.detach())
            n_sup_batches += 1
            iter_i += 1

        sup_loss_acc /= n_sup_batches
        if n_unsup_batches > 0:
            consistency_loss_acc /= n_unsup_batches
            conf_rate_acc /= n_unsup_batches

        eval_net.eval()

        if src_val_iter is not None:
            src_iou_eval = evaluation.EvaluatorIoU(ds_src.num_classes,
                                                   bin_fill_holes)
            with torch.no_grad():
                for batch in src_val_iter:
                    batch_x = batch['image'].to(torch_device)
                    batch_y = batch['labels'].numpy()

                    logits = eval_net(batch_x)
                    pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                    for sample_i in range(len(batch_y)):
                        src_iou_eval.sample(batch_y[sample_i, 0],
                                            pred_y[sample_i],
                                            ignore_value=255)

            src_iou = src_iou_eval.score()
            src_miou = src_iou.mean()
        else:
            src_iou_eval = src_iou = src_miou = None

        tgt_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes,
                                               bin_fill_holes)
        with torch.no_grad():
            for batch in tgt_val_iter:
                batch_x = batch['image'].to(torch_device)
                batch_y = batch['labels'].numpy()

                logits = eval_net(batch_x)
                pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                for sample_i in range(len(batch_y)):
                    tgt_iou_eval.sample(batch_y[sample_i, 0],
                                        pred_y[sample_i],
                                        ignore_value=255)

        tgt_iou = tgt_iou_eval.score()
        tgt_miou = tgt_iou.mean()

        t2 = time.time()

        if ds_src is not ds_tgt:
            print(
                'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, '
                'SRC VAL mIoU={:.3%}, TGT VAL mIoU={:.3%}'.format(
                    epoch_i + 1, t2 - t1, sup_loss_acc, consistency_loss_acc,
                    conf_rate_acc, src_miou, tgt_miou))
            print('-- SRC {}'.format(', '.join(
                ['{:.3%}'.format(x) for x in src_iou])))
            print('-- TGT {}'.format(', '.join(
                ['{:.3%}'.format(x) for x in tgt_iou])))
        else:
            print(
                'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, VAL mIoU={:.3%}'
                .format(epoch_i + 1, t2 - t1, sup_loss_acc,
                        consistency_loss_acc, conf_rate_acc, tgt_miou))
            print('-- {}'.format(', '.join(
                ['{:.3%}'.format(x) for x in tgt_iou])))

    if save_model:
        model_path = os.path.join(submit_config.run_dir, "model.pth")
        torch.save(eval_net, model_path)

    if save_preds:
        out_dir = os.path.join(submit_config.run_dir, 'preds')
        os.makedirs(out_dir, exist_ok=True)
        with torch.no_grad():
            for batch in tgt_val_loader:
                batch_x = batch['image'].to(torch_device)
                batch_ndx = batch['index'].numpy()

                logits = eval_net(batch_x)
                pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                for sample_i, sample_ndx in enumerate(batch_ndx):
                    ds_tgt.save_prediction_by_index(
                        out_dir, pred_y[sample_i].astype(np.uint32),
                        sample_ndx)
    else:
        out_dir = None

    if test_loader is not None:
        test_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes,
                                                bin_fill_holes)
        with torch.no_grad():
            for batch in test_loader:
                batch_x = batch['image'].to(torch_device)
                batch_ndx = batch['index'].numpy()

                logits = eval_net(batch_x)
                pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                for sample_i, sample_ndx in enumerate(batch_ndx):
                    if save_preds:
                        ds_tgt.save_prediction_by_index(
                            out_dir, pred_y[sample_i].astype(np.uint32),
                            sample_ndx)
                    test_iou_eval.sample(batch_y[sample_i, 0],
                                         pred_y[sample_i],
                                         ignore_value=255)

        test_iou = test_iou_eval.score()
        test_miou = test_iou.mean()

        print('FINAL TEST: mIoU={:.3%}'.format(test_miou))
        print('-- TEST {}'.format(', '.join(
            ['{:.3%}'.format(x) for x in test_iou])))
Пример #3
0
def train_seg_semisup_mask_mt(
        submit_config: job_helper.SubmitConfig, dataset, model, arch,
        freeze_bn, opt_type, sgd_momentum, sgd_nesterov, sgd_weight_decay,
        learning_rate, lr_sched, lr_step_epochs, lr_step_gamma, lr_poly_power,
        teacher_alpha, bin_fill_holes, crop_size, aug_hflip, aug_vflip,
        aug_hvflip, aug_scale_hung, aug_max_scale, aug_scale_non_uniform,
        aug_rot_mag, mask_mode, mask_prop_range, boxmask_n_boxes,
        boxmask_fixed_aspect_ratio, boxmask_by_size, boxmask_outside_bounds,
        boxmask_no_invert, cons_loss_fn, cons_weight, conf_thresh,
        conf_per_pixel, rampup, unsup_batch_ratio, num_epochs, iters_per_epoch,
        batch_size, n_sup, n_unsup, n_val, split_seed, split_path, val_seed,
        save_preds, save_model, num_workers):
    settings = locals().copy()
    del settings['submit_config']

    if ':' in mask_prop_range:
        a, b = mask_prop_range.split(':')
        mask_prop_range = (float(a.strip()), float(b.strip()))
        del a, b
    else:
        mask_prop_range = float(mask_prop_range)

    if mask_mode == 'zero':
        mask_mix = False
    elif mask_mode == 'mix':
        mask_mix = True
    else:
        raise ValueError('Unknown mask_mode {}'.format(mask_mode))
    del mask_mode

    import os
    import math
    import time
    import itertools
    import numpy as np
    import torch.nn as nn, torch.nn.functional as F
    from architectures import network_architectures
    import torch.utils.data
    from datapipe import datasets
    from datapipe import seg_data, seg_transforms, seg_transforms_cv
    import evaluation
    import optim_weight_ema
    import mask_gen
    import lr_schedules

    if crop_size == '':
        crop_size = None
    else:
        crop_size = [int(x.strip()) for x in crop_size.split(',')]

    torch_device = torch.device('cuda:0')

    #
    # Load data sets
    #
    ds_dict = datasets.load_dataset(dataset, n_val, val_seed, n_sup, n_unsup,
                                    split_seed, split_path)

    ds_src = ds_dict['ds_src']
    ds_tgt = ds_dict['ds_tgt']
    tgt_val_ndx = ds_dict['val_ndx_tgt']
    src_val_ndx = ds_dict['val_ndx_src'] if ds_src is not ds_tgt else None
    test_ndx = ds_dict['test_ndx_tgt']
    sup_ndx = ds_dict['sup_ndx']
    unsup_ndx = ds_dict['unsup_ndx']

    n_classes = ds_src.num_classes
    root_n_classes = math.sqrt(n_classes)

    if bin_fill_holes and n_classes != 2:
        print(
            'Binary hole filling can only be used with binary (2-class) segmentation datasets'
        )
        return

    print('Loaded data')

    # Build network
    NetClass = network_architectures.seg.get(arch)

    student_net = NetClass(ds_src.num_classes).to(torch_device)

    if opt_type == 'adam':
        student_optim = torch.optim.Adam([
            dict(params=student_net.pretrained_parameters(),
                 lr=learning_rate * 0.1),
            dict(params=student_net.new_parameters(), lr=learning_rate)
        ])
    elif opt_type == 'sgd':
        student_optim = torch.optim.SGD([
            dict(params=student_net.pretrained_parameters(),
                 lr=learning_rate * 0.1),
            dict(params=student_net.new_parameters(), lr=learning_rate)
        ],
                                        momentum=sgd_momentum,
                                        nesterov=sgd_nesterov,
                                        weight_decay=sgd_weight_decay)
    else:
        raise ValueError('Unknown opt_type {}'.format(opt_type))

    if model == 'mean_teacher':
        teacher_net = NetClass(ds_src.num_classes).to(torch_device)

        for p in teacher_net.parameters():
            p.requires_grad = False

        teacher_optim = optim_weight_ema.EMAWeightOptimizer(
            teacher_net, student_net, teacher_alpha)
        eval_net = teacher_net
    elif model == 'pi':
        teacher_net = student_net
        teacher_optim = None
        eval_net = student_net
    else:
        print('Unknown model type {}'.format(model))
        return

    BLOCK_SIZE = student_net.BLOCK_SIZE
    NET_MEAN, NET_STD = seg_transforms.get_mean_std(ds_tgt, student_net)

    if freeze_bn:
        if not hasattr(student_net, 'freeze_batchnorm'):
            raise ValueError(
                'Network {} does not support batchnorm freezing'.format(arch))

    clf_crossent_loss = nn.CrossEntropyLoss(ignore_index=255)

    print('Built network')

    mask_generator = mask_gen.BoxMaskGenerator(
        prop_range=mask_prop_range,
        n_boxes=boxmask_n_boxes,
        random_aspect_ratio=not boxmask_fixed_aspect_ratio,
        prop_by_area=not boxmask_by_size,
        within_bounds=not boxmask_outside_bounds,
        invert=not boxmask_no_invert)

    if iters_per_epoch == -1:
        iters_per_epoch = len(unsup_ndx) // batch_size
    total_iters = iters_per_epoch * num_epochs

    lr_epoch_scheduler, lr_iter_scheduler = lr_schedules.make_lr_schedulers(
        optimizer=student_optim,
        total_iters=total_iters,
        schedule_type=lr_sched,
        step_epochs=lr_step_epochs,
        step_gamma=lr_step_gamma,
        poly_power=lr_poly_power)

    train_transforms = []
    eval_transforms = []

    if crop_size is not None:
        if aug_scale_hung:
            train_transforms.append(
                seg_transforms_cv.SegCVTransformRandomCropScaleHung(
                    crop_size, (0, 0),
                    uniform_scale=not aug_scale_non_uniform))
        elif aug_max_scale != 1.0 or aug_rot_mag != 0.0:
            train_transforms.append(
                seg_transforms_cv.SegCVTransformRandomCropRotateScale(
                    crop_size, (0, 0),
                    rot_mag=aug_rot_mag,
                    max_scale=aug_max_scale,
                    uniform_scale=not aug_scale_non_uniform,
                    constrain_rot_scale=True))
        else:
            train_transforms.append(
                seg_transforms_cv.SegCVTransformRandomCrop(crop_size, (0, 0)))
    else:
        if aug_scale_hung:
            raise NotImplementedError('aug_scale_hung requires a crop_size')

    if aug_hflip or aug_vflip or aug_hvflip:
        train_transforms.append(
            seg_transforms_cv.SegCVTransformRandomFlip(aug_hflip, aug_vflip,
                                                       aug_hvflip))
    train_transforms.append(
        seg_transforms_cv.SegCVTransformNormalizeToTensor(NET_MEAN, NET_STD))
    eval_transforms.append(
        seg_transforms_cv.SegCVTransformNormalizeToTensor(NET_MEAN, NET_STD))

    train_sup_ds = ds_src.dataset(
        labels=True,
        mask=False,
        xf=False,
        pair=False,
        transforms=seg_transforms.SegTransformCompose(train_transforms),
        pipeline_type='cv')
    train_unsup_ds = ds_src.dataset(
        labels=False,
        mask=True,
        xf=False,
        pair=False,
        transforms=seg_transforms.SegTransformCompose(train_transforms),
        pipeline_type='cv')
    eval_ds = ds_src.dataset(
        labels=True,
        mask=False,
        xf=False,
        pair=False,
        transforms=seg_transforms.SegTransformCompose(eval_transforms),
        pipeline_type='cv')

    add_mask_params_to_batch = mask_gen.AddMaskParamsToBatch(mask_generator)

    collate_fn = seg_data.SegCollate(BLOCK_SIZE)
    mask_collate_fn = seg_data.SegCollate(
        BLOCK_SIZE, batch_aug_fn=add_mask_params_to_batch)

    # Train data pipeline: data loaders
    sup_sampler = seg_data.RepeatSampler(
        torch.utils.data.SubsetRandomSampler(sup_ndx))
    train_sup_loader = torch.utils.data.DataLoader(train_sup_ds,
                                                   batch_size,
                                                   sampler=sup_sampler,
                                                   collate_fn=collate_fn,
                                                   num_workers=num_workers)
    if cons_weight > 0.0:
        unsup_sampler = seg_data.RepeatSampler(
            torch.utils.data.SubsetRandomSampler(unsup_ndx))
        train_unsup_loader_0 = torch.utils.data.DataLoader(
            train_unsup_ds,
            batch_size,
            sampler=unsup_sampler,
            collate_fn=mask_collate_fn,
            num_workers=num_workers)
        if mask_mix:
            train_unsup_loader_1 = torch.utils.data.DataLoader(
                train_unsup_ds,
                batch_size,
                sampler=unsup_sampler,
                collate_fn=collate_fn,
                num_workers=num_workers)
        else:
            train_unsup_loader_1 = None
    else:
        train_unsup_loader_0 = None
        train_unsup_loader_1 = None

    # Eval pipeline
    src_val_loader, tgt_val_loader, test_loader = datasets.eval_data_pipeline(
        ds_src, ds_tgt, src_val_ndx, tgt_val_ndx, test_ndx, batch_size,
        collate_fn, NET_MEAN, NET_STD, num_workers)

    # Report setttings
    print('Settings:')
    print(', '.join([
        '{}={}'.format(key, settings[key])
        for key in sorted(list(settings.keys()))
    ]))

    # Report dataset size
    print('Dataset:')
    print('len(sup_ndx)={}'.format(len(sup_ndx)))
    print('len(unsup_ndx)={}'.format(len(unsup_ndx)))
    if ds_src is not ds_tgt:
        print('len(src_val_ndx)={}'.format(len(tgt_val_ndx)))
        print('len(tgt_val_ndx)={}'.format(len(tgt_val_ndx)))
    else:
        print('len(val_ndx)={}'.format(len(tgt_val_ndx)))
    if test_ndx is not None:
        print('len(test_ndx)={}'.format(len(test_ndx)))

    if n_sup != -1:
        print('sup_ndx={}'.format(sup_ndx.tolist()))

    # Track mIoU for early stopping
    best_tgt_miou = None
    best_epoch = 0

    eval_net_state = {
        key: value.detach().cpu().numpy()
        for key, value in eval_net.state_dict().items()
    }

    # Create iterators
    train_sup_iter = iter(train_sup_loader)
    train_unsup_iter_0 = iter(
        train_unsup_loader_0) if train_unsup_loader_0 is not None else None
    train_unsup_iter_1 = iter(
        train_unsup_loader_1) if train_unsup_loader_1 is not None else None

    iter_i = 0
    print('Training...')
    for epoch_i in range(num_epochs):
        if lr_epoch_scheduler is not None:
            lr_epoch_scheduler.step(epoch_i)

        t1 = time.time()

        if rampup > 0:
            ramp_val = network_architectures.sigmoid_rampup(epoch_i, rampup)
        else:
            ramp_val = 1.0

        student_net.train()
        if teacher_net is not student_net:
            teacher_net.train()

        if freeze_bn:
            student_net.freeze_batchnorm()
            if teacher_net is not student_net:
                teacher_net.freeze_batchnorm()

        sup_loss_acc = 0.0
        consistency_loss_acc = 0.0
        conf_rate_acc = 0.0
        n_sup_batches = 0
        n_unsup_batches = 0

        src_val_iter = iter(
            src_val_loader) if src_val_loader is not None else None
        tgt_val_iter = iter(
            tgt_val_loader) if tgt_val_loader is not None else None

        for sup_batch in itertools.islice(train_sup_iter, iters_per_epoch):
            if lr_iter_scheduler is not None:
                lr_iter_scheduler.step(iter_i)
            student_optim.zero_grad()

            #
            # Supervised branch
            #

            batch_x = sup_batch['image'].to(torch_device)
            batch_y = sup_batch['labels'].to(torch_device)

            logits_sup = student_net(batch_x)
            sup_loss = clf_crossent_loss(logits_sup, batch_y[:, 0, :, :])
            sup_loss.backward()

            if cons_weight > 0.0:
                for _ in range(unsup_batch_ratio):
                    #
                    # Unsupervised branch
                    #

                    if mask_mix:
                        # Mix mode: batch consists of paired unsupervised samples and mask parameters
                        unsup_batch0 = next(train_unsup_iter_0)
                        unsup_batch1 = next(train_unsup_iter_1)
                        batch_ux0 = unsup_batch0['image'].to(torch_device)
                        batch_um0 = unsup_batch0['mask'].to(torch_device)
                        batch_ux1 = unsup_batch1['image'].to(torch_device)
                        batch_um1 = unsup_batch1['mask'].to(torch_device)
                        batch_mask_params = unsup_batch0['mask_params'].to(
                            torch_device)

                        # batch_um0 and batch_um1 are masks that are 1 for valid pixels, 0 for invalid pixels.
                        # They are used later on to scale the consistency loss, so that consistency loss is
                        # only computed for valid pixels.
                        # Explanation:
                        # When using geometric augmentations such as rotations, some pixels in the training
                        # crop may come from outside the bounds of the input image. These pixels will have a value
                        # of 0 in these masks. Similarly, when using scaled crops, the size of the crop
                        # from the input image that must be scaled to the size of the training crop may be
                        # larger than one/both of the input image dimensions. Pixels in the training crop
                        # that arise from outside the input image bounds will once again be given a value
                        # of 0 in these masks.

                        # Convert mask parameters to masks of shape (N,1,H,W)
                        batch_mix_masks = mask_generator.torch_masks_from_params(
                            batch_mask_params, batch_ux0.shape[2:4],
                            torch_device)

                        # Mix images with masks
                        batch_ux_mixed = batch_ux0 * (
                            1 - batch_mix_masks) + batch_ux1 * batch_mix_masks
                        batch_um_mixed = batch_um0 * (
                            1 - batch_mix_masks) + batch_um1 * batch_mix_masks

                        # Get teacher predictions for original images
                        with torch.no_grad():
                            logits_u0_tea = teacher_net(batch_ux0).detach()
                            logits_u1_tea = teacher_net(batch_ux1).detach()
                        # Get student prediction for mixed image
                        logits_cons_stu = student_net(batch_ux_mixed)

                        # Mix teacher predictions using same mask
                        # It makes no difference whether we do this with logits or probabilities as
                        # the mask pixels are either 1 or 0
                        logits_cons_tea = logits_u0_tea * (
                            1 -
                            batch_mix_masks) + logits_u1_tea * batch_mix_masks

                        # Logits -> probs
                        prob_cons_tea = F.softmax(logits_cons_tea, dim=1)
                        prob_cons_stu = F.softmax(logits_cons_stu, dim=1)

                        loss_mask = batch_um_mixed

                    else:
                        # Cut mode: batch consists of unsupervised samples and mask params
                        unsup_batch = next(train_unsup_iter_0)
                        batch_ux = unsup_batch['image'].to(torch_device)
                        batch_um = unsup_batch['mask'].to(torch_device)
                        batch_mask_params = unsup_batch['mask_params'].to(
                            torch_device)

                        # Convert mask parameters to masks of shape (N,1,H,W)
                        batch_cut_masks = mask_generator.torch_masks_from_params(
                            batch_mask_params, batch_ux.shape[2:4],
                            torch_device)

                        # Cut image with mask (mask regions to zero)
                        batch_ux_cut = batch_ux * batch_cut_masks

                        # Get teacher predictions for original image
                        with torch.no_grad():
                            logits_cons_tea = teacher_net(batch_ux).detach()
                        # Get student prediction for cut image
                        logits_cons_stu = student_net(batch_ux_cut)

                        # Logits -> probs
                        prob_cons_tea = F.softmax(logits_cons_tea, dim=1)
                        prob_cons_stu = F.softmax(logits_cons_stu, dim=1)

                        loss_mask = batch_cut_masks * batch_um

                    # -- shared by mix and cut --

                    # Confidence thresholding
                    if conf_thresh > 0.0:
                        # Compute confidence of teacher predictions
                        conf_tea = prob_cons_tea.max(dim=1)[0]
                        # Compute confidence mask
                        conf_mask = (conf_tea >=
                                     conf_thresh).float()[:, None, :, :]
                        # Record rate for reporting
                        conf_rate_acc += float(conf_mask.mean())
                        # Average confidence mask if requested
                        if not conf_per_pixel:
                            conf_mask = conf_mask.mean()

                        loss_mask = loss_mask * conf_mask
                    elif rampup > 0:
                        conf_rate_acc += ramp_val

                    # Compute per-pixel consistency loss
                    # Note that the way we aggregate the loss across the class/channel dimension (1)
                    # depends on the loss function used. Generally, summing over the class dimension
                    # keeps the magnitude of the gradient of the loss w.r.t. the logits
                    # nearly constant w.r.t. the number of classes. When using logit-variance,
                    # dividing by `sqrt(num_classes)` helps.
                    if cons_loss_fn == 'var':
                        delta_prob = prob_cons_stu - prob_cons_tea
                        consistency_loss = delta_prob * delta_prob
                        consistency_loss = consistency_loss.sum(dim=1,
                                                                keepdim=True)
                    elif cons_loss_fn == 'logits_var':
                        delta_logits = logits_cons_stu - logits_cons_tea
                        consistency_loss = delta_logits * delta_logits
                        consistency_loss = consistency_loss.sum(
                            dim=1, keepdim=True) / root_n_classes
                    elif cons_loss_fn == 'logits_smoothl1':
                        consistency_loss = F.smooth_l1_loss(logits_cons_stu,
                                                            logits_cons_tea,
                                                            reduce=False)
                        consistency_loss = consistency_loss.sum(
                            dim=1, keepdim=True) / root_n_classes
                    elif cons_loss_fn == 'bce':
                        consistency_loss = network_architectures.robust_binary_crossentropy(
                            prob_cons_stu, prob_cons_tea)
                        consistency_loss = consistency_loss.sum(dim=1,
                                                                keepdim=True)
                    elif cons_loss_fn == 'kld':
                        consistency_loss = F.kl_div(F.log_softmax(
                            logits_cons_stu, dim=1),
                                                    prob_cons_tea,
                                                    reduce=False)
                        consistency_loss = consistency_loss.sum(dim=1,
                                                                keepdim=True)
                    else:
                        raise ValueError(
                            'Unknown consistency loss function {}'.format(
                                cons_loss_fn))

                    # Apply consistency loss mask and take the mean over pixels and images
                    consistency_loss = (consistency_loss * loss_mask).mean()

                    # Modulate with rampup if desired
                    if rampup > 0:
                        consistency_loss = consistency_loss * ramp_val

                    # Weight the consistency loss and back-prop
                    unsup_loss = consistency_loss * cons_weight
                    unsup_loss.backward()

                    consistency_loss_acc += float(consistency_loss.detach())

                    n_unsup_batches += 1

            student_optim.step()
            if teacher_optim is not None:
                teacher_optim.step()

            sup_loss_val = float(sup_loss.detach())
            if np.isnan(sup_loss_val):
                print('NaN detected; network dead, bailing.')
                return

            sup_loss_acc += sup_loss_val
            n_sup_batches += 1
            iter_i += 1

        sup_loss_acc /= n_sup_batches
        if n_unsup_batches > 0:
            consistency_loss_acc /= n_unsup_batches
            conf_rate_acc /= n_unsup_batches

        eval_net.eval()

        if ds_src is not ds_tgt:
            src_iou_eval = evaluation.EvaluatorIoU(ds_src.num_classes,
                                                   bin_fill_holes)
            with torch.no_grad():
                for batch in src_val_iter:
                    batch_x = batch['image'].to(torch_device)
                    batch_y = batch['labels'].numpy()

                    logits = eval_net(batch_x)
                    pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                    for sample_i in range(len(batch_y)):
                        src_iou_eval.sample(batch_y[sample_i, 0],
                                            pred_y[sample_i],
                                            ignore_value=255)

            src_iou = src_iou_eval.score()
            src_miou = src_iou.mean()
        else:
            src_iou_eval = src_iou = src_miou = None

        tgt_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes,
                                               bin_fill_holes)
        with torch.no_grad():
            for batch in tgt_val_iter:
                batch_x = batch['image'].to(torch_device)
                batch_y = batch['labels'].numpy()

                logits = eval_net(batch_x)
                pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                for sample_i in range(len(batch_y)):
                    tgt_iou_eval.sample(batch_y[sample_i, 0],
                                        pred_y[sample_i],
                                        ignore_value=255)

        tgt_iou = tgt_iou_eval.score()
        tgt_miou = tgt_iou.mean()

        t2 = time.time()

        if ds_src is not ds_tgt:
            print(
                'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, '
                'SRC VAL mIoU={:.3%}, TGT VAL mIoU={:.3%}'.format(
                    epoch_i + 1, t2 - t1, sup_loss_acc, consistency_loss_acc,
                    conf_rate_acc, src_miou, tgt_miou))
            print('-- SRC {}'.format(', '.join(
                ['{:.3%}'.format(x) for x in src_iou])))
            print('-- TGT {}'.format(', '.join(
                ['{:.3%}'.format(x) for x in tgt_iou])))
        else:
            print(
                'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, VAL mIoU={:.3%}'
                .format(epoch_i + 1, t2 - t1, sup_loss_acc,
                        consistency_loss_acc, conf_rate_acc, tgt_miou))
            print('-- {}'.format(', '.join(
                ['{:.3%}'.format(x) for x in tgt_iou])))

    if save_model:
        model_path = os.path.join(submit_config.run_dir, "model.pth")
        torch.save(eval_net, model_path)

    if save_preds:
        out_dir = os.path.join(submit_config.run_dir, 'preds')
        os.makedirs(out_dir, exist_ok=True)
        with torch.no_grad():
            for batch in tgt_val_loader:
                batch_x = batch['image'].to(torch_device)
                batch_ndx = batch['index'].numpy()

                logits = eval_net(batch_x)
                pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                for sample_i, sample_ndx in enumerate(batch_ndx):
                    ds_tgt.save_prediction_by_index(
                        out_dir, pred_y[sample_i].astype(np.uint32),
                        sample_ndx)
    else:
        out_dir = None

    if test_loader is not None:
        test_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes,
                                                bin_fill_holes)
        with torch.no_grad():
            for batch in test_loader:
                batch_x = batch['image'].to(torch_device)
                batch_ndx = batch['index'].numpy()

                logits = eval_net(batch_x)
                pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy()

                for sample_i, sample_ndx in enumerate(batch_ndx):
                    if save_preds:
                        ds_tgt.save_prediction_by_index(
                            out_dir, pred_y[sample_i].astype(np.uint32),
                            sample_ndx)
                    test_iou_eval.sample(batch_y[sample_i, 0],
                                         pred_y[sample_i],
                                         ignore_value=255)

        test_iou = test_iou_eval.score()
        test_miou = test_iou.mean()

        print('FINAL TEST: mIoU={:.3%}'.format(test_miou))
        print('-- TEST {}'.format(', '.join(
            ['{:.3%}'.format(x) for x in test_iou])))