예제 #1
0
    def __init__(self, data_dir, train_txt: str, batch_size: int = 32, show=False):
        self.batch_size = batch_size
        self.data_dir = data_dir
        self.dataset_sequences = []
        self.show = show
        self.image_augmentation = augmentation.ImageAugmentation()

        mean_path = get_file('c3d_mean.npy',
                             C3D_MEAN_PATH,
                             cache_subdir='models',
                             md5_hash='08a07d9761e76097985124d9e8b2fe34')
        # Subtract mean
        self.mean = numpy.load(mean_path)

        with open(os.path.join(data_dir, train_txt), 'r') as video_fn:
            train_files = list(map(lambda l: l.strip(), video_fn.readlines()))

        for video_fn in train_files:
            if video_fn.startswith("no_smoking_videos"):
                # hack-hack
                pass
            else:
                human_y = os.path.join(self.data_dir, "jsonl.byhuman", video_fn, 'result.jsonl')
                # machine_y = os.path.join(self.data_dir, "jsonl", _f, 'result.jsonl')

                with open(human_y) as f:
                    predictions_by_frame = [json.loads(s.strip()) for s in f]
                    jsonl_to_sequences(predictions_by_frame)
                    for _seq in jsonl_to_sequences(predictions_by_frame):
                        cls_id, fnumbers = _seq

                        _start_fn = fnumbers[0]
                        _end_fn = fnumbers[-1]
                        if _end_fn > _start_fn:
                            self.dataset_sequences.append((cls_id, video_fn, _start_fn, _end_fn))

        shuffle(self.dataset_sequences)
def experiment(exp, arch, loss, double_softmax, confidence_thresh, rampup,
               teacher_alpha, fix_ema, unsup_weight, cls_bal_scale,
               cls_bal_scale_range, cls_balance, cls_balance_loss,
               combine_batches, learning_rate, standardise_samples,
               src_affine_std, src_xlat_range, src_hflip, src_intens_flip,
               src_intens_scale_range, src_intens_offset_range,
               src_gaussian_noise_std, tgt_affine_std, tgt_xlat_range,
               tgt_hflip, tgt_intens_flip, tgt_intens_scale_range,
               tgt_intens_offset_range, tgt_gaussian_noise_std, num_epochs,
               batch_size, epoch_size, seed, log_file, model_file, device):
    settings = locals().copy()

    import os
    import sys
    import pickle
    import cmdline_helpers

    if log_file == '':
        log_file = 'output_aug_log_{}.txt'.format(exp)
    elif log_file == 'none':
        log_file = None

    if log_file is not None:
        if os.path.exists(log_file):
            print('Output log file {} already exists'.format(log_file))
            return

    use_rampup = rampup > 0

    src_intens_scale_range_lower, src_intens_scale_range_upper, src_intens_offset_range_lower, src_intens_offset_range_upper = \
        cmdline_helpers.intens_aug_options(src_intens_scale_range, src_intens_offset_range)
    tgt_intens_scale_range_lower, tgt_intens_scale_range_upper, tgt_intens_offset_range_lower, tgt_intens_offset_range_upper = \
        cmdline_helpers.intens_aug_options(tgt_intens_scale_range, tgt_intens_offset_range)

    import time
    import math
    import numpy as np
    from batchup import data_source, work_pool
    import data_loaders
    import standardisation
    import network_architectures
    import augmentation
    import torch, torch.cuda
    from torch import nn
    from torch.nn import functional as F
    import optim_weight_ema

    torch_device = torch.device(device)
    pool = work_pool.WorkerThreadPool(2)

    n_chn = 0

    if exp == 'svhn_mnist':
        d_source = data_loaders.load_svhn(zero_centre=False, greyscale=True)
        d_target = data_loaders.load_mnist(invert=False,
                                           zero_centre=False,
                                           pad32=True,
                                           val=False)
    elif exp == 'mnist_svhn':
        d_source = data_loaders.load_mnist(invert=False,
                                           zero_centre=False,
                                           pad32=True)
        d_target = data_loaders.load_svhn(zero_centre=False,
                                          greyscale=True,
                                          val=False)
    elif exp == 'svhn_mnist_rgb':
        d_source = data_loaders.load_svhn(zero_centre=False, greyscale=False)
        d_target = data_loaders.load_mnist(invert=False,
                                           zero_centre=False,
                                           pad32=True,
                                           val=False,
                                           rgb=True)
    elif exp == 'mnist_svhn_rgb':
        d_source = data_loaders.load_mnist(invert=False,
                                           zero_centre=False,
                                           pad32=True,
                                           rgb=True)
        d_target = data_loaders.load_svhn(zero_centre=False,
                                          greyscale=False,
                                          val=False)
    elif exp == 'cifar_stl':
        d_source = data_loaders.load_cifar10(range_01=False)
        d_target = data_loaders.load_stl(zero_centre=False, val=False)
    elif exp == 'stl_cifar':
        d_source = data_loaders.load_stl(zero_centre=False)
        d_target = data_loaders.load_cifar10(range_01=False, val=False)
    elif exp == 'mnist_usps':
        d_source = data_loaders.load_mnist(zero_centre=False)
        d_target = data_loaders.load_usps(zero_centre=False,
                                          scale28=True,
                                          val=False)
    elif exp == 'usps_mnist':
        d_source = data_loaders.load_usps(zero_centre=False, scale28=True)
        d_target = data_loaders.load_mnist(zero_centre=False, val=False)
    elif exp == 'syndigits_svhn':
        d_source = data_loaders.load_syn_digits(zero_centre=False)
        d_target = data_loaders.load_svhn(zero_centre=False, val=False)
    elif exp == 'synsigns_gtsrb':
        d_source = data_loaders.load_syn_signs(zero_centre=False)
        d_target = data_loaders.load_gtsrb(zero_centre=False, val=False)
    else:
        print('Unknown experiment type \'{}\''.format(exp))
        return

    # Delete the training ground truths as we should not be using them
    del d_target.train_y

    if standardise_samples:
        standardisation.standardise_dataset(d_source)
        standardisation.standardise_dataset(d_target)

    n_classes = d_source.n_classes

    print('Loaded data')

    if arch == '':
        if exp in {'mnist_usps', 'usps_mnist'}:
            arch = 'mnist-bn-32-64-256'
        if exp in {'svhn_mnist', 'mnist_svhn'}:
            arch = 'grey-32-64-128-gp'
        if exp in {
                'cifar_stl', 'stl_cifar', 'syndigits_svhn', 'svhn_mnist_rgb',
                'mnist_svhn_rgb'
        }:
            arch = 'rgb-128-256-down-gp'
        if exp in {'synsigns_gtsrb'}:
            arch = 'rgb40-96-192-384-gp'

    net_class, expected_shape = network_architectures.get_net_and_shape_for_architecture(
        arch)

    if expected_shape != d_source.train_X.shape[1:]:
        print(
            'Architecture {} not compatible with experiment {}; it needs samples of shape {}, '
            'data has samples of shape {}'.format(arch, exp, expected_shape,
                                                  d_source.train_X.shape[1:]))
        return

    student_net = net_class(n_classes).to(torch_device)
    teacher_net = net_class(n_classes).to(torch_device)
    student_params = list(student_net.parameters())
    teacher_params = list(teacher_net.parameters())
    for param in teacher_params:
        param.requires_grad = False

    student_optimizer = torch.optim.Adam(student_params, lr=learning_rate)
    if fix_ema:
        teacher_optimizer = optim_weight_ema.EMAWeightOptimizer(
            teacher_net, student_net, alpha=teacher_alpha)
    else:
        teacher_optimizer = optim_weight_ema.OldWeightEMA(teacher_net,
                                                          student_net,
                                                          alpha=teacher_alpha)
    classification_criterion = nn.CrossEntropyLoss()

    print('Built network')

    src_aug = augmentation.ImageAugmentation(
        src_hflip,
        src_xlat_range,
        src_affine_std,
        intens_flip=src_intens_flip,
        intens_scale_range_lower=src_intens_scale_range_lower,
        intens_scale_range_upper=src_intens_scale_range_upper,
        intens_offset_range_lower=src_intens_offset_range_lower,
        intens_offset_range_upper=src_intens_offset_range_upper,
        gaussian_noise_std=src_gaussian_noise_std)
    tgt_aug = augmentation.ImageAugmentation(
        tgt_hflip,
        tgt_xlat_range,
        tgt_affine_std,
        intens_flip=tgt_intens_flip,
        intens_scale_range_lower=tgt_intens_scale_range_lower,
        intens_scale_range_upper=tgt_intens_scale_range_upper,
        intens_offset_range_lower=tgt_intens_offset_range_lower,
        intens_offset_range_upper=tgt_intens_offset_range_upper,
        gaussian_noise_std=tgt_gaussian_noise_std)

    if combine_batches:

        def augment(X_sup, y_src, X_tgt):
            X_src_stu, X_src_tea = src_aug.augment_pair(X_sup)
            X_tgt_stu, X_tgt_tea = tgt_aug.augment_pair(X_tgt)
            return X_src_stu, X_src_tea, y_src, X_tgt_stu, X_tgt_tea
    else:

        def augment(X_src, y_src, X_tgt):
            X_src = src_aug.augment(X_src)
            X_tgt_stu, X_tgt_tea = tgt_aug.augment_pair(X_tgt)
            return X_src, y_src, X_tgt_stu, X_tgt_tea

    rampup_weight_in_list = [0]

    cls_bal_fn = network_architectures.get_cls_bal_function(cls_balance_loss)

    def compute_aug_loss(stu_out, tea_out):
        # Augmentation loss
        if use_rampup:
            unsup_mask = None
            conf_mask_count = None
            unsup_mask_count = None
        else:
            conf_tea = torch.max(tea_out, 1)[0]
            unsup_mask = conf_mask = (conf_tea > confidence_thresh).float()
            unsup_mask_count = conf_mask_count = conf_mask.sum()

        if loss == 'bce':
            aug_loss = network_architectures.robust_binary_crossentropy(
                stu_out, tea_out)
        else:
            d_aug_loss = stu_out - tea_out
            aug_loss = d_aug_loss * d_aug_loss

        # Class balance scaling
        if cls_bal_scale:
            if use_rampup:
                n_samples = float(aug_loss.shape[0])
            else:
                n_samples = unsup_mask.sum()
            avg_pred = n_samples / float(n_classes)
            bal_scale = avg_pred / torch.clamp(tea_out.sum(dim=0), min=1.0)
            if cls_bal_scale_range != 0.0:
                bal_scale = torch.clamp(bal_scale,
                                        min=1.0 / cls_bal_scale_range,
                                        max=cls_bal_scale_range)
            bal_scale = bal_scale.detach()
            aug_loss = aug_loss * bal_scale[None, :]

        aug_loss = aug_loss.mean(dim=1)

        if use_rampup:
            unsup_loss = aug_loss.mean() * rampup_weight_in_list[0]
        else:
            unsup_loss = (aug_loss * unsup_mask).mean()

        # Class balance loss
        if cls_balance > 0.0:
            # Compute per-sample average predicated probability
            # Average over samples to get average class prediction
            avg_cls_prob = stu_out.mean(dim=0)
            # Compute loss
            equalise_cls_loss = cls_bal_fn(avg_cls_prob,
                                           float(1.0 / n_classes))

            equalise_cls_loss = equalise_cls_loss.mean() * n_classes

            if use_rampup:
                equalise_cls_loss = equalise_cls_loss * rampup_weight_in_list[0]
            else:
                if rampup == 0:
                    equalise_cls_loss = equalise_cls_loss * unsup_mask.mean(
                        dim=0)

            unsup_loss += equalise_cls_loss * cls_balance

        return unsup_loss, conf_mask_count, unsup_mask_count

    if combine_batches:

        def f_train(X_src0, X_src1, y_src, X_tgt0, X_tgt1):
            X_src0 = torch.tensor(X_src0,
                                  dtype=torch.float,
                                  device=torch_device)
            X_src1 = torch.tensor(X_src1,
                                  dtype=torch.float,
                                  device=torch_device)
            y_src = torch.tensor(y_src, dtype=torch.long, device=torch_device)
            X_tgt0 = torch.tensor(X_tgt0,
                                  dtype=torch.float,
                                  device=torch_device)
            X_tgt1 = torch.tensor(X_tgt1,
                                  dtype=torch.float,
                                  device=torch_device)

            n_samples = X_src0.size()[0]
            n_total = n_samples + X_tgt0.size()[0]

            student_optimizer.zero_grad()
            student_net.train()
            teacher_net.train()

            # Concatenate source and target mini-batches
            X0 = torch.cat([X_src0, X_tgt0], 0)
            X1 = torch.cat([X_src1, X_tgt1], 0)

            student_logits_out = student_net(X0)
            student_prob_out = F.softmax(student_logits_out, dim=1)

            src_logits_out = student_logits_out[:n_samples]
            src_prob_out = student_prob_out[:n_samples]

            teacher_logits_out = teacher_net(X1)
            teacher_prob_out = F.softmax(teacher_logits_out, dim=1)

            # Supervised classification loss
            if double_softmax:
                clf_loss = classification_criterion(src_prob_out, y_src)
            else:
                clf_loss = classification_criterion(src_logits_out, y_src)

            unsup_loss, conf_mask_count, unsup_mask_count = compute_aug_loss(
                student_prob_out, teacher_prob_out)

            loss_expr = clf_loss + unsup_loss * unsup_weight

            loss_expr.backward()
            student_optimizer.step()
            teacher_optimizer.step()

            outputs = [
                float(clf_loss) * n_samples,
                float(unsup_loss) * n_total
            ]
            if not use_rampup:
                mask_count = float(conf_mask_count) * 0.5
                unsup_count = float(unsup_mask_count) * 0.5

                outputs.append(mask_count)
                outputs.append(unsup_count)
            return tuple(outputs)
    else:

        def f_train(X_src, y_src, X_tgt0, X_tgt1):
            X_src = torch.tensor(X_src, dtype=torch.float, device=torch_device)
            y_src = torch.tensor(y_src, dtype=torch.long, device=torch_device)
            X_tgt0 = torch.tensor(X_tgt0,
                                  dtype=torch.float,
                                  device=torch_device)
            X_tgt1 = torch.tensor(X_tgt1,
                                  dtype=torch.float,
                                  device=torch_device)

            student_optimizer.zero_grad()
            student_net.train()
            teacher_net.train()

            src_logits_out = student_net(X_src)
            student_tgt_logits_out = student_net(X_tgt0)
            student_tgt_prob_out = F.softmax(student_tgt_logits_out, dim=1)
            teacher_tgt_logits_out = teacher_net(X_tgt1)
            teacher_tgt_prob_out = F.softmax(teacher_tgt_logits_out, dim=1)

            # Supervised classification loss
            if double_softmax:
                clf_loss = classification_criterion(
                    F.softmax(src_logits_out, dim=1), y_src)
            else:
                clf_loss = classification_criterion(src_logits_out, y_src)

            unsup_loss, conf_mask_count, unsup_mask_count = compute_aug_loss(
                student_tgt_prob_out, teacher_tgt_prob_out)

            loss_expr = clf_loss + unsup_loss * unsup_weight

            loss_expr.backward()
            student_optimizer.step()
            teacher_optimizer.step()

            n_samples = X_src.size()[0]

            outputs = [
                float(clf_loss) * n_samples,
                float(unsup_loss) * n_samples
            ]
            if not use_rampup:
                mask_count = float(conf_mask_count)
                unsup_count = float(unsup_mask_count)

                outputs.append(mask_count)
                outputs.append(unsup_count)
            return tuple(outputs)

    print('Compiled training function')

    def f_pred_src(X_sup):
        X_var = torch.tensor(X_sup, dtype=torch.float, device=torch_device)
        student_net.eval()
        teacher_net.eval()
        return (F.softmax(student_net(X_var), dim=1).detach().cpu().numpy(),
                F.softmax(teacher_net(X_var), dim=1).detach().cpu().numpy())

    def f_pred_tgt(X_sup):
        X_var = torch.tensor(X_sup, dtype=torch.float, device=torch_device)
        student_net.eval()
        teacher_net.eval()
        return (F.softmax(student_net(X_var), dim=1).detach().cpu().numpy(),
                F.softmax(teacher_net(X_var), dim=1).detach().cpu().numpy())

    def f_eval_src(X_sup, y_sup):
        y_pred_prob_stu, y_pred_prob_tea = f_pred_src(X_sup)
        y_pred_stu = np.argmax(y_pred_prob_stu, axis=1)
        y_pred_tea = np.argmax(y_pred_prob_tea, axis=1)
        return (float(
            (y_pred_stu != y_sup).sum()), float((y_pred_tea != y_sup).sum()))

    def f_eval_tgt(X_sup, y_sup):
        y_pred_prob_stu, y_pred_prob_tea = f_pred_tgt(X_sup)
        y_pred_stu = np.argmax(y_pred_prob_stu, axis=1)
        y_pred_tea = np.argmax(y_pred_prob_tea, axis=1)
        return (float(
            (y_pred_stu != y_sup).sum()), float((y_pred_tea != y_sup).sum()))

    print('Compiled evaluation function')

    # Setup output
    def log(text):
        print(text)
        if log_file is not None:
            with open(log_file, 'a') as f:
                f.write(text + '\n')
                f.flush()
                f.close()

    cmdline_helpers.ensure_containing_dir_exists(log_file)

    # Report setttings
    log('Settings: {}'.format(', '.join([
        '{}={}'.format(key, settings[key])
        for key in sorted(list(settings.keys()))
    ])))

    # Report dataset size
    log('Dataset:')
    log('SOURCE Train: X.shape={}, y.shape={}'.format(d_source.train_X.shape,
                                                      d_source.train_y.shape))
    log('SOURCE Test: X.shape={}, y.shape={}'.format(d_source.test_X.shape,
                                                     d_source.test_y.shape))
    log('TARGET Train: X.shape={}'.format(d_target.train_X.shape))
    log('TARGET Test: X.shape={}, y.shape={}'.format(d_target.test_X.shape,
                                                     d_target.test_y.shape))

    print('Training...')
    sup_ds = data_source.ArrayDataSource([d_source.train_X, d_source.train_y],
                                         repeats=-1)
    tgt_train_ds = data_source.ArrayDataSource([d_target.train_X], repeats=-1)
    train_ds = data_source.CompositeDataSource([sup_ds,
                                                tgt_train_ds]).map(augment)
    train_ds = pool.parallel_data_source(train_ds)
    if epoch_size == 'large':
        n_samples = max(d_source.train_X.shape[0], d_target.train_X.shape[0])
    elif epoch_size == 'small':
        n_samples = min(d_source.train_X.shape[0], d_target.train_X.shape[0])
    elif epoch_size == 'target':
        n_samples = d_target.train_X.shape[0]
    n_train_batches = n_samples // batch_size

    source_test_ds = data_source.ArrayDataSource(
        [d_source.test_X, d_source.test_y])
    target_test_ds = data_source.ArrayDataSource(
        [d_target.test_X, d_target.test_y])

    if seed != 0:
        shuffle_rng = np.random.RandomState(seed)
    else:
        shuffle_rng = np.random

    train_batch_iter = train_ds.batch_iterator(batch_size=batch_size,
                                               shuffle=shuffle_rng)

    best_teacher_model_state = {
        k: v.cpu().numpy()
        for k, v in teacher_net.state_dict().items()
    }

    best_conf_mask_rate = 0.0
    best_src_test_err = 1.0
    for epoch in range(num_epochs):
        t1 = time.time()

        if use_rampup:
            if epoch < rampup:
                p = max(0.0, float(epoch)) / float(rampup)
                p = 1.0 - p
                rampup_value = math.exp(-p * p * 5.0)
            else:
                rampup_value = 1.0

            rampup_weight_in_list[0] = rampup_value

        train_res = data_source.batch_map_mean(f_train,
                                               train_batch_iter,
                                               n_batches=n_train_batches)

        train_clf_loss = train_res[0]
        if combine_batches:
            unsup_loss_string = 'unsup (both) loss={:.6f}'.format(train_res[1])
        else:
            unsup_loss_string = 'unsup (tgt) loss={:.6f}'.format(train_res[1])

        src_test_err_stu, src_test_err_tea = source_test_ds.batch_map_mean(
            f_eval_src, batch_size=batch_size * 2)
        tgt_test_err_stu, tgt_test_err_tea = target_test_ds.batch_map_mean(
            f_eval_tgt, batch_size=batch_size * 2)

        if use_rampup:
            unsup_loss_string = '{}, rampup={:.3%}'.format(
                unsup_loss_string, rampup_value)
            if src_test_err_stu < best_src_test_err:
                best_src_test_err = src_test_err_stu
                best_teacher_model_state = {
                    k: v.cpu().numpy()
                    for k, v in teacher_net.state_dict().items()
                }
                improve = '*** '
            else:
                improve = ''
        else:
            conf_mask_rate = train_res[-2]
            unsup_mask_rate = train_res[-1]
            if conf_mask_rate > best_conf_mask_rate:
                best_conf_mask_rate = conf_mask_rate
                improve = '*** '
                best_teacher_model_state = {
                    k: v.cpu().numpy()
                    for k, v in teacher_net.state_dict().items()
                }
            else:
                improve = ''
            unsup_loss_string = '{}, conf mask={:.3%}, unsup mask={:.3%}'.format(
                unsup_loss_string, conf_mask_rate, unsup_mask_rate)

        t2 = time.time()

        log('{}Epoch {} took {:.2f}s: TRAIN clf loss={:.6f}, {}; '
            'SRC TEST ERR={:.3%}, TGT TEST student err={:.3%}, TGT TEST teacher err={:.3%}'
            .format(improve, epoch, t2 - t1, train_clf_loss, unsup_loss_string,
                    src_test_err_stu, tgt_test_err_stu, tgt_test_err_tea))

    # Save network
    if model_file != '':
        cmdline_helpers.ensure_containing_dir_exists(model_file)
        with open(model_file, 'wb') as f:
            torch.save(best_teacher_model_state, f)
def experiment(exp, arch, rnd_init, img_size, confidence_thresh, teacher_alpha,
               unsup_weight, cls_balance, cls_balance_loss, learning_rate,
               pretrained_lr_factor, fix_layers, double_softmax, use_dropout,
               src_scale_u_range, src_scale_x_range, src_scale_y_range,
               src_affine_std, src_xlat_range, src_rot_std, src_hflip,
               src_intens_scale_range, src_colour_rot_std, src_colour_off_std,
               src_greyscale, src_cutout_prob, src_cutout_size,
               tgt_scale_u_range, tgt_scale_x_range, tgt_scale_y_range,
               tgt_affine_std, tgt_xlat_range, tgt_rot_std, tgt_hflip,
               tgt_intens_scale_range, tgt_colour_rot_std, tgt_colour_off_std,
               tgt_greyscale, tgt_cutout_prob, tgt_cutout_size, constrain_crop,
               img_pad_width, num_epochs, batch_size, epoch_size, seed,
               log_file, skip_epoch_eval, result_file, record_history,
               model_file, hide_progress_bar, subsetsize, subsetseed, device,
               num_threads):
    settings = locals().copy()

    if rnd_init:
        if fix_layers != '':
            print('`rnd_init` and `fix_layers` are mutually exclusive')
            return

    if epoch_size not in {'source', 'target'}:
        try:
            epoch_size = int(epoch_size)
        except ValueError:
            print(
                'epoch_size should be an integer, \'source\', or \'target\', not {}'
                .format(epoch_size))
            return

    import os
    import sys
    import pickle
    import cmdline_helpers

    fix_layers = [lyr.strip() for lyr in fix_layers.split(',')]

    if log_file == '':
        log_file = 'output_aug_log_{}.txt'.format(exp)
    elif log_file == 'none':
        log_file = None

    if log_file is not None:
        if os.path.exists(log_file):
            print('Output log file {} already exists'.format(log_file))
            return

    src_intens_scale_range_lower, src_intens_scale_range_upper = cmdline_helpers.colon_separated_range(
        src_intens_scale_range)
    tgt_intens_scale_range_lower, tgt_intens_scale_range_upper = cmdline_helpers.colon_separated_range(
        tgt_intens_scale_range)
    src_scale_u_range = cmdline_helpers.colon_separated_range(
        src_scale_u_range)
    tgt_scale_u_range = cmdline_helpers.colon_separated_range(
        tgt_scale_u_range)
    src_scale_x_range = cmdline_helpers.colon_separated_range(
        src_scale_x_range)
    tgt_scale_x_range = cmdline_helpers.colon_separated_range(
        tgt_scale_x_range)
    src_scale_y_range = cmdline_helpers.colon_separated_range(
        src_scale_y_range)
    tgt_scale_y_range = cmdline_helpers.colon_separated_range(
        tgt_scale_y_range)

    import time
    import tqdm
    import math
    import tables
    import numpy as np
    from batchup import data_source, work_pool
    import image_dataset, visda17_dataset, office_dataset
    import network_architectures
    import augmentation
    import image_transforms
    from sklearn.model_selection import StratifiedShuffleSplit, ShuffleSplit
    import torch, torch.cuda
    from torch import nn
    from torch.nn import functional as F
    import optim_weight_ema

    if hide_progress_bar:
        progress_bar = None
    else:
        progress_bar = tqdm.tqdm

    with torch.cuda.device(device):
        pool = work_pool.WorkerThreadPool(num_threads)

        n_chn = 0
        half_batch_size = batch_size // 2

        if arch == '':
            if exp in {'train_val', 'train_test'}:
                arch = 'resnet50'

        if rnd_init:
            mean_value = np.array([0.5, 0.5, 0.5])
            std_value = np.array([0.5, 0.5, 0.5])
        else:
            mean_value = np.array([0.485, 0.456, 0.406])
            std_value = np.array([0.229, 0.224, 0.225])

        img_shape = (img_size, img_size)
        img_padding = (img_pad_width, img_pad_width)

        if exp == 'visda_train_val':
            d_source = visda17_dataset.TrainDataset(img_size=img_shape,
                                                    range01=True,
                                                    rgb_order=True)
            d_target = visda17_dataset.ValidationDataset(img_size=img_shape,
                                                         range01=True,
                                                         rgb_order=True)
        elif exp == 'visda_train_test':
            d_source = visda17_dataset.TrainDataset(img_size=img_shape,
                                                    range01=True,
                                                    rgb_order=True)
            d_target = visda17_dataset.TestDataset(img_size=img_shape,
                                                   range01=True,
                                                   rgb_order=True)

            if not skip_epoch_eval:
                print('WARNING: setting skip_epoch_eval to True')
                skip_epoch_eval = True
        elif exp == 'office_amazon_dslr':
            d_source = office_dataset.OfficeAmazonDataset(img_size=img_shape,
                                                          range01=True,
                                                          rgb_order=True)
            d_target = office_dataset.OfficeDSLRDataset(img_size=img_shape,
                                                        range01=True,
                                                        rgb_order=True)
        elif exp == 'office_amazon_webcam':
            d_source = office_dataset.OfficeAmazonDataset(img_size=img_shape,
                                                          range01=True,
                                                          rgb_order=True)
            d_target = office_dataset.OfficeWebcamDataset(img_size=img_shape,
                                                          range01=True,
                                                          rgb_order=True)
        elif exp == 'office_dslr_amazon':
            d_source = office_dataset.OfficeDSLRDataset(img_size=img_shape,
                                                        range01=True,
                                                        rgb_order=True)
            d_target = office_dataset.OfficeAmazonDataset(img_size=img_shape,
                                                          range01=True,
                                                          rgb_order=True)
        elif exp == 'office_dslr_webcam':
            d_source = office_dataset.OfficeDSLRDataset(img_size=img_shape,
                                                        range01=True,
                                                        rgb_order=True)
            d_target = office_dataset.OfficeWebcamDataset(img_size=img_shape,
                                                          range01=True,
                                                          rgb_order=True)
        elif exp == 'office_webcam_amazon':
            d_source = office_dataset.OfficeWebcamDataset(img_size=img_shape,
                                                          range01=True,
                                                          rgb_order=True)
            d_target = office_dataset.OfficeAmazonDataset(img_size=img_shape,
                                                          range01=True,
                                                          rgb_order=True)
        elif exp == 'office_webcam_dslr':
            d_source = office_dataset.OfficeWebcamDataset(img_size=img_shape,
                                                          range01=True,
                                                          rgb_order=True)
            d_target = office_dataset.OfficeDSLRDataset(img_size=img_shape,
                                                        range01=True,
                                                        rgb_order=True)
        else:
            print('Unknown experiment type \'{}\''.format(exp))
            return

        # Tensorboard log

        # Subset
        source_indices, target_indices, n_src, n_tgt = image_dataset.subset_indices(
            d_source, d_target, subsetsize, subsetseed)

        #
        # Result file
        #

        if result_file != '':
            cmdline_helpers.ensure_containing_dir_exists(result_file)
            h5_filters = tables.Filters(complevel=9, complib='blosc')
            f_target_pred = tables.open_file(result_file, mode='w')
            g_tgt_pred = f_target_pred.create_group(f_target_pred.root,
                                                    'target_pred_y',
                                                    'Target prediction')
            if record_history:
                arr_tgt_pred_history = f_target_pred.create_earray(
                    g_tgt_pred,
                    'y_prob_history',
                    tables.Float32Atom(), (0, n_tgt, d_target.n_classes),
                    filters=h5_filters)
            else:
                arr_tgt_pred_history = None
        else:
            arr_tgt_pred_history = None
            f_target_pred = None
            g_tgt_pred = None

        n_classes = d_source.n_classes

        print('Loaded data')

        net_class = network_architectures.get_build_fn_for_architecture(arch)

        student_net = net_class(n_classes, img_size, use_dropout,
                                not rnd_init).cuda()
        teacher_net = net_class(n_classes, img_size, use_dropout,
                                not rnd_init).cuda()
        student_params = list(student_net.parameters())
        teacher_params = list(teacher_net.parameters())
        for param in teacher_params:
            param.requires_grad = False

        if rnd_init:
            new_student_optimizer = torch.optim.Adam(student_params,
                                                     lr=learning_rate)
            pretrained_student_optimizer = None
        else:
            named_params = list(student_net.named_parameters())
            new_params = []
            pretrained_params = []
            for name, param in named_params:
                if name.startswith('new_'):
                    new_params.append(param)
                else:
                    fix = False
                    for lyr in fix_layers:
                        if name.startswith(lyr + '.'):
                            fix = True
                            break
                    if not fix:
                        pretrained_params.append(param)
                    else:
                        print('Fixing param {}'.format(name))
                        param.requires_grad = False

            new_student_optimizer = torch.optim.Adam(new_params,
                                                     lr=learning_rate)
            if len(pretrained_params) > 0:
                pretrained_student_optimizer = torch.optim.Adam(
                    pretrained_params, lr=learning_rate * pretrained_lr_factor)
            else:
                pretrained_student_optimizer = None
        teacher_optimizer = optim_weight_ema.WeightEMA(teacher_params,
                                                       student_params,
                                                       alpha=teacher_alpha)
        classification_criterion = nn.CrossEntropyLoss()

        print('Built network')

        # Image augmentation

        src_aug = augmentation.ImageAugmentation(
            src_hflip,
            src_xlat_range,
            src_affine_std,
            rot_std=src_rot_std,
            intens_scale_range_lower=src_intens_scale_range_lower,
            intens_scale_range_upper=src_intens_scale_range_upper,
            colour_rot_std=src_colour_rot_std,
            colour_off_std=src_colour_off_std,
            greyscale=src_greyscale,
            scale_u_range=src_scale_u_range,
            scale_x_range=src_scale_x_range,
            scale_y_range=src_scale_y_range,
            cutout_probability=src_cutout_prob,
            cutout_size=src_cutout_size)

        tgt_aug = augmentation.ImageAugmentation(
            tgt_hflip,
            tgt_xlat_range,
            tgt_affine_std,
            rot_std=tgt_rot_std,
            intens_scale_range_lower=tgt_intens_scale_range_lower,
            intens_scale_range_upper=tgt_intens_scale_range_upper,
            colour_rot_std=tgt_colour_rot_std,
            colour_off_std=tgt_colour_off_std,
            greyscale=tgt_greyscale,
            scale_u_range=tgt_scale_u_range,
            scale_x_range=tgt_scale_x_range,
            scale_y_range=tgt_scale_y_range,
            cutout_probability=tgt_cutout_prob,
            cutout_size=tgt_cutout_size)

        test_aug = augmentation.ImageAugmentation(
            tgt_hflip,
            tgt_xlat_range,
            0.0,
            rot_std=0.0,
            scale_u_range=tgt_scale_u_range,
            scale_x_range=tgt_scale_x_range,
            scale_y_range=tgt_scale_y_range)

        border_value = int(np.mean(mean_value) * 255 + 0.5)

        sup_xf = image_transforms.Compose(
            image_transforms.ScaleCropAndAugmentAffine(img_shape, img_padding,
                                                       True, src_aug,
                                                       border_value,
                                                       mean_value, std_value),
            image_transforms.ToTensor(),
        )

        if constrain_crop >= 0:
            unsup_xf = image_transforms.Compose(
                image_transforms.ScaleCropAndAugmentAffinePair(
                    img_shape, img_padding, constrain_crop, True, tgt_aug,
                    border_value, mean_value, std_value),
                image_transforms.ToTensor(),
            )
        else:
            unsup_xf = image_transforms.Compose(
                image_transforms.ScaleCropAndAugmentAffine(
                    img_shape, img_padding, True, tgt_aug, border_value,
                    mean_value, std_value),
                image_transforms.ToTensor(),
            )

        test_xf = image_transforms.Compose(
            image_transforms.ScaleAndCrop(img_shape, img_padding, False),
            image_transforms.ToTensor(),
            image_transforms.Standardise(mean_value, std_value),
        )

        test_xf_aug_mult = image_transforms.Compose(
            image_transforms.ScaleCropAndAugmentAffineMultiple(
                16, img_shape, img_padding, True, test_aug, border_value,
                mean_value, std_value),
            image_transforms.ToTensorMultiple(),
        )

        if constrain_crop >= 0:

            def augment(X_sup, y_sup, X_tgt):
                X_sup = sup_xf(X_sup)[0]
                X_unsup_both = unsup_xf(X_tgt)[0]
                X_unsup_stu = X_unsup_both[:len(X_tgt)]
                X_unsup_tea = X_unsup_both[len(X_tgt):]
                return X_sup, y_sup, X_unsup_stu, X_unsup_tea
        else:

            def augment(X_sup, y_sup, X_tgt):
                X_sup = sup_xf(X_sup)[0]
                X_unsup_stu = unsup_xf(X_tgt)[0]
                X_unsup_tea = unsup_xf(X_tgt)[0]
                return X_sup, y_sup, X_unsup_stu, X_unsup_tea

        cls_bal_fn = network_architectures.get_cls_bal_function(
            cls_balance_loss)

        def compute_aug_loss(stu_out, tea_out):
            # Augmentation loss
            conf_tea = torch.max(tea_out, 1)[0]
            conf_mask = torch.gt(conf_tea, confidence_thresh).float()

            d_aug_loss = stu_out - tea_out
            aug_loss = d_aug_loss * d_aug_loss

            aug_loss = torch.mean(aug_loss, 1) * conf_mask

            # Class balance loss
            if cls_balance > 0.0:
                # Average over samples to get average class prediction
                avg_cls_prob = torch.mean(stu_out, 0)
                # Compute loss
                equalise_cls_loss = cls_bal_fn(avg_cls_prob,
                                               float(1.0 / n_classes))

                equalise_cls_loss = torch.mean(equalise_cls_loss) * n_classes

                equalise_cls_loss = equalise_cls_loss * torch.mean(
                    conf_mask, 0)
            else:
                equalise_cls_loss = None

            return aug_loss, conf_mask, equalise_cls_loss

        _one = torch.autograd.Variable(
            torch.from_numpy(np.array([1.0]).astype(np.float32)).cuda())

        def f_train(X_sup, y_sup, X_unsup0, X_unsup1):
            X_sup = torch.autograd.Variable(torch.from_numpy(X_sup).cuda())
            y_sup = torch.autograd.Variable(
                torch.from_numpy(y_sup).long().cuda())
            X_unsup0 = torch.autograd.Variable(
                torch.from_numpy(X_unsup0).cuda())
            X_unsup1 = torch.autograd.Variable(
                torch.from_numpy(X_unsup1).cuda())

            if pretrained_student_optimizer is not None:
                pretrained_student_optimizer.zero_grad()
            new_student_optimizer.zero_grad()
            student_net.train(mode=True)
            teacher_net.train(mode=True)

            sup_logits_out = student_net(X_sup)
            student_unsup_logits_out = student_net(X_unsup0)
            student_unsup_prob_out = F.softmax(student_unsup_logits_out)
            teacher_unsup_logits_out = teacher_net(X_unsup1)
            teacher_unsup_prob_out = F.softmax(teacher_unsup_logits_out)

            # Supervised classification loss
            if double_softmax:
                clf_loss = classification_criterion(F.softmax(sup_logits_out),
                                                    y_sup)
            else:
                clf_loss = classification_criterion(sup_logits_out, y_sup)

            aug_loss, conf_mask, cls_bal_loss = compute_aug_loss(
                student_unsup_prob_out, teacher_unsup_prob_out)

            conf_mask_count = torch.sum(conf_mask)

            unsup_loss = torch.mean(aug_loss)
            loss_expr = clf_loss + unsup_loss * unsup_weight
            if cls_bal_loss is not None:
                loss_expr = loss_expr + cls_bal_loss * cls_balance * unsup_weight

            loss_expr.backward()
            if pretrained_student_optimizer is not None:
                pretrained_student_optimizer.step()
            new_student_optimizer.step()
            teacher_optimizer.step()

            n_samples = X_sup.size()[0]

            mask_count = conf_mask_count.data.cpu()[0]

            outputs = [
                float(clf_loss.data.cpu()[0]) * n_samples,
                float(unsup_loss.data.cpu()[0]) * n_samples, mask_count
            ]
            return tuple(outputs)

        print('Compiled training function')

        def f_pred_src(X_sup):
            X_var = torch.autograd.Variable(torch.from_numpy(X_sup).cuda())
            teacher_net.train(mode=False)
            return (F.softmax(teacher_net(X_var)).data.cpu().numpy(), )

        def f_pred_tgt(X_sup):
            X_var = torch.autograd.Variable(torch.from_numpy(X_sup).cuda())
            teacher_net.train(mode=False)
            return (F.softmax(teacher_net(X_var)).data.cpu().numpy(), )

        def f_pred_tgt_mult(X_sup):
            teacher_net.train(mode=False)
            y_pred_aug = []
            for aug_i in range(len(X_sup)):
                X_var = torch.autograd.Variable(
                    torch.from_numpy(X_sup[aug_i, ...]).cuda())
                y_pred = F.softmax(teacher_net(X_var)).data.cpu().numpy()
                y_pred_aug.append(y_pred[None, ...])
            y_pred_aug = np.concatenate(y_pred_aug, axis=0)
            return (y_pred_aug.mean(axis=0), )

        print('Compiled evaluation function')

        # Setup output
        cmdline_helpers.ensure_containing_dir_exists(log_file)

        def log(text):
            print(text)
            if log_file is not None:
                with open(log_file, 'a') as f:
                    f.write(text + '\n')
                    f.flush()
                    f.close()

        # Report setttings
        log('Program = {}'.format(sys.argv[0]))
        log('Settings: {}'.format(', '.join([
            '{}={}'.format(key, settings[key])
            for key in sorted(list(settings.keys()))
        ])))

        # Report dataset size
        log('Dataset:')
        log('SOURCE len(X)={}, y.shape={}'.format(len(d_source.images),
                                                  d_source.y.shape))
        log('TARGET len(X)={}'.format(len(d_target.images)))

        if epoch_size == 'source':
            n_samples = n_src
        elif epoch_size == 'target':
            n_samples = n_tgt
        else:
            n_samples = epoch_size
        n_train_batches = n_samples // batch_size
        n_test_batches = n_tgt // (batch_size * 2) + 1

        print('Training...')
        sup_ds = data_source.ArrayDataSource([d_source.images, d_source.y],
                                             repeats=-1,
                                             indices=source_indices)
        tgt_train_ds = data_source.ArrayDataSource([d_target.images],
                                                   repeats=-1,
                                                   indices=target_indices)
        train_ds = data_source.CompositeDataSource([sup_ds,
                                                    tgt_train_ds]).map(augment)
        train_ds = pool.parallel_data_source(train_ds,
                                             batch_buffer_size=min(
                                                 20, n_train_batches))

        target_ds_for_test = data_source.ArrayDataSource(
            [d_target.images], indices=target_indices)
        target_test_ds = target_ds_for_test.map(test_xf)
        target_test_ds = pool.parallel_data_source(target_test_ds,
                                                   batch_buffer_size=min(
                                                       20, n_test_batches))
        target_mult_test_ds = target_ds_for_test.map(test_xf_aug_mult)
        target_mult_test_ds = pool.parallel_data_source(target_mult_test_ds,
                                                        batch_buffer_size=min(
                                                            20,
                                                            n_test_batches))

        if seed != 0:
            shuffle_rng = np.random.RandomState(seed)
        else:
            shuffle_rng = np.random

        if d_target.has_ground_truth:
            evaluator = d_target.prediction_evaluator(target_indices)
        else:
            evaluator = None

        best_mask_rate = 0.0
        best_teacher_model_state = {
            k: v.cpu().numpy()
            for k, v in teacher_net.state_dict().items()
        }

        train_batch_iter = train_ds.batch_iterator(batch_size=batch_size,
                                                   shuffle=shuffle_rng)

        for epoch in range(num_epochs):
            t1 = time.time()

            if not skip_epoch_eval:
                test_batch_iter = target_test_ds.batch_iterator(
                    batch_size=batch_size * 2)
            else:
                test_batch_iter = None

            train_clf_loss, train_unsup_loss, mask_rate = data_source.batch_map_mean(
                f_train,
                train_batch_iter,
                progress_iter_func=progress_bar,
                n_batches=n_train_batches)

            # train_clf_loss, train_unsup_loss, mask_rate = train_ds.batch_map_mean(
            #     f_train, batch_size=batch_size, shuffle=shuffle_rng, n_batches=n_train_batches,
            #     progress_iter_func=progress_bar)

            if mask_rate > best_mask_rate:
                best_mask_rate = mask_rate
                improve = True
                improve_str = '*** '
                best_teacher_model_state = {
                    k: v.cpu().numpy()
                    for k, v in teacher_net.state_dict().items()
                }
            else:
                improve = False
                improve_str = ''

            if not skip_epoch_eval:
                tgt_pred_prob_y, = data_source.batch_map_concat(
                    f_pred_tgt,
                    test_batch_iter,
                    progress_iter_func=progress_bar)
                mean_class_acc, cls_acc_str = evaluator.evaluate(
                    tgt_pred_prob_y)
                t2 = time.time()

                log('{}Epoch {} took {:.2f}s: TRAIN clf loss={:.6f}, unsup loss={:.6f}, mask={:.3%}; '
                    'TGT mean class acc={:.3%}'.format(improve_str, epoch,
                                                       t2 - t1, train_clf_loss,
                                                       train_unsup_loss,
                                                       mask_rate,
                                                       mean_class_acc))
                log('  per class:  {}'.format(cls_acc_str))

                # Save results
                if arr_tgt_pred_history is not None:
                    arr_tgt_pred_history.append(
                        tgt_pred_prob_y[None, ...].astype(np.float32))
            else:
                t2 = time.time()
                log('{}Epoch {} took {:.2f}s: TRAIN clf loss={:.6f}, unsup loss={:.6f}, mask={:.3%}'
                    .format(improve_str, epoch, t2 - t1, train_clf_loss,
                            train_unsup_loss, mask_rate))

        # Save network
        if model_file != '':
            cmdline_helpers.ensure_containing_dir_exists(model_file)
            with open(model_file, 'wb') as f:
                pickle.dump(best_teacher_model_state, f)

        # Restore network to best state
        teacher_net.load_state_dict({
            k: torch.from_numpy(v)
            for k, v in best_teacher_model_state.items()
        })

        # Predict on test set, without augmentation
        tgt_pred_prob_y, = target_test_ds.batch_map_concat(
            f_pred_tgt, batch_size=batch_size, progress_iter_func=progress_bar)

        if d_target.has_ground_truth:
            mean_class_acc, cls_acc_str = evaluator.evaluate(tgt_pred_prob_y)

            log('FINAL: TGT mean class acc={:.3%}'.format(mean_class_acc))
            log('  per class:  {}'.format(cls_acc_str))

        # Predict on test set, using augmentation
        tgt_aug_pred_prob_y, = target_mult_test_ds.batch_map_concat(
            f_pred_tgt_mult,
            batch_size=batch_size,
            progress_iter_func=progress_bar)
        if d_target.has_ground_truth:
            aug_mean_class_acc, aug_cls_acc_str = evaluator.evaluate(
                tgt_aug_pred_prob_y)

            log('FINAL: TGT AUG mean class acc={:.3%}'.format(
                aug_mean_class_acc))
            log('  per class:  {}'.format(aug_cls_acc_str))

        if f_target_pred is not None:
            f_target_pred.create_array(g_tgt_pred, 'y_prob', tgt_pred_prob_y)
            f_target_pred.create_array(g_tgt_pred, 'y_prob_aug',
                                       tgt_aug_pred_prob_y)
            f_target_pred.close()
def experiment(plot_path, ds_name, no_aug, affine_std, scale_u_range,
               scale_x_range, scale_y_range, xlat_range, hflip, intens_flip,
               intens_scale_range, intens_offset_range, grid_h, grid_w, seed):
    settings = locals().copy()

    import os
    import sys
    import cmdline_helpers

    intens_scale_range_lower, intens_scale_range_upper = cmdline_helpers.colon_separated_range(
        intens_scale_range)
    intens_offset_range_lower, intens_offset_range_upper = cmdline_helpers.colon_separated_range(
        intens_offset_range)
    scale_u_range = cmdline_helpers.colon_separated_range(scale_u_range)
    scale_x_range = cmdline_helpers.colon_separated_range(scale_x_range)
    scale_y_range = cmdline_helpers.colon_separated_range(scale_y_range)

    import numpy as np
    # from skimage.util import montage2d
    from skimage.util import montage as montage2d
    from PIL import Image
    from batchup import data_source
    import data_loaders
    import augmentation

    n_chn = 0

    if ds_name == 'mnist':
        d_source = data_loaders.load_mnist(zero_centre=False)
    elif ds_name == 'usps':
        d_source = data_loaders.load_usps(zero_centre=False, scale28=True)
    elif ds_name == 'svhn_grey':
        d_source = data_loaders.load_svhn(zero_centre=False, greyscale=True)
    elif ds_name == 'svhn':
        d_source = data_loaders.load_svhn(zero_centre=False, greyscale=False)
    elif ds_name == 'cifar':
        d_source = data_loaders.load_cifar10()
    elif ds_name == 'stl':
        d_source = data_loaders.load_stl()
    elif ds_name == 'syndigits':
        d_source = data_loaders.load_syn_digits(zero_centre=False,
                                                greyscale=False)
    elif ds_name == 'synsigns':
        d_source = data_loaders.load_syn_signs(zero_centre=False,
                                               greyscale=False)
    elif ds_name == 'gtsrb':
        d_source = data_loaders.load_gtsrb(zero_centre=False, greyscale=False)
    else:
        print('Unknown dataset \'{}\''.format(ds_name))
        return

    # Delete the training ground truths as we should not be using them
    del d_source.train_y

    n_classes = d_source.n_classes

    print('Loaded data')

    src_aug = augmentation.ImageAugmentation(
        hflip,
        xlat_range,
        affine_std,
        intens_flip=intens_flip,
        intens_scale_range_lower=intens_scale_range_lower,
        intens_scale_range_upper=intens_scale_range_upper,
        intens_offset_range_lower=intens_offset_range_lower,
        intens_offset_range_upper=intens_offset_range_upper,
        scale_u_range=scale_u_range,
        scale_x_range=scale_x_range,
        scale_y_range=scale_y_range)

    def augment(X):
        if not no_aug:
            X = src_aug.augment(X)
        return X,

    rampup_weight_in_list = [0]

    print('Rendering...')
    train_ds = data_source.ArrayDataSource([d_source.train_X],
                                           repeats=-1).map(augment)
    n_samples = len(d_source.train_X)

    if seed != 0:
        shuffle_rng = np.random.RandomState(seed)
    else:
        shuffle_rng = np.random

    batch_size = grid_h * grid_w
    display_batch_iter = train_ds.batch_iterator(batch_size=batch_size,
                                                 shuffle=shuffle_rng)

    best_src_test_err = 1.0

    x_batch, = next(display_batch_iter)

    montage = []
    for chn_i in range(x_batch.shape[1]):
        m = montage2d(x_batch[:, chn_i, :, :], grid_shape=(grid_h, grid_w))
        montage.append(m[:, :, None])
    montage = np.concatenate(montage, axis=2)

    if montage.shape[2] == 1:
        montage = montage[:, :, 0]

    lower = min(0.0, montage.min())
    upper = max(1.0, montage.max())
    montage = (montage - lower) / (upper - lower)
    montage = (np.clip(montage, 0.0, 1.0) * 255.0).astype(np.uint8)

    Image.fromarray(montage).save(plot_path)
예제 #5
0
def experiment(exp, arch, learning_rate, standardise_samples, affine_std,
               xlat_range, hflip, intens_flip, intens_scale_range,
               intens_offset_range, gaussian_noise_std, num_epochs, batch_size,
               seed, log_file, device):
    import os
    import sys
    import cmdline_helpers

    if log_file == '':
        log_file = 'output_aug_log_{}.txt'.format(exp)
    elif log_file == 'none':
        log_file = None

    if log_file is not None:
        if os.path.exists(log_file):
            print('Output log file {} already exists'.format(log_file))
            return

    intens_scale_range_lower, intens_scale_range_upper, intens_offset_range_lower, intens_offset_range_upper = \
        cmdline_helpers.intens_aug_options(intens_scale_range, intens_offset_range)

    import time
    import math
    import numpy as np
    from batchup import data_source, work_pool
    import data_loaders
    import standardisation
    import network_architectures
    import augmentation
    import torch, torch.cuda
    from torch import nn
    from torch.nn import functional as F

    with torch.cuda.device(device):
        pool = work_pool.WorkerThreadPool(2)

        n_chn = 0

        if exp == 'svhn_mnist':
            d_source = data_loaders.load_svhn(zero_centre=False,
                                              greyscale=True)
            d_target = data_loaders.load_mnist(invert=False,
                                               zero_centre=False,
                                               pad32=True,
                                               val=False)
        elif exp == 'mnist_svhn':
            d_source = data_loaders.load_mnist(invert=False,
                                               zero_centre=False,
                                               pad32=True)
            d_target = data_loaders.load_svhn(zero_centre=False,
                                              greyscale=True,
                                              val=False)
        elif exp == 'svhn_mnist_rgb':
            d_source = data_loaders.load_svhn(zero_centre=False,
                                              greyscale=False)
            d_target = data_loaders.load_mnist(invert=False,
                                               zero_centre=False,
                                               pad32=True,
                                               val=False,
                                               rgb=True)
        elif exp == 'mnist_svhn_rgb':
            d_source = data_loaders.load_mnist(invert=False,
                                               zero_centre=False,
                                               pad32=True,
                                               rgb=True)
            d_target = data_loaders.load_svhn(zero_centre=False,
                                              greyscale=False,
                                              val=False)
        elif exp == 'cifar_stl':
            d_source = data_loaders.load_cifar10(range_01=False)
            d_target = data_loaders.load_stl(zero_centre=False, val=False)
        elif exp == 'stl_cifar':
            d_source = data_loaders.load_stl(zero_centre=False)
            d_target = data_loaders.load_cifar10(range_01=False, val=False)
        elif exp == 'mnist_usps':
            d_source = data_loaders.load_mnist(zero_centre=False)
            d_target = data_loaders.load_usps(zero_centre=False,
                                              scale28=True,
                                              val=False)
        elif exp == 'usps_mnist':
            d_source = data_loaders.load_usps(zero_centre=False, scale28=True)
            d_target = data_loaders.load_mnist(zero_centre=False, val=False)
        elif exp == 'syndigits_svhn':
            d_source = data_loaders.load_syn_digits(zero_centre=False)
            d_target = data_loaders.load_svhn(zero_centre=False, val=False)
        elif exp == 'svhn_syndigits':
            d_source = data_loaders.load_svhn(zero_centre=False, val=False)
            d_target = data_loaders.load_syn_digits(zero_centre=False)
        elif exp == 'synsigns_gtsrb':
            d_source = data_loaders.load_syn_signs(zero_centre=False)
            d_target = data_loaders.load_gtsrb(zero_centre=False, val=False)
        elif exp == 'gtsrb_synsigns':
            d_source = data_loaders.load_gtsrb(zero_centre=False, val=False)
            d_target = data_loaders.load_syn_signs(zero_centre=False)
        else:
            print('Unknown experiment type \'{}\''.format(exp))
            return

        # Delete the training ground truths as we should not be using them
        del d_target.train_y

        if standardise_samples:
            standardisation.standardise_dataset(d_source)
            standardisation.standardise_dataset(d_target)

        n_classes = d_source.n_classes

        print('Loaded data')

        if arch == '':
            if exp in {'mnist_usps', 'usps_mnist'}:
                arch = 'mnist-bn-32-64-256'
            if exp in {'svhn_mnist', 'mnist_svhn'}:
                arch = 'grey-32-64-128-gp'
            if exp in {
                    'cifar_stl', 'stl_cifar', 'syndigits_svhn',
                    'svhn_syndigits', 'svhn_mnist_rgb', 'mnist_svhn_rgb'
            }:
                arch = 'rgb-48-96-192-gp'
            if exp in {'synsigns_gtsrb', 'gtsrb_synsigns'}:
                arch = 'rgb40-48-96-192-384-gp'

        net_class, expected_shape = network_architectures.get_net_and_shape_for_architecture(
            arch)

        if expected_shape != d_source.train_X.shape[1:]:
            print(
                'Architecture {} not compatible with experiment {}; it needs samples of shape {}, '
                'data has samples of shape {}'.format(
                    arch, exp, expected_shape, d_source.train_X.shape[1:]))
            return

        net = net_class(n_classes).cuda()
        params = list(net.parameters())

        optimizer = torch.optim.Adam(params, lr=learning_rate)
        classification_criterion = nn.CrossEntropyLoss()

        print('Built network')

        aug = augmentation.ImageAugmentation(
            hflip,
            xlat_range,
            affine_std,
            intens_scale_range_lower=intens_scale_range_lower,
            intens_scale_range_upper=intens_scale_range_upper,
            intens_offset_range_lower=intens_offset_range_lower,
            intens_offset_range_upper=intens_offset_range_upper,
            intens_flip=intens_flip,
            gaussian_noise_std=gaussian_noise_std)

        def augment(X_sup, y_sup):
            X_sup = aug.augment(X_sup)
            return [X_sup, y_sup]

        def f_train(X_sup, y_sup):
            X_sup = torch.autograd.Variable(torch.from_numpy(X_sup).cuda())
            y_sup = torch.autograd.Variable(
                torch.from_numpy(y_sup).long().cuda())

            optimizer.zero_grad()
            net.train(mode=True)

            sup_logits_out = net(X_sup)

            # Supervised classification loss
            clf_loss = classification_criterion(sup_logits_out, y_sup)

            loss_expr = clf_loss

            loss_expr.backward()
            optimizer.step()

            n_samples = X_sup.size()[0]

            return float(clf_loss.data.cpu().numpy()) * n_samples

        print('Compiled training function')

        def f_pred_src(X_sup):
            X_var = torch.autograd.Variable(torch.from_numpy(X_sup).cuda())
            net.train(mode=False)
            return F.softmax(net(X_var)).data.cpu().numpy()

        def f_pred_tgt(X_sup):
            X_var = torch.autograd.Variable(torch.from_numpy(X_sup).cuda())
            net.train(mode=False)
            return F.softmax(net(X_var)).data.cpu().numpy()

        def f_eval_src(X_sup, y_sup):
            y_pred_prob = f_pred_src(X_sup)
            y_pred = np.argmax(y_pred_prob, axis=1)
            return float((y_pred != y_sup).sum())

        def f_eval_tgt(X_sup, y_sup):
            y_pred_prob = f_pred_tgt(X_sup)
            y_pred = np.argmax(y_pred_prob, axis=1)
            return float((y_pred != y_sup).sum())

        print('Compiled evaluation function')

        # Setup output
        def log(text):
            print(text)
            if log_file is not None:
                with open(log_file, 'a') as f:
                    f.write(text + '\n')
                    f.flush()
                    f.close()

        cmdline_helpers.ensure_containing_dir_exists(log_file)

        # Report setttings
        log('sys.argv={}'.format(sys.argv))

        # Report dataset size
        log('Dataset:')
        log('SOURCE Train: X.shape={}, y.shape={}'.format(
            d_source.train_X.shape, d_source.train_y.shape))
        log('SOURCE Test: X.shape={}, y.shape={}'.format(
            d_source.test_X.shape, d_source.test_y.shape))
        log('TARGET Train: X.shape={}'.format(d_target.train_X.shape))
        log('TARGET Test: X.shape={}, y.shape={}'.format(
            d_target.test_X.shape, d_target.test_y.shape))

        print('Training...')
        train_ds = data_source.ArrayDataSource(
            [d_source.train_X, d_source.train_y]).map(augment)

        source_test_ds = data_source.ArrayDataSource(
            [d_source.test_X, d_source.test_y])
        target_test_ds = data_source.ArrayDataSource(
            [d_target.test_X, d_target.test_y])

        if seed != 0:
            shuffle_rng = np.random.RandomState(seed)
        else:
            shuffle_rng = np.random

        best_src_test_err = 1.0
        for epoch in range(num_epochs):
            t1 = time.time()

            train_res = train_ds.batch_map_mean(f_train,
                                                batch_size=batch_size,
                                                shuffle=shuffle_rng)

            train_clf_loss = train_res[0]
            src_test_err, = source_test_ds.batch_map_mean(
                f_eval_src, batch_size=batch_size * 4)
            tgt_test_err, = target_test_ds.batch_map_mean(
                f_eval_tgt, batch_size=batch_size * 4)

            t2 = time.time()

            if src_test_err < best_src_test_err:
                log('*** Epoch {} took {:.2f}s: TRAIN clf loss={:.6f}; '
                    'SRC TEST ERR={:.3%}, TGT TEST err={:.3%}'.format(
                        epoch, t2 - t1, train_clf_loss, src_test_err,
                        tgt_test_err))
                best_src_test_err = src_test_err
            else:
                log('Epoch {} took {:.2f}s: TRAIN clf loss={:.6f}; '
                    'SRC TEST ERR={:.3%}, TGT TEST err={:.3%}'.format(
                        epoch, t2 - t1, train_clf_loss, src_test_err,
                        tgt_test_err))
예제 #6
0
    def train(self):

        curr_iter = 0

        reallabel = torch.FloatTensor(self.opt.batchSize).fill_(
            self.real_label_val)
        fakelabel = torch.FloatTensor(self.opt.batchSize).fill_(
            self.fake_label_val)
        if self.opt.gpu >= 0:
            reallabel, fakelabel = reallabel.cuda(), fakelabel.cuda()
        reallabelv = Variable(reallabel)
        fakelabelv = Variable(fakelabel)

        # parameters
        src_hflip = False
        src_xlat_range = 2.0
        src_affine_std = 0.1
        src_intens_flip = False
        src_intens_scale_range_lower = -1.5
        src_intens_scale_range_upper = 1.5
        src_intens_offset_range_lower = -0.5
        src_intens_offset_range_upper = 0.5
        src_gaussian_noise_std = 0.1
        tgt_hflip = False
        tgt_xlat_range = 2.0
        tgt_affine_std = 0.1
        tgt_intens_flip = False
        tgt_intens_scale_range_lower = -1.5
        tgt_intens_scale_range_upper = 1.5
        tgt_intens_offset_range_lower = -0.5
        tgt_intens_offset_range_upper = 0.5
        tgt_gaussian_noise_std = 0.1

        # augmentation function
        src_aug = augmentation.ImageAugmentation(
            src_hflip,
            src_xlat_range,
            src_affine_std,
            intens_flip=src_intens_flip,
            intens_scale_range_lower=src_intens_scale_range_lower,
            intens_scale_range_upper=src_intens_scale_range_upper,
            intens_offset_range_lower=src_intens_offset_range_lower,
            intens_offset_range_upper=src_intens_offset_range_upper,
            gaussian_noise_std=src_gaussian_noise_std)
        tgt_aug = augmentation.ImageAugmentation(
            tgt_hflip,
            tgt_xlat_range,
            tgt_affine_std,
            intens_flip=tgt_intens_flip,
            intens_scale_range_lower=tgt_intens_scale_range_lower,
            intens_scale_range_upper=tgt_intens_scale_range_upper,
            intens_offset_range_lower=tgt_intens_offset_range_lower,
            intens_offset_range_upper=tgt_intens_offset_range_upper,
            gaussian_noise_std=tgt_gaussian_noise_std)

        combine_batches = False

        if combine_batches:

            def augment(X_sup, y_src, X_tgt):
                X_src_stu, X_src_tea = src_aug.augment_pair(X_sup)
                X_tgt_stu, X_tgt_tea = tgt_aug.augment_pair(X_tgt)
                return X_src_stu, X_src_tea, y_src, X_tgt_stu, X_tgt_tea
        else:

            def augment(X_src, y_src, X_tgt):
                X_src = src_aug.augment(X_src)
                X_tgt_stu, X_tgt_tea = tgt_aug.augment_pair(X_tgt)
                return X_src, y_src, X_tgt_stu, X_tgt_tea

        for epoch in range(self.opt.nepochs):

            self.netG.train()
            self.netF.train()
            self.netC.train()
            self.netD.train()

            for i, (datas, datat) in enumerate(
                    zip(self.source_trainloader, self.targetloader)):

                ###########################
                # Forming input variables
                ###########################

                src_inputs, src_labels = datas
                tgt_inputs, __ = datat
                if self.augment:
                    if combine_batches:
                        src_inputs, _, src_labels, tgt_inputs, _ = augment(
                            src_inputs, src_labels, tgt_inputs)
                    else:
                        src_inputs, src_labels, tgt_inputs, _ = augment(
                            src_inputs.numpy(), src_labels.numpy(),
                            tgt_inputs.numpy())
                    src_inputs = torch.FloatTensor(src_inputs)
                    src_labels = torch.LongTensor(src_labels)
                    tgt_inputs = torch.FloatTensor(tgt_inputs)

                src_inputs_unnorm = ((
                    (src_inputs * self.std[0]) + self.mean[0]) - 0.5) * 2

                # Creating one hot vector
                labels_onehot = np.zeros(
                    (self.opt.batchSize, self.nclasses + 1), dtype=np.float32)
                for num in range(self.opt.batchSize):
                    labels_onehot[num, src_labels[num]] = 1
                src_labels_onehot = torch.from_numpy(labels_onehot)

                labels_onehot = np.zeros(
                    (self.opt.batchSize, self.nclasses + 1), dtype=np.float32)
                for num in range(self.opt.batchSize):
                    labels_onehot[num, self.nclasses] = 1
                tgt_labels_onehot = torch.from_numpy(labels_onehot)

                if self.opt.gpu >= 0:
                    src_inputs, src_labels = src_inputs.cuda(
                    ), src_labels.cuda()
                    src_inputs_unnorm = src_inputs_unnorm.cuda()
                    tgt_inputs = tgt_inputs.cuda()
                    src_labels_onehot = src_labels_onehot.cuda()
                    tgt_labels_onehot = tgt_labels_onehot.cuda()

                # Wrapping in variable
                src_inputsv, src_labelsv = Variable(src_inputs), Variable(
                    src_labels)
                src_inputs_unnormv = Variable(src_inputs_unnorm)
                tgt_inputsv = Variable(tgt_inputs)
                src_labels_onehotv = Variable(src_labels_onehot)
                tgt_labels_onehotv = Variable(tgt_labels_onehot)

                ###########################
                # Updates
                ###########################

                # Updating D network

                self.netD.zero_grad()
                src_emb = self.netF(src_inputsv)
                src_emb_cat = torch.cat((src_labels_onehotv, src_emb), 1)
                src_gen = self.netG(src_emb_cat)

                tgt_emb = self.netF(tgt_inputsv)
                tgt_emb_cat = torch.cat((tgt_labels_onehotv, tgt_emb), 1)
                tgt_gen = self.netG(tgt_emb_cat)

                src_realoutputD_s, src_realoutputD_c = self.netD(
                    src_inputs_unnormv)
                errD_src_real_s = self.criterion_s(src_realoutputD_s,
                                                   reallabelv)
                errD_src_real_c = self.criterion_c(src_realoutputD_c,
                                                   src_labelsv)

                src_fakeoutputD_s, src_fakeoutputD_c = self.netD(src_gen)
                errD_src_fake_s = self.criterion_s(src_fakeoutputD_s,
                                                   fakelabelv)

                tgt_fakeoutputD_s, tgt_fakeoutputD_c = self.netD(tgt_gen)
                errD_tgt_fake_s = self.criterion_s(tgt_fakeoutputD_s,
                                                   fakelabelv)

                errD = errD_src_real_c + errD_src_real_s + errD_src_fake_s + errD_tgt_fake_s
                #TODO add CBL to D loss
                if self.class_balance > 0.0:
                    avg_cls_prob = torch.mean(tgt_fakeoutputD_c, 0)
                    equalise_cls_loss = self.cls_bal_fn(
                        avg_cls_prob, float(1.0 / self.nclasses))
                    equalise_cls_loss = torch.mean(
                        equalise_cls_loss) * self.nclasses
                    errD += equalise_cls_loss * self.class_balance
                errD.backward(retain_graph=True)
                self.optimizerD.step()

                # Updating G network

                self.netG.zero_grad()
                src_fakeoutputD_s, src_fakeoutputD_c = self.netD(src_gen)
                errG_c = self.criterion_c(src_fakeoutputD_c, src_labelsv)
                errG_s = self.criterion_s(src_fakeoutputD_s, reallabelv)
                errG = errG_c + errG_s
                errG.backward(retain_graph=True)
                self.optimizerG.step()

                # Updating C network

                self.netC.zero_grad()
                outC = self.netC(src_emb)
                errC = self.criterion_c(outC, src_labelsv)
                errC.backward(retain_graph=True)
                self.optimizerC.step()

                # Updating F network

                self.netF.zero_grad()
                errF_fromC = self.criterion_c(outC, src_labelsv)

                src_fakeoutputD_s, src_fakeoutputD_c = self.netD(src_gen)
                errF_src_fromD = self.criterion_c(
                    src_fakeoutputD_c, src_labelsv) * (self.opt.adv_weight)

                tgt_fakeoutputD_s, tgt_fakeoutputD_c = self.netD(tgt_gen)

                #TODO add CBL to D gradient
                errF_tgt_fromD = self.criterion_s(
                    tgt_fakeoutputD_s,
                    reallabelv) * (self.opt.adv_weight * self.opt.alpha)

                errF = errF_fromC + errF_src_fromD + errF_tgt_fromD
                if self.class_balance > 0.0:
                    avg_cls_prob = torch.mean(tgt_fakeoutputD_c, 0)
                    equalise_cls_loss = self.cls_bal_fn(
                        avg_cls_prob, float(1.0 / self.nclasses))
                    equalise_cls_loss = torch.mean(
                        equalise_cls_loss) * self.nclasses
                    errF += equalise_cls_loss * self.class_balance

                errF.backward()
                self.optimizerF.step()

                curr_iter += 1

                # Visualization
                if i == 1:
                    vutils.save_image((src_gen.data / 2) + 0.5,
                                      '%s/visualization/source_gen_%d.png' %
                                      (self.opt.outf, epoch))
                    vutils.save_image((tgt_gen.data / 2) + 0.5,
                                      '%s/visualization/target_gen_%d.png' %
                                      (self.opt.outf, epoch))

                # Learning rate scheduling
                if self.opt.lrd:
                    self.optimizerD = utils.exp_lr_scheduler(
                        self.optimizerD, epoch, self.opt.lr, self.opt.lrd,
                        curr_iter)
                    self.optimizerF = utils.exp_lr_scheduler(
                        self.optimizerF, epoch, self.opt.lr, self.opt.lrd,
                        curr_iter)
                    self.optimizerC = utils.exp_lr_scheduler(
                        self.optimizerC, epoch, self.opt.lr, self.opt.lrd,
                        curr_iter)

            # Validate every epoch
            self.validate(epoch + 1)
def experiment(exp, arch, rnd_init, img_size, standardise_samples,
               learning_rate, pretrained_lr_factor, fix_layers, double_softmax,
               use_dropout, scale_u_range, scale_x_range, scale_y_range,
               affine_std, xlat_range, rot_std, hflip, intens_scale_range,
               colour_rot_std, colour_off_std, greyscale, img_pad_width,
               num_epochs, batch_size, seed, log_file, result_file,
               hide_progress_bar, subsetsize, subsetseed, device):
    settings = locals().copy()

    if rnd_init:
        if fix_layers != '':
            print('`rnd_init` and `fix_layers` are mutually exclusive')
            return

    import os
    import sys
    import cmdline_helpers

    fix_layers = [lyr.strip() for lyr in fix_layers.split(',')]

    if log_file == '':
        log_file = 'output_aug_log_{}.txt'.format(exp)
    elif log_file == 'none':
        log_file = None

    if log_file is not None:
        if os.path.exists(log_file):
            print('Output log file {} already exists'.format(log_file))
            return

    intens_scale_range_lower, intens_scale_range_upper = cmdline_helpers.colon_separated_range(
        intens_scale_range)
    scale_u_range = cmdline_helpers.colon_separated_range(scale_u_range)
    scale_x_range = cmdline_helpers.colon_separated_range(scale_x_range)
    scale_y_range = cmdline_helpers.colon_separated_range(scale_y_range)

    import time
    import tqdm
    import math
    import tables
    import numpy as np
    from batchup import data_source, work_pool
    import image_dataset, visda17_dataset, office_dataset
    import network_architectures
    import augmentation
    import image_transforms
    from sklearn.model_selection import StratifiedShuffleSplit
    import torch, torch.cuda
    from torch import nn
    from torch.nn import functional as F

    if hide_progress_bar:
        progress_bar = None
    else:
        progress_bar = tqdm.tqdm

    with torch.cuda.device(device):
        pool = work_pool.WorkerThreadPool(4)

        n_chn = 0
        half_batch_size = batch_size // 2

        RESNET_ARCHS = {'resnet50', 'resnet101', 'resnet152'}
        RNDINIT_ARCHS = {'vgg13_48_gp'}

        if arch == '':
            if exp in {'train_val', 'train_test'}:
                arch = 'resnet50'

        if arch in RESNET_ARCHS and not rnd_init:
            mean_value = np.array([0.485, 0.456, 0.406])
            std_value = np.array([0.229, 0.224, 0.225])
        elif arch in RNDINIT_ARCHS:
            mean_value = np.array([0.5, 0.5, 0.5])
            std_value = np.array([0.5, 0.5, 0.5])
            rnd_init = True
        else:
            mean_value = std_value = None

        img_shape = (img_size, img_size)
        img_padding = (img_pad_width, img_pad_width)

        if exp == 'visda_train_val':
            d_source = visda17_dataset.TrainDataset(img_size=img_shape,
                                                    range01=True,
                                                    rgb_order=True)
            d_target = visda17_dataset.ValidationDataset(img_size=img_shape,
                                                         range01=True,
                                                         rgb_order=True)
        elif exp == 'visda_train_test':
            d_source = visda17_dataset.TrainDataset(img_size=img_shape,
                                                    range01=True,
                                                    rgb_order=True)
            d_target = visda17_dataset.TestDataset(img_size=img_shape,
                                                   range01=True,
                                                   rgb_order=True)
        elif exp == 'office_amazon_dslr':
            d_source = office_dataset.OfficeAmazonDataset(img_size=img_shape,
                                                          range01=True,
                                                          rgb_order=True)
            d_target = office_dataset.OfficeDSLRDataset(img_size=img_shape,
                                                        range01=True,
                                                        rgb_order=True)
        elif exp == 'office_amazon_webcam':
            d_source = office_dataset.OfficeAmazonDataset(img_size=img_shape,
                                                          range01=True,
                                                          rgb_order=True)
            d_target = office_dataset.OfficeWebcamDataset(img_size=img_shape,
                                                          range01=True,
                                                          rgb_order=True)
        elif exp == 'office_dslr_amazon':
            d_source = office_dataset.OfficeDSLRDataset(img_size=img_shape,
                                                        range01=True,
                                                        rgb_order=True)
            d_target = office_dataset.OfficeAmazonDataset(img_size=img_shape,
                                                          range01=True,
                                                          rgb_order=True)
        elif exp == 'office_dslr_webcam':
            d_source = office_dataset.OfficeDSLRDataset(img_size=img_shape,
                                                        range01=True,
                                                        rgb_order=True)
            d_target = office_dataset.OfficeWebcamDataset(img_size=img_shape,
                                                          range01=True,
                                                          rgb_order=True)
        elif exp == 'office_webcam_amazon':
            d_source = office_dataset.OfficeWebcamDataset(img_size=img_shape,
                                                          range01=True,
                                                          rgb_order=True)
            d_target = office_dataset.OfficeAmazonDataset(img_size=img_shape,
                                                          range01=True,
                                                          rgb_order=True)
        elif exp == 'office_webcam_dslr':
            d_source = office_dataset.OfficeWebcamDataset(img_size=img_shape,
                                                          range01=True,
                                                          rgb_order=True)
            d_target = office_dataset.OfficeDSLRDataset(img_size=img_shape,
                                                        range01=True,
                                                        rgb_order=True)
        else:
            print('Unknown experiment type \'{}\''.format(exp))
            return

        #
        # Result file
        #

        if result_file != '':
            cmdline_helpers.ensure_containing_dir_exists(result_file)
            h5_filters = tables.Filters(complevel=9, complib='blosc')
            f_target_pred = tables.open_file(result_file, mode='w')
            g_tgt_pred = f_target_pred.create_group(f_target_pred.root,
                                                    'target_pred_y',
                                                    'Target prediction')
            arr_tgt_pred = f_target_pred.create_earray(
                g_tgt_pred,
                'y',
                tables.Float32Atom(),
                (0, len(d_target.images), d_target.n_classes),
                filters=h5_filters)
        else:
            f_target_pred = None
            g_tgt_pred = None
            arr_tgt_pred = None

        # Delete the training ground truths as we should not be using them
        # del d_target.y

        n_classes = d_source.n_classes

        print('Loaded data')

        net_class = network_architectures.get_build_fn_for_architecture(arch)

        net = net_class(n_classes, img_size, use_dropout, not rnd_init).cuda()

        if arch in RESNET_ARCHS and not rnd_init:
            named_params = list(net.named_parameters())
            new_params = []
            pretrained_params = []
            for name, param in named_params:
                if name.startswith('new_'):
                    new_params.append(param)
                else:
                    fix = False
                    for lyr in fix_layers:
                        if name.startswith(lyr + '.'):
                            fix = True
                            break
                    if not fix:
                        pretrained_params.append(param)
                    else:
                        print('Fixing param {}'.format(name))
                        param.requires_grad = False

            new_optimizer = torch.optim.Adam(new_params, lr=learning_rate)
            if len(pretrained_params) > 0:
                pretrained_optimizer = torch.optim.Adam(pretrained_params,
                                                        lr=learning_rate *
                                                        pretrained_lr_factor)
            else:
                pretrained_optimizer = None
        else:
            new_optimizer = torch.optim.Adam(net.parameters(),
                                             lr=learning_rate)
            pretrained_optimizer = None
        classification_criterion = nn.CrossEntropyLoss()

        print('Built network')

        # Image augmentation

        aug = augmentation.ImageAugmentation(
            hflip,
            xlat_range,
            affine_std,
            rot_std=rot_std,
            intens_scale_range_lower=intens_scale_range_lower,
            intens_scale_range_upper=intens_scale_range_upper,
            colour_rot_std=colour_rot_std,
            colour_off_std=colour_off_std,
            greyscale=greyscale,
            scale_u_range=scale_u_range,
            scale_x_range=scale_x_range,
            scale_y_range=scale_y_range)

        test_aug = augmentation.ImageAugmentation(hflip,
                                                  xlat_range,
                                                  0.0,
                                                  rot_std=0.0,
                                                  scale_u_range=scale_u_range,
                                                  scale_x_range=scale_x_range,
                                                  scale_y_range=scale_y_range)

        border_value = int(np.mean(mean_value) * 255 + 0.5)

        sup_xf = image_transforms.Compose(
            image_transforms.ScaleCropAndAugmentAffine(img_shape, img_padding,
                                                       True, aug, border_value,
                                                       mean_value, std_value),
            image_transforms.ToTensor(),
        )

        test_xf = image_transforms.Compose(
            image_transforms.ScaleAndCrop(img_shape, img_padding, False),
            image_transforms.ToTensor(),
            image_transforms.Standardise(mean_value, std_value),
        )

        test_xf_aug_mult = image_transforms.Compose(
            image_transforms.ScaleCropAndAugmentAffineMultiple(
                16, img_shape, img_padding, True, test_aug, border_value,
                mean_value, std_value),
            image_transforms.ToTensorMultiple(),
        )

        def augment(X_sup, y_sup):
            X_sup = sup_xf(X_sup)[0]
            return X_sup, y_sup

        _one = torch.autograd.Variable(
            torch.from_numpy(np.array([1.0]).astype(np.float32)).cuda())

        def f_train(X_sup, y_sup):
            X_sup = torch.autograd.Variable(torch.from_numpy(X_sup).cuda())
            y_sup = torch.autograd.Variable(
                torch.from_numpy(y_sup).long().cuda())

            if pretrained_optimizer is not None:
                pretrained_optimizer.zero_grad()
            new_optimizer.zero_grad()
            net.train(mode=True)

            sup_logits_out = net(X_sup)

            # Supervised classification loss
            if double_softmax:
                clf_loss = classification_criterion(F.softmax(sup_logits_out),
                                                    y_sup)
            else:
                clf_loss = classification_criterion(sup_logits_out, y_sup)

            loss_expr = clf_loss

            loss_expr.backward()
            if pretrained_optimizer is not None:
                pretrained_optimizer.step()
            new_optimizer.step()

            n_samples = X_sup.size()[0]

            return (float(clf_loss.data.cpu()[0]) * n_samples, )

        print('Compiled training function')

        def f_pred(X_sup):
            X_var = torch.autograd.Variable(torch.from_numpy(X_sup).cuda())
            net.train(mode=False)
            return F.softmax(net(X_var)).data.cpu().numpy()

        def f_pred_tgt_mult(X_sup):
            net.train(mode=False)
            y_pred_aug = []
            for aug_i in range(len(X_sup)):
                X_var = torch.autograd.Variable(
                    torch.from_numpy(X_sup[aug_i, ...]).cuda())
                y_pred = F.softmax(net(X_var)).data.cpu().numpy()
                y_pred_aug.append(y_pred[None, ...])
            y_pred_aug = np.concatenate(y_pred_aug, axis=0)
            return (y_pred_aug.mean(axis=0), )

        print('Compiled evaluation function')

        # Setup output
        def log(text):
            print(text)
            if log_file is not None:
                with open(log_file, 'a') as f:
                    f.write(text + '\n')
                    f.flush()
                    f.close()

        cmdline_helpers.ensure_containing_dir_exists(log_file)

        # Report setttings
        log('Program = {}'.format(sys.argv[0]))
        log('Settings: {}'.format(', '.join([
            '{}={}'.format(key, settings[key])
            for key in sorted(list(settings.keys()))
        ])))

        # Report dataset size
        log('Dataset:')
        print('SOURCE len(X)={}, y.shape={}'.format(len(d_source.images),
                                                    d_source.y.shape))
        print('TARGET len(X)={}'.format(len(d_target.images)))

        # Subset
        source_indices, target_indices, n_src, n_tgt = image_dataset.subset_indices(
            d_source, d_target, subsetsize, subsetseed)

        n_train_batches = n_src // batch_size + 1
        n_test_batches = n_tgt // (batch_size * 2) + 1

        print('Training...')
        train_ds = data_source.ArrayDataSource([d_source.images, d_source.y],
                                               indices=source_indices)
        train_ds = train_ds.map(augment)
        train_ds = pool.parallel_data_source(train_ds,
                                             batch_buffer_size=min(
                                                 20, n_train_batches))

        # source_test_ds = data_source.ArrayDataSource([d_source.images])
        # source_test_ds = pool.parallel_data_source(source_test_ds)
        target_ds_for_test = data_source.ArrayDataSource(
            [d_target.images], indices=target_indices)
        target_test_ds = target_ds_for_test.map(test_xf)
        target_test_ds = pool.parallel_data_source(target_test_ds,
                                                   batch_buffer_size=min(
                                                       20, n_test_batches))
        target_mult_test_ds = target_ds_for_test.map(test_xf_aug_mult)
        target_mult_test_ds = pool.parallel_data_source(target_mult_test_ds,
                                                        batch_buffer_size=min(
                                                            20,
                                                            n_test_batches))

        if seed != 0:
            shuffle_rng = np.random.RandomState(seed)
        else:
            shuffle_rng = np.random

        if d_target.has_ground_truth:
            evaluator = d_target.prediction_evaluator(target_indices)
        else:
            evaluator = None

        train_batch_iter = train_ds.batch_iterator(batch_size=batch_size,
                                                   shuffle=shuffle_rng)

        for epoch in range(num_epochs):
            t1 = time.time()

            test_batch_iter = target_test_ds.batch_iterator(
                batch_size=batch_size)

            train_clf_loss, = data_source.batch_map_mean(
                f_train,
                train_batch_iter,
                n_batches=n_train_batches,
                progress_iter_func=progress_bar)
            # train_clf_loss, train_unsup_loss, mask_rate, train_align_loss = train_ds.batch_map_mean(
            #     lambda *x: 1.0, batch_size=batch_size, shuffle=shuffle_rng, n_batches=n_train_batches,
            #     progress_iter_func=progress_bar)

            if d_target.has_ground_truth or arr_tgt_pred is not None:
                tgt_pred_prob_y, = data_source.batch_map_concat(
                    f_pred, test_batch_iter, progress_iter_func=progress_bar)
            else:
                tgt_pred_prob_y = None

            train_batch_iter = train_ds.batch_iterator(batch_size=batch_size,
                                                       shuffle=shuffle_rng)

            if d_target.has_ground_truth:
                mean_class_acc, cls_acc_str = evaluator.evaluate(
                    tgt_pred_prob_y)

                t2 = time.time()

                log('Epoch {} took {:.2f}s: TRAIN clf loss={:.6f}; '
                    'TGT mean class acc={:.3%}'.format(epoch, t2 - t1,
                                                       train_clf_loss,
                                                       mean_class_acc))
                log('  per class:  {}'.format(cls_acc_str))
            else:
                t2 = time.time()

                log('Epoch {} took {:.2f}s: TRAIN clf loss={:.6f}'.format(
                    epoch, t2 - t1, train_clf_loss))

            # Save results
            if arr_tgt_pred is not None:
                arr_tgt_pred.append(tgt_pred_prob_y[None,
                                                    ...].astype(np.float32))

        # Predict on test set, using augmentation
        tgt_aug_pred_prob_y, = target_mult_test_ds.batch_map_concat(
            f_pred_tgt_mult,
            batch_size=batch_size,
            progress_iter_func=progress_bar)
        if d_target.has_ground_truth:
            aug_mean_class_acc, aug_cls_acc_str = evaluator.evaluate(
                tgt_aug_pred_prob_y)

            log('FINAL: TGT AUG mean class acc={:.3%}'.format(
                aug_mean_class_acc))
            log('  per class:  {}'.format(aug_cls_acc_str))

        if f_target_pred is not None:
            f_target_pred.create_array(g_tgt_pred, 'y_prob', tgt_pred_prob_y)
            f_target_pred.create_array(g_tgt_pred, 'y_prob_aug',
                                       tgt_aug_pred_prob_y)
            f_target_pred.close()
예제 #8
0
def experiment(exp, scale_u_range, scale_x_range, scale_y_range, affine_std,
               xlat_range, rot_std, hflip, intens_scale_range, colour_rot_std,
               colour_off_std, greyscale, cutout_prob, cutout_size, batch_size,
               n_batches, seed):
    import os
    import sys
    import cmdline_helpers
    intens_scale_range_lower, intens_scale_range_upper = cmdline_helpers.colon_separated_range(
        intens_scale_range)
    scale_u_range = cmdline_helpers.colon_separated_range(scale_u_range)
    scale_x_range = cmdline_helpers.colon_separated_range(scale_x_range)
    scale_y_range = cmdline_helpers.colon_separated_range(scale_y_range)

    import time
    import tqdm
    import math
    import numpy as np
    from matplotlib import pyplot as plt
    from batchup import data_source, work_pool
    import visda17_dataset
    import augmentation, image_transforms
    import itertools

    n_chn = 0

    mean_value = np.array([0.485, 0.456, 0.406])
    std_value = np.array([0.229, 0.224, 0.225])

    if exp == 'train_val':
        d_source = visda17_dataset.TrainDataset(img_size=(96, 96),
                                                mean_value=mean_value,
                                                std_value=std_value,
                                                range01=True,
                                                rgb_order=True,
                                                random_crop=False)
        d_target = visda17_dataset.ValidationDataset(img_size=(96, 96),
                                                     mean_value=mean_value,
                                                     std_value=std_value,
                                                     range01=True,
                                                     rgb_order=True,
                                                     random_crop=False)
        d_target_test = visda17_dataset.ValidationDataset(
            img_size=(96, 96),
            mean_value=mean_value,
            std_value=std_value,
            range01=True,
            rgb_order=True,
            random_crop=False)
    elif exp == 'train_test':
        print('train_test experiment not supported yet')
        return
    else:
        print('Unknown experiment type \'{}\''.format(exp))
        return

    n_classes = d_source.n_classes
    n_domains = 2

    print('Loaded data')

    arch = 'show-images'

    # Image augmentation

    aug = augmentation.ImageAugmentation(
        hflip,
        xlat_range,
        affine_std,
        rot_std=rot_std,
        intens_scale_range_lower=intens_scale_range_lower,
        intens_scale_range_upper=intens_scale_range_upper,
        colour_rot_std=colour_rot_std,
        colour_off_std=colour_off_std,
        greyscale=greyscale,
        scale_u_range=scale_u_range,
        scale_x_range=scale_x_range,
        scale_y_range=scale_y_range,
        cutout_size=cutout_size,
        cutout_probability=cutout_prob)

    # Report setttings
    print('sys.argv={}'.format(sys.argv))

    # Report dataset size
    print('Dataset:')
    print('SOURCE len(X)={}, y.shape={}'.format(len(d_source.images),
                                                d_source.y.shape))
    print('TARGET len(X)={}'.format(len(d_target.images)))

    print('Building data sources...')
    source_train_ds = data_source.ArrayDataSource(
        [d_source.images, d_source.y], repeats=-1)
    target_train_ds = data_source.ArrayDataSource([d_target.images],
                                                  repeats=-1)
    train_ds = data_source.CompositeDataSource(
        [source_train_ds, target_train_ds])

    border_value = int(np.mean(mean_value) * 255 + 0.5)

    train_xf = image_transforms.Compose(
        image_transforms.ScaleCropAndAugmentAffine(
            (96, 96), (16, 16), True, aug, border_value, mean_value,
            std_value),
        image_transforms.ToTensor(),
    )

    test_xf = image_transforms.Compose(
        image_transforms.ScaleAndCrop((96, 96), (16, 16), False),
        image_transforms.ToTensor(),
        image_transforms.Standardise(mean_value, std_value),
    )

    def augment(X_sup, y_sup, X_tgt):
        X_sup = train_xf(X_sup)[0]
        X_tgt_0 = train_xf(X_tgt)[0]
        X_tgt_1 = train_xf(X_tgt)[0]
        return [X_sup, y_sup, X_tgt_0, X_tgt_1]

    train_ds = train_ds.map(augment)

    test_ds = data_source.ArrayDataSource([d_target_test.images]).map(test_xf)

    if seed != 0:
        shuffle_rng = np.random.RandomState(seed)
    else:
        shuffle_rng = np.random

    print('Showing...')

    n_shown = 0
    for (src_X, src_y, tgt_X0, tgt_X1), (te_X, ) in zip(
            train_ds.batch_iterator(batch_size=batch_size,
                                    shuffle=shuffle_rng),
            test_ds.batch_iterator(batch_size=batch_size)):
        print('Batch')
        tgt_X = np.zeros(
            (tgt_X0.shape[0] + tgt_X1.shape[0], ) + tgt_X0.shape[1:],
            dtype=np.float32)
        tgt_X[0::2] = tgt_X0
        tgt_X[1::2] = tgt_X1
        x = np.concatenate([src_X, tgt_X, te_X], axis=0)
        n = x.shape[0]
        n_sup = src_X.shape[0] + tgt_X.shape[0]
        across = int(math.ceil(math.sqrt(float(n))))
        plt.figure(figsize=(16, 16))

        for i in tqdm.tqdm(range(n)):
            plt.subplot(across, across, i + 1)
            im_x = x[i] * std_value[:, None, None] + mean_value[:, None, None]
            im_x = np.clip(im_x, 0.0, 1.0)
            plt.imshow(im_x.transpose(1, 2, 0))
            if i < src_y.shape[0]:
                plt.title(str(src_y[i]))
            elif i < n_sup:
                plt.title('target')
            else:
                plt.title('test')
        plt.show()
        n_shown += 1
        if n_shown >= n_batches:
            break