示例#1
0
def load_mod_dl(args):
    """

    :param args:
    :return:
    """
    if args.dataset == 'cifar10':
        imsize, in_channel, num_classes = 32, 3, 10
        train_loader, val_loader, test_loader = data_loaders.load_cifar10(
            args.batch_size,
            val_split=True,
            augmentation=args.data_augmentation,
            subset=[args.train_size, args.val_size, args.test_size])
    elif args.dataset == 'cifar100':
        imsize, in_channel, num_classes = 32, 3, 100
        train_loader, val_loader, test_loader = data_loaders.load_cifar100(
            args.batch_size,
            val_split=True,
            augmentation=args.data_augmentation,
            subset=[args.train_size, args.val_size, args.test_size])
    elif args.dataset == 'mnist':
        imsize, in_channel, num_classes = 28, 1, 10
        num_train = 50000
        train_loader, val_loader, test_loader = data_loaders.load_mnist(
            args.batch_size,
            subset=[args.train_size, args.val_size, args.test_size],
            num_train=num_train,
            only_split_train=False)

    if args.model == 'resnet18':
        cnn = ResNet18(num_classes=num_classes)
    elif args.model == 'cbr':
        cnn = CBRStudent(in_channel, num_classes)

    # This essentially does no mixup.
    mixup_mat = -100 * torch.ones([num_classes, num_classes]).cuda()

    checkpoint = None
    if args.load_checkpoint:
        checkpoint = torch.load(args.load_checkpoint)
        mixup_mat = checkpoint['mixup_grid']
        print(f"loaded mixupmat from {args.load_checkpoint}")

        if args.rand_mixup:
            # Randomise mixup grid
            rng = np.random.RandomState(args.seed)
            mixup_mat = rng.uniform(
                0.5, 1.0, (num_classes, num_classes)).astype(np.float32)
            print("Randomised the mixup mat")
        mixup_mat = torch.from_numpy(
            mixup_mat.reshape(num_classes, num_classes)).cuda()

    model = cnn.cuda()
    model.train()

    return model, mixup_mat, train_loader, val_loader, test_loader, checkpoint
示例#2
0
def load_mod_dl(args):
    """
    :param args:
    :return:
    """
    if args.dataset == 'cifar10':
        imsize, in_channel, num_classes = 32, 3, 10
        train_loader, val_loader, test_loader = data_loaders.load_cifar10(args.batch_size, val_split=True,
                                                                          augmentation=args.data_augmentation,
                                                                          subset=[args.train_size, args.val_size,
                                                                                  args.test_size])
    elif args.dataset == 'cifar100':
        imsize, in_channel, num_classes = 32, 3, 100
        train_loader, val_loader, test_loader = data_loaders.load_cifar100(args.batch_size, val_split=True,
                                                                           augmentation=args.data_augmentation,
                                                                           subset=[args.train_size, args.val_size,
                                                                                   args.test_size])
    elif args.dataset == 'mnist':
        imsize, in_channel, num_classes = 28, 1, 10
        num_train = 50000
        train_loader, val_loader, test_loader = data_loaders.load_mnist(args.batch_size,
                                                           subset=[args.train_size, args.val_size, args.test_size],
                                                           num_train=num_train, only_split_train=False)


    if args.model == 'resnet18':
        cnn = ResNet18(num_classes=num_classes, num_channels=in_channel)
    elif args.model == 'cbr':
        cnn = CBRStudent(in_channel, num_classes)
        
    mixup_mat = -1*torch.ones([num_classes,num_classes]).cuda()
    mixup_mat.requires_grad = True

    checkpoint = None
    if args.load_baseline_checkpoint:
        checkpoint = torch.load(args.load_baseline_checkpoint)
        cnn.load_state_dict(checkpoint['model_state_dict'])

    model = cnn.cuda()
    model.train()
    return model, mixup_mat, train_loader, val_loader, test_loader, checkpoint
示例#3
0
def load_baseline_model(args):
    """

    :param args:
    :return:
    """
    if args.dataset == 'cifar10':
        num_classes = 10
        train_loader, val_loader, test_loader = data_loaders.load_cifar10(args.batch_size, val_split=True,
                                                                          augmentation=args.data_augmentation)
    elif args.dataset == 'cifar100':
        num_classes = 100
        train_loader, val_loader, test_loader = data_loaders.load_cifar100(args.batch_size, val_split=True,
                                                                           augmentation=args.data_augmentation)
    elif args.dataset == 'mnist':
        args.datasize, args.valsize, args.testsize = 100, 100, 100
        num_train = args.datasize
        if args.datasize == -1:
            num_train = 50000

        from data_loaders import load_mnist
        train_loader, val_loader, test_loader = load_mnist(args.batch_size,
                                                           subset=[args.datasize, args.valsize, args.testsize],
                                                           num_train=num_train)

    if args.model == 'resnet18':
        cnn = ResNet18(num_classes=num_classes)
    elif args.model == 'wideresnet':
        cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10, dropRate=0.3)

    checkpoint = None
    if args.load_baseline_checkpoint:
        checkpoint = torch.load(args.load_baseline_checkpoint)
        cnn.load_state_dict(checkpoint['model_state_dict'])

    model = cnn.cuda()
    model.train()
    return model, train_loader, val_loader, test_loader, checkpoint
def experiment(exp, arch, loss, double_softmax, confidence_thresh, rampup,
               teacher_alpha, fix_ema, unsup_weight, cls_bal_scale,
               cls_bal_scale_range, cls_balance, cls_balance_loss,
               combine_batches, learning_rate, standardise_samples,
               src_affine_std, src_xlat_range, src_hflip, src_intens_flip,
               src_intens_scale_range, src_intens_offset_range,
               src_gaussian_noise_std, tgt_affine_std, tgt_xlat_range,
               tgt_hflip, tgt_intens_flip, tgt_intens_scale_range,
               tgt_intens_offset_range, tgt_gaussian_noise_std, num_epochs,
               batch_size, epoch_size, seed, log_file, model_file, device):
    settings = locals().copy()

    import os
    import sys
    import pickle
    import cmdline_helpers

    if log_file == '':
        log_file = 'output_aug_log_{}.txt'.format(exp)
    elif log_file == 'none':
        log_file = None

    if log_file is not None:
        if os.path.exists(log_file):
            print('Output log file {} already exists'.format(log_file))
            return

    use_rampup = rampup > 0

    src_intens_scale_range_lower, src_intens_scale_range_upper, src_intens_offset_range_lower, src_intens_offset_range_upper = \
        cmdline_helpers.intens_aug_options(src_intens_scale_range, src_intens_offset_range)
    tgt_intens_scale_range_lower, tgt_intens_scale_range_upper, tgt_intens_offset_range_lower, tgt_intens_offset_range_upper = \
        cmdline_helpers.intens_aug_options(tgt_intens_scale_range, tgt_intens_offset_range)

    import time
    import math
    import numpy as np
    from batchup import data_source, work_pool
    import data_loaders
    import standardisation
    import network_architectures
    import augmentation
    import torch, torch.cuda
    from torch import nn
    from torch.nn import functional as F
    import optim_weight_ema

    torch_device = torch.device(device)
    pool = work_pool.WorkerThreadPool(2)

    n_chn = 0

    if exp == 'svhn_mnist':
        d_source = data_loaders.load_svhn(zero_centre=False, greyscale=True)
        d_target = data_loaders.load_mnist(invert=False,
                                           zero_centre=False,
                                           pad32=True,
                                           val=False)
    elif exp == 'mnist_svhn':
        d_source = data_loaders.load_mnist(invert=False,
                                           zero_centre=False,
                                           pad32=True)
        d_target = data_loaders.load_svhn(zero_centre=False,
                                          greyscale=True,
                                          val=False)
    elif exp == 'svhn_mnist_rgb':
        d_source = data_loaders.load_svhn(zero_centre=False, greyscale=False)
        d_target = data_loaders.load_mnist(invert=False,
                                           zero_centre=False,
                                           pad32=True,
                                           val=False,
                                           rgb=True)
    elif exp == 'mnist_svhn_rgb':
        d_source = data_loaders.load_mnist(invert=False,
                                           zero_centre=False,
                                           pad32=True,
                                           rgb=True)
        d_target = data_loaders.load_svhn(zero_centre=False,
                                          greyscale=False,
                                          val=False)
    elif exp == 'cifar_stl':
        d_source = data_loaders.load_cifar10(range_01=False)
        d_target = data_loaders.load_stl(zero_centre=False, val=False)
    elif exp == 'stl_cifar':
        d_source = data_loaders.load_stl(zero_centre=False)
        d_target = data_loaders.load_cifar10(range_01=False, val=False)
    elif exp == 'mnist_usps':
        d_source = data_loaders.load_mnist(zero_centre=False)
        d_target = data_loaders.load_usps(zero_centre=False,
                                          scale28=True,
                                          val=False)
    elif exp == 'usps_mnist':
        d_source = data_loaders.load_usps(zero_centre=False, scale28=True)
        d_target = data_loaders.load_mnist(zero_centre=False, val=False)
    elif exp == 'syndigits_svhn':
        d_source = data_loaders.load_syn_digits(zero_centre=False)
        d_target = data_loaders.load_svhn(zero_centre=False, val=False)
    elif exp == 'synsigns_gtsrb':
        d_source = data_loaders.load_syn_signs(zero_centre=False)
        d_target = data_loaders.load_gtsrb(zero_centre=False, val=False)
    else:
        print('Unknown experiment type \'{}\''.format(exp))
        return

    # Delete the training ground truths as we should not be using them
    del d_target.train_y

    if standardise_samples:
        standardisation.standardise_dataset(d_source)
        standardisation.standardise_dataset(d_target)

    n_classes = d_source.n_classes

    print('Loaded data')

    if arch == '':
        if exp in {'mnist_usps', 'usps_mnist'}:
            arch = 'mnist-bn-32-64-256'
        if exp in {'svhn_mnist', 'mnist_svhn'}:
            arch = 'grey-32-64-128-gp'
        if exp in {
                'cifar_stl', 'stl_cifar', 'syndigits_svhn', 'svhn_mnist_rgb',
                'mnist_svhn_rgb'
        }:
            arch = 'rgb-128-256-down-gp'
        if exp in {'synsigns_gtsrb'}:
            arch = 'rgb40-96-192-384-gp'

    net_class, expected_shape = network_architectures.get_net_and_shape_for_architecture(
        arch)

    if expected_shape != d_source.train_X.shape[1:]:
        print(
            'Architecture {} not compatible with experiment {}; it needs samples of shape {}, '
            'data has samples of shape {}'.format(arch, exp, expected_shape,
                                                  d_source.train_X.shape[1:]))
        return

    student_net = net_class(n_classes).to(torch_device)
    teacher_net = net_class(n_classes).to(torch_device)
    student_params = list(student_net.parameters())
    teacher_params = list(teacher_net.parameters())
    for param in teacher_params:
        param.requires_grad = False

    student_optimizer = torch.optim.Adam(student_params, lr=learning_rate)
    if fix_ema:
        teacher_optimizer = optim_weight_ema.EMAWeightOptimizer(
            teacher_net, student_net, alpha=teacher_alpha)
    else:
        teacher_optimizer = optim_weight_ema.OldWeightEMA(teacher_net,
                                                          student_net,
                                                          alpha=teacher_alpha)
    classification_criterion = nn.CrossEntropyLoss()

    print('Built network')

    src_aug = augmentation.ImageAugmentation(
        src_hflip,
        src_xlat_range,
        src_affine_std,
        intens_flip=src_intens_flip,
        intens_scale_range_lower=src_intens_scale_range_lower,
        intens_scale_range_upper=src_intens_scale_range_upper,
        intens_offset_range_lower=src_intens_offset_range_lower,
        intens_offset_range_upper=src_intens_offset_range_upper,
        gaussian_noise_std=src_gaussian_noise_std)
    tgt_aug = augmentation.ImageAugmentation(
        tgt_hflip,
        tgt_xlat_range,
        tgt_affine_std,
        intens_flip=tgt_intens_flip,
        intens_scale_range_lower=tgt_intens_scale_range_lower,
        intens_scale_range_upper=tgt_intens_scale_range_upper,
        intens_offset_range_lower=tgt_intens_offset_range_lower,
        intens_offset_range_upper=tgt_intens_offset_range_upper,
        gaussian_noise_std=tgt_gaussian_noise_std)

    if combine_batches:

        def augment(X_sup, y_src, X_tgt):
            X_src_stu, X_src_tea = src_aug.augment_pair(X_sup)
            X_tgt_stu, X_tgt_tea = tgt_aug.augment_pair(X_tgt)
            return X_src_stu, X_src_tea, y_src, X_tgt_stu, X_tgt_tea
    else:

        def augment(X_src, y_src, X_tgt):
            X_src = src_aug.augment(X_src)
            X_tgt_stu, X_tgt_tea = tgt_aug.augment_pair(X_tgt)
            return X_src, y_src, X_tgt_stu, X_tgt_tea

    rampup_weight_in_list = [0]

    cls_bal_fn = network_architectures.get_cls_bal_function(cls_balance_loss)

    def compute_aug_loss(stu_out, tea_out):
        # Augmentation loss
        if use_rampup:
            unsup_mask = None
            conf_mask_count = None
            unsup_mask_count = None
        else:
            conf_tea = torch.max(tea_out, 1)[0]
            unsup_mask = conf_mask = (conf_tea > confidence_thresh).float()
            unsup_mask_count = conf_mask_count = conf_mask.sum()

        if loss == 'bce':
            aug_loss = network_architectures.robust_binary_crossentropy(
                stu_out, tea_out)
        else:
            d_aug_loss = stu_out - tea_out
            aug_loss = d_aug_loss * d_aug_loss

        # Class balance scaling
        if cls_bal_scale:
            if use_rampup:
                n_samples = float(aug_loss.shape[0])
            else:
                n_samples = unsup_mask.sum()
            avg_pred = n_samples / float(n_classes)
            bal_scale = avg_pred / torch.clamp(tea_out.sum(dim=0), min=1.0)
            if cls_bal_scale_range != 0.0:
                bal_scale = torch.clamp(bal_scale,
                                        min=1.0 / cls_bal_scale_range,
                                        max=cls_bal_scale_range)
            bal_scale = bal_scale.detach()
            aug_loss = aug_loss * bal_scale[None, :]

        aug_loss = aug_loss.mean(dim=1)

        if use_rampup:
            unsup_loss = aug_loss.mean() * rampup_weight_in_list[0]
        else:
            unsup_loss = (aug_loss * unsup_mask).mean()

        # Class balance loss
        if cls_balance > 0.0:
            # Compute per-sample average predicated probability
            # Average over samples to get average class prediction
            avg_cls_prob = stu_out.mean(dim=0)
            # Compute loss
            equalise_cls_loss = cls_bal_fn(avg_cls_prob,
                                           float(1.0 / n_classes))

            equalise_cls_loss = equalise_cls_loss.mean() * n_classes

            if use_rampup:
                equalise_cls_loss = equalise_cls_loss * rampup_weight_in_list[0]
            else:
                if rampup == 0:
                    equalise_cls_loss = equalise_cls_loss * unsup_mask.mean(
                        dim=0)

            unsup_loss += equalise_cls_loss * cls_balance

        return unsup_loss, conf_mask_count, unsup_mask_count

    if combine_batches:

        def f_train(X_src0, X_src1, y_src, X_tgt0, X_tgt1):
            X_src0 = torch.tensor(X_src0,
                                  dtype=torch.float,
                                  device=torch_device)
            X_src1 = torch.tensor(X_src1,
                                  dtype=torch.float,
                                  device=torch_device)
            y_src = torch.tensor(y_src, dtype=torch.long, device=torch_device)
            X_tgt0 = torch.tensor(X_tgt0,
                                  dtype=torch.float,
                                  device=torch_device)
            X_tgt1 = torch.tensor(X_tgt1,
                                  dtype=torch.float,
                                  device=torch_device)

            n_samples = X_src0.size()[0]
            n_total = n_samples + X_tgt0.size()[0]

            student_optimizer.zero_grad()
            student_net.train()
            teacher_net.train()

            # Concatenate source and target mini-batches
            X0 = torch.cat([X_src0, X_tgt0], 0)
            X1 = torch.cat([X_src1, X_tgt1], 0)

            student_logits_out = student_net(X0)
            student_prob_out = F.softmax(student_logits_out, dim=1)

            src_logits_out = student_logits_out[:n_samples]
            src_prob_out = student_prob_out[:n_samples]

            teacher_logits_out = teacher_net(X1)
            teacher_prob_out = F.softmax(teacher_logits_out, dim=1)

            # Supervised classification loss
            if double_softmax:
                clf_loss = classification_criterion(src_prob_out, y_src)
            else:
                clf_loss = classification_criterion(src_logits_out, y_src)

            unsup_loss, conf_mask_count, unsup_mask_count = compute_aug_loss(
                student_prob_out, teacher_prob_out)

            loss_expr = clf_loss + unsup_loss * unsup_weight

            loss_expr.backward()
            student_optimizer.step()
            teacher_optimizer.step()

            outputs = [
                float(clf_loss) * n_samples,
                float(unsup_loss) * n_total
            ]
            if not use_rampup:
                mask_count = float(conf_mask_count) * 0.5
                unsup_count = float(unsup_mask_count) * 0.5

                outputs.append(mask_count)
                outputs.append(unsup_count)
            return tuple(outputs)
    else:

        def f_train(X_src, y_src, X_tgt0, X_tgt1):
            X_src = torch.tensor(X_src, dtype=torch.float, device=torch_device)
            y_src = torch.tensor(y_src, dtype=torch.long, device=torch_device)
            X_tgt0 = torch.tensor(X_tgt0,
                                  dtype=torch.float,
                                  device=torch_device)
            X_tgt1 = torch.tensor(X_tgt1,
                                  dtype=torch.float,
                                  device=torch_device)

            student_optimizer.zero_grad()
            student_net.train()
            teacher_net.train()

            src_logits_out = student_net(X_src)
            student_tgt_logits_out = student_net(X_tgt0)
            student_tgt_prob_out = F.softmax(student_tgt_logits_out, dim=1)
            teacher_tgt_logits_out = teacher_net(X_tgt1)
            teacher_tgt_prob_out = F.softmax(teacher_tgt_logits_out, dim=1)

            # Supervised classification loss
            if double_softmax:
                clf_loss = classification_criterion(
                    F.softmax(src_logits_out, dim=1), y_src)
            else:
                clf_loss = classification_criterion(src_logits_out, y_src)

            unsup_loss, conf_mask_count, unsup_mask_count = compute_aug_loss(
                student_tgt_prob_out, teacher_tgt_prob_out)

            loss_expr = clf_loss + unsup_loss * unsup_weight

            loss_expr.backward()
            student_optimizer.step()
            teacher_optimizer.step()

            n_samples = X_src.size()[0]

            outputs = [
                float(clf_loss) * n_samples,
                float(unsup_loss) * n_samples
            ]
            if not use_rampup:
                mask_count = float(conf_mask_count)
                unsup_count = float(unsup_mask_count)

                outputs.append(mask_count)
                outputs.append(unsup_count)
            return tuple(outputs)

    print('Compiled training function')

    def f_pred_src(X_sup):
        X_var = torch.tensor(X_sup, dtype=torch.float, device=torch_device)
        student_net.eval()
        teacher_net.eval()
        return (F.softmax(student_net(X_var), dim=1).detach().cpu().numpy(),
                F.softmax(teacher_net(X_var), dim=1).detach().cpu().numpy())

    def f_pred_tgt(X_sup):
        X_var = torch.tensor(X_sup, dtype=torch.float, device=torch_device)
        student_net.eval()
        teacher_net.eval()
        return (F.softmax(student_net(X_var), dim=1).detach().cpu().numpy(),
                F.softmax(teacher_net(X_var), dim=1).detach().cpu().numpy())

    def f_eval_src(X_sup, y_sup):
        y_pred_prob_stu, y_pred_prob_tea = f_pred_src(X_sup)
        y_pred_stu = np.argmax(y_pred_prob_stu, axis=1)
        y_pred_tea = np.argmax(y_pred_prob_tea, axis=1)
        return (float(
            (y_pred_stu != y_sup).sum()), float((y_pred_tea != y_sup).sum()))

    def f_eval_tgt(X_sup, y_sup):
        y_pred_prob_stu, y_pred_prob_tea = f_pred_tgt(X_sup)
        y_pred_stu = np.argmax(y_pred_prob_stu, axis=1)
        y_pred_tea = np.argmax(y_pred_prob_tea, axis=1)
        return (float(
            (y_pred_stu != y_sup).sum()), float((y_pred_tea != y_sup).sum()))

    print('Compiled evaluation function')

    # Setup output
    def log(text):
        print(text)
        if log_file is not None:
            with open(log_file, 'a') as f:
                f.write(text + '\n')
                f.flush()
                f.close()

    cmdline_helpers.ensure_containing_dir_exists(log_file)

    # Report setttings
    log('Settings: {}'.format(', '.join([
        '{}={}'.format(key, settings[key])
        for key in sorted(list(settings.keys()))
    ])))

    # Report dataset size
    log('Dataset:')
    log('SOURCE Train: X.shape={}, y.shape={}'.format(d_source.train_X.shape,
                                                      d_source.train_y.shape))
    log('SOURCE Test: X.shape={}, y.shape={}'.format(d_source.test_X.shape,
                                                     d_source.test_y.shape))
    log('TARGET Train: X.shape={}'.format(d_target.train_X.shape))
    log('TARGET Test: X.shape={}, y.shape={}'.format(d_target.test_X.shape,
                                                     d_target.test_y.shape))

    print('Training...')
    sup_ds = data_source.ArrayDataSource([d_source.train_X, d_source.train_y],
                                         repeats=-1)
    tgt_train_ds = data_source.ArrayDataSource([d_target.train_X], repeats=-1)
    train_ds = data_source.CompositeDataSource([sup_ds,
                                                tgt_train_ds]).map(augment)
    train_ds = pool.parallel_data_source(train_ds)
    if epoch_size == 'large':
        n_samples = max(d_source.train_X.shape[0], d_target.train_X.shape[0])
    elif epoch_size == 'small':
        n_samples = min(d_source.train_X.shape[0], d_target.train_X.shape[0])
    elif epoch_size == 'target':
        n_samples = d_target.train_X.shape[0]
    n_train_batches = n_samples // batch_size

    source_test_ds = data_source.ArrayDataSource(
        [d_source.test_X, d_source.test_y])
    target_test_ds = data_source.ArrayDataSource(
        [d_target.test_X, d_target.test_y])

    if seed != 0:
        shuffle_rng = np.random.RandomState(seed)
    else:
        shuffle_rng = np.random

    train_batch_iter = train_ds.batch_iterator(batch_size=batch_size,
                                               shuffle=shuffle_rng)

    best_teacher_model_state = {
        k: v.cpu().numpy()
        for k, v in teacher_net.state_dict().items()
    }

    best_conf_mask_rate = 0.0
    best_src_test_err = 1.0
    for epoch in range(num_epochs):
        t1 = time.time()

        if use_rampup:
            if epoch < rampup:
                p = max(0.0, float(epoch)) / float(rampup)
                p = 1.0 - p
                rampup_value = math.exp(-p * p * 5.0)
            else:
                rampup_value = 1.0

            rampup_weight_in_list[0] = rampup_value

        train_res = data_source.batch_map_mean(f_train,
                                               train_batch_iter,
                                               n_batches=n_train_batches)

        train_clf_loss = train_res[0]
        if combine_batches:
            unsup_loss_string = 'unsup (both) loss={:.6f}'.format(train_res[1])
        else:
            unsup_loss_string = 'unsup (tgt) loss={:.6f}'.format(train_res[1])

        src_test_err_stu, src_test_err_tea = source_test_ds.batch_map_mean(
            f_eval_src, batch_size=batch_size * 2)
        tgt_test_err_stu, tgt_test_err_tea = target_test_ds.batch_map_mean(
            f_eval_tgt, batch_size=batch_size * 2)

        if use_rampup:
            unsup_loss_string = '{}, rampup={:.3%}'.format(
                unsup_loss_string, rampup_value)
            if src_test_err_stu < best_src_test_err:
                best_src_test_err = src_test_err_stu
                best_teacher_model_state = {
                    k: v.cpu().numpy()
                    for k, v in teacher_net.state_dict().items()
                }
                improve = '*** '
            else:
                improve = ''
        else:
            conf_mask_rate = train_res[-2]
            unsup_mask_rate = train_res[-1]
            if conf_mask_rate > best_conf_mask_rate:
                best_conf_mask_rate = conf_mask_rate
                improve = '*** '
                best_teacher_model_state = {
                    k: v.cpu().numpy()
                    for k, v in teacher_net.state_dict().items()
                }
            else:
                improve = ''
            unsup_loss_string = '{}, conf mask={:.3%}, unsup mask={:.3%}'.format(
                unsup_loss_string, conf_mask_rate, unsup_mask_rate)

        t2 = time.time()

        log('{}Epoch {} took {:.2f}s: TRAIN clf loss={:.6f}, {}; '
            'SRC TEST ERR={:.3%}, TGT TEST student err={:.3%}, TGT TEST teacher err={:.3%}'
            .format(improve, epoch, t2 - t1, train_clf_loss, unsup_loss_string,
                    src_test_err_stu, tgt_test_err_stu, tgt_test_err_tea))

    # Save network
    if model_file != '':
        cmdline_helpers.ensure_containing_dir_exists(model_file)
        with open(model_file, 'wb') as f:
            torch.save(best_teacher_model_state, f)
def experiment(plot_path, ds_name, no_aug, affine_std, scale_u_range,
               scale_x_range, scale_y_range, xlat_range, hflip, intens_flip,
               intens_scale_range, intens_offset_range, grid_h, grid_w, seed):
    settings = locals().copy()

    import os
    import sys
    import cmdline_helpers

    intens_scale_range_lower, intens_scale_range_upper = cmdline_helpers.colon_separated_range(
        intens_scale_range)
    intens_offset_range_lower, intens_offset_range_upper = cmdline_helpers.colon_separated_range(
        intens_offset_range)
    scale_u_range = cmdline_helpers.colon_separated_range(scale_u_range)
    scale_x_range = cmdline_helpers.colon_separated_range(scale_x_range)
    scale_y_range = cmdline_helpers.colon_separated_range(scale_y_range)

    import numpy as np
    # from skimage.util import montage2d
    from skimage.util import montage as montage2d
    from PIL import Image
    from batchup import data_source
    import data_loaders
    import augmentation

    n_chn = 0

    if ds_name == 'mnist':
        d_source = data_loaders.load_mnist(zero_centre=False)
    elif ds_name == 'usps':
        d_source = data_loaders.load_usps(zero_centre=False, scale28=True)
    elif ds_name == 'svhn_grey':
        d_source = data_loaders.load_svhn(zero_centre=False, greyscale=True)
    elif ds_name == 'svhn':
        d_source = data_loaders.load_svhn(zero_centre=False, greyscale=False)
    elif ds_name == 'cifar':
        d_source = data_loaders.load_cifar10()
    elif ds_name == 'stl':
        d_source = data_loaders.load_stl()
    elif ds_name == 'syndigits':
        d_source = data_loaders.load_syn_digits(zero_centre=False,
                                                greyscale=False)
    elif ds_name == 'synsigns':
        d_source = data_loaders.load_syn_signs(zero_centre=False,
                                               greyscale=False)
    elif ds_name == 'gtsrb':
        d_source = data_loaders.load_gtsrb(zero_centre=False, greyscale=False)
    else:
        print('Unknown dataset \'{}\''.format(ds_name))
        return

    # Delete the training ground truths as we should not be using them
    del d_source.train_y

    n_classes = d_source.n_classes

    print('Loaded data')

    src_aug = augmentation.ImageAugmentation(
        hflip,
        xlat_range,
        affine_std,
        intens_flip=intens_flip,
        intens_scale_range_lower=intens_scale_range_lower,
        intens_scale_range_upper=intens_scale_range_upper,
        intens_offset_range_lower=intens_offset_range_lower,
        intens_offset_range_upper=intens_offset_range_upper,
        scale_u_range=scale_u_range,
        scale_x_range=scale_x_range,
        scale_y_range=scale_y_range)

    def augment(X):
        if not no_aug:
            X = src_aug.augment(X)
        return X,

    rampup_weight_in_list = [0]

    print('Rendering...')
    train_ds = data_source.ArrayDataSource([d_source.train_X],
                                           repeats=-1).map(augment)
    n_samples = len(d_source.train_X)

    if seed != 0:
        shuffle_rng = np.random.RandomState(seed)
    else:
        shuffle_rng = np.random

    batch_size = grid_h * grid_w
    display_batch_iter = train_ds.batch_iterator(batch_size=batch_size,
                                                 shuffle=shuffle_rng)

    best_src_test_err = 1.0

    x_batch, = next(display_batch_iter)

    montage = []
    for chn_i in range(x_batch.shape[1]):
        m = montage2d(x_batch[:, chn_i, :, :], grid_shape=(grid_h, grid_w))
        montage.append(m[:, :, None])
    montage = np.concatenate(montage, axis=2)

    if montage.shape[2] == 1:
        montage = montage[:, :, 0]

    lower = min(0.0, montage.min())
    upper = max(1.0, montage.max())
    montage = (montage - lower) / (upper - lower)
    montage = (np.clip(montage, 0.0, 1.0) * 255.0).astype(np.uint8)

    Image.fromarray(montage).save(plot_path)
示例#6
0
if args.cuda:
    torch.cuda.manual_seed(args.seed)


def half_image_noise(image):
    image[0, 14:] = torch.randn(14, 28)
    return image


if args.dataset == 'MNIST':
    train_loader, val_loader, test_loader = data_loaders.load_mnist(
        args.batch_size, val_split=True)
    in_channel = 1
    fc_shape = 800
elif args.dataset == 'CIFAR10':
    train_loader, val_loader, test_loader = data_loaders.load_cifar10(
        args.batch_size, val_split=True)
    in_channel = 3
    fc_shape = 1250

###############################################################################
# Saving
###############################################################################
short_args = {'num_layers': 'l', 'dropout': 'drop', 'input_dropout': 'indrop'}
flags = {}
subdir = "normal_mnist"
files_used = ['mnist/train']

train_labels = ("global_step", "epoch", "batch", "loss")
valid_labels = ("global_step", "loss", "acc")
stats = {"train": train_labels, "valid": valid_labels}
logger = Logger(sys.argv, args, stats)
示例#7
0
def experiment(exp, arch, learning_rate, standardise_samples, affine_std,
               xlat_range, hflip, intens_flip, intens_scale_range,
               intens_offset_range, gaussian_noise_std, num_epochs, batch_size,
               seed, log_file, device):
    import os
    import sys
    import cmdline_helpers

    if log_file == '':
        log_file = 'output_aug_log_{}.txt'.format(exp)
    elif log_file == 'none':
        log_file = None

    if log_file is not None:
        if os.path.exists(log_file):
            print('Output log file {} already exists'.format(log_file))
            return

    intens_scale_range_lower, intens_scale_range_upper, intens_offset_range_lower, intens_offset_range_upper = \
        cmdline_helpers.intens_aug_options(intens_scale_range, intens_offset_range)

    import time
    import math
    import numpy as np
    from batchup import data_source, work_pool
    import data_loaders
    import standardisation
    import network_architectures
    import augmentation
    import torch, torch.cuda
    from torch import nn
    from torch.nn import functional as F

    with torch.cuda.device(device):
        pool = work_pool.WorkerThreadPool(2)

        n_chn = 0

        if exp == 'svhn_mnist':
            d_source = data_loaders.load_svhn(zero_centre=False,
                                              greyscale=True)
            d_target = data_loaders.load_mnist(invert=False,
                                               zero_centre=False,
                                               pad32=True,
                                               val=False)
        elif exp == 'mnist_svhn':
            d_source = data_loaders.load_mnist(invert=False,
                                               zero_centre=False,
                                               pad32=True)
            d_target = data_loaders.load_svhn(zero_centre=False,
                                              greyscale=True,
                                              val=False)
        elif exp == 'svhn_mnist_rgb':
            d_source = data_loaders.load_svhn(zero_centre=False,
                                              greyscale=False)
            d_target = data_loaders.load_mnist(invert=False,
                                               zero_centre=False,
                                               pad32=True,
                                               val=False,
                                               rgb=True)
        elif exp == 'mnist_svhn_rgb':
            d_source = data_loaders.load_mnist(invert=False,
                                               zero_centre=False,
                                               pad32=True,
                                               rgb=True)
            d_target = data_loaders.load_svhn(zero_centre=False,
                                              greyscale=False,
                                              val=False)
        elif exp == 'cifar_stl':
            d_source = data_loaders.load_cifar10(range_01=False)
            d_target = data_loaders.load_stl(zero_centre=False, val=False)
        elif exp == 'stl_cifar':
            d_source = data_loaders.load_stl(zero_centre=False)
            d_target = data_loaders.load_cifar10(range_01=False, val=False)
        elif exp == 'mnist_usps':
            d_source = data_loaders.load_mnist(zero_centre=False)
            d_target = data_loaders.load_usps(zero_centre=False,
                                              scale28=True,
                                              val=False)
        elif exp == 'usps_mnist':
            d_source = data_loaders.load_usps(zero_centre=False, scale28=True)
            d_target = data_loaders.load_mnist(zero_centre=False, val=False)
        elif exp == 'syndigits_svhn':
            d_source = data_loaders.load_syn_digits(zero_centre=False)
            d_target = data_loaders.load_svhn(zero_centre=False, val=False)
        elif exp == 'svhn_syndigits':
            d_source = data_loaders.load_svhn(zero_centre=False, val=False)
            d_target = data_loaders.load_syn_digits(zero_centre=False)
        elif exp == 'synsigns_gtsrb':
            d_source = data_loaders.load_syn_signs(zero_centre=False)
            d_target = data_loaders.load_gtsrb(zero_centre=False, val=False)
        elif exp == 'gtsrb_synsigns':
            d_source = data_loaders.load_gtsrb(zero_centre=False, val=False)
            d_target = data_loaders.load_syn_signs(zero_centre=False)
        else:
            print('Unknown experiment type \'{}\''.format(exp))
            return

        # Delete the training ground truths as we should not be using them
        del d_target.train_y

        if standardise_samples:
            standardisation.standardise_dataset(d_source)
            standardisation.standardise_dataset(d_target)

        n_classes = d_source.n_classes

        print('Loaded data')

        if arch == '':
            if exp in {'mnist_usps', 'usps_mnist'}:
                arch = 'mnist-bn-32-64-256'
            if exp in {'svhn_mnist', 'mnist_svhn'}:
                arch = 'grey-32-64-128-gp'
            if exp in {
                    'cifar_stl', 'stl_cifar', 'syndigits_svhn',
                    'svhn_syndigits', 'svhn_mnist_rgb', 'mnist_svhn_rgb'
            }:
                arch = 'rgb-48-96-192-gp'
            if exp in {'synsigns_gtsrb', 'gtsrb_synsigns'}:
                arch = 'rgb40-48-96-192-384-gp'

        net_class, expected_shape = network_architectures.get_net_and_shape_for_architecture(
            arch)

        if expected_shape != d_source.train_X.shape[1:]:
            print(
                'Architecture {} not compatible with experiment {}; it needs samples of shape {}, '
                'data has samples of shape {}'.format(
                    arch, exp, expected_shape, d_source.train_X.shape[1:]))
            return

        net = net_class(n_classes).cuda()
        params = list(net.parameters())

        optimizer = torch.optim.Adam(params, lr=learning_rate)
        classification_criterion = nn.CrossEntropyLoss()

        print('Built network')

        aug = augmentation.ImageAugmentation(
            hflip,
            xlat_range,
            affine_std,
            intens_scale_range_lower=intens_scale_range_lower,
            intens_scale_range_upper=intens_scale_range_upper,
            intens_offset_range_lower=intens_offset_range_lower,
            intens_offset_range_upper=intens_offset_range_upper,
            intens_flip=intens_flip,
            gaussian_noise_std=gaussian_noise_std)

        def augment(X_sup, y_sup):
            X_sup = aug.augment(X_sup)
            return [X_sup, y_sup]

        def f_train(X_sup, y_sup):
            X_sup = torch.autograd.Variable(torch.from_numpy(X_sup).cuda())
            y_sup = torch.autograd.Variable(
                torch.from_numpy(y_sup).long().cuda())

            optimizer.zero_grad()
            net.train(mode=True)

            sup_logits_out = net(X_sup)

            # Supervised classification loss
            clf_loss = classification_criterion(sup_logits_out, y_sup)

            loss_expr = clf_loss

            loss_expr.backward()
            optimizer.step()

            n_samples = X_sup.size()[0]

            return float(clf_loss.data.cpu().numpy()) * n_samples

        print('Compiled training function')

        def f_pred_src(X_sup):
            X_var = torch.autograd.Variable(torch.from_numpy(X_sup).cuda())
            net.train(mode=False)
            return F.softmax(net(X_var)).data.cpu().numpy()

        def f_pred_tgt(X_sup):
            X_var = torch.autograd.Variable(torch.from_numpy(X_sup).cuda())
            net.train(mode=False)
            return F.softmax(net(X_var)).data.cpu().numpy()

        def f_eval_src(X_sup, y_sup):
            y_pred_prob = f_pred_src(X_sup)
            y_pred = np.argmax(y_pred_prob, axis=1)
            return float((y_pred != y_sup).sum())

        def f_eval_tgt(X_sup, y_sup):
            y_pred_prob = f_pred_tgt(X_sup)
            y_pred = np.argmax(y_pred_prob, axis=1)
            return float((y_pred != y_sup).sum())

        print('Compiled evaluation function')

        # Setup output
        def log(text):
            print(text)
            if log_file is not None:
                with open(log_file, 'a') as f:
                    f.write(text + '\n')
                    f.flush()
                    f.close()

        cmdline_helpers.ensure_containing_dir_exists(log_file)

        # Report setttings
        log('sys.argv={}'.format(sys.argv))

        # Report dataset size
        log('Dataset:')
        log('SOURCE Train: X.shape={}, y.shape={}'.format(
            d_source.train_X.shape, d_source.train_y.shape))
        log('SOURCE Test: X.shape={}, y.shape={}'.format(
            d_source.test_X.shape, d_source.test_y.shape))
        log('TARGET Train: X.shape={}'.format(d_target.train_X.shape))
        log('TARGET Test: X.shape={}, y.shape={}'.format(
            d_target.test_X.shape, d_target.test_y.shape))

        print('Training...')
        train_ds = data_source.ArrayDataSource(
            [d_source.train_X, d_source.train_y]).map(augment)

        source_test_ds = data_source.ArrayDataSource(
            [d_source.test_X, d_source.test_y])
        target_test_ds = data_source.ArrayDataSource(
            [d_target.test_X, d_target.test_y])

        if seed != 0:
            shuffle_rng = np.random.RandomState(seed)
        else:
            shuffle_rng = np.random

        best_src_test_err = 1.0
        for epoch in range(num_epochs):
            t1 = time.time()

            train_res = train_ds.batch_map_mean(f_train,
                                                batch_size=batch_size,
                                                shuffle=shuffle_rng)

            train_clf_loss = train_res[0]
            src_test_err, = source_test_ds.batch_map_mean(
                f_eval_src, batch_size=batch_size * 4)
            tgt_test_err, = target_test_ds.batch_map_mean(
                f_eval_tgt, batch_size=batch_size * 4)

            t2 = time.time()

            if src_test_err < best_src_test_err:
                log('*** Epoch {} took {:.2f}s: TRAIN clf loss={:.6f}; '
                    'SRC TEST ERR={:.3%}, TGT TEST err={:.3%}'.format(
                        epoch, t2 - t1, train_clf_loss, src_test_err,
                        tgt_test_err))
                best_src_test_err = src_test_err
            else:
                log('Epoch {} took {:.2f}s: TRAIN clf loss={:.6f}; '
                    'SRC TEST ERR={:.3%}, TGT TEST err={:.3%}'.format(
                        epoch, t2 - t1, train_clf_loss, src_test_err,
                        tgt_test_err))
示例#8
0
    import os
    #import argparse

    args = get_args()
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    # load data
    if args.dataset == 0:
        (x_train, y_train), (x_test, y_test) = load_mnist()
    elif args.dataset == 1:
        (x_train, y_train), (x_test, y_test) = load_fashion_mnist()
    elif args.dataset == 2:
        (x_train, y_train), (x_test, y_test) = load_svhn()
    elif args.dataset == 3:
        (x_train, y_train), (x_test, y_test) = load_cifar10()
    elif args.dataset == 4:
        (x_train, y_train), (x_test, y_test) = load_food101()

    x_train = x_train[:args.train_num]
    y_train = y_train[:args.train_num]
    # define model
    if args.dataset != 4:
        model, eval_model, manipulate_model = CapsNet(
            input_shape=x_train.shape[1:],
            n_class=len(np.unique(np.argmax(y_train, 1))),
            routings=args.routings,
            l1=args.l1)
    else:
        model, eval_model, manipulate_model = CapsNet_for_big(
            input_shape=x_train.shape[1:],
train_transform = transforms.Compose([])
if args.data_augmentation:
    train_transform.transforms.append(transforms.RandomCrop(32, padding=4))
    train_transform.transforms.append(transforms.RandomHorizontalFlip())
train_transform.transforms.append(transforms.ToTensor())
train_transform.transforms.append(normalize)
if args.cutout:
    train_transform.transforms.append(
        Cutout(n_holes=args.n_holes, length=args.length))

test_transform = transforms.Compose([transforms.ToTensor(), normalize])

if args.dataset == 'cifar10':
    num_classes = 10
    train_loader, val_loader, test_loader = data_loaders.load_cifar10(
        args.batch_size, val_split=True, augmentation=args.data_augmentation)
elif args.dataset == 'cifar100':
    num_classes = 100
    train_loader, val_loader, test_loader = data_loaders.load_cifar100(
        args.batch_size, val_split=True, augmentation=args.data_augmentation)

if args.model == 'resnet18':
    cnn = ResNet18(num_classes=num_classes)
elif args.model == 'wideresnet':
    cnn = WideResNet(depth=28,
                     num_classes=num_classes,
                     widen_factor=10,
                     dropRate=0.3)

cnn = cnn.cuda()
criterion = nn.CrossEntropyLoss().cuda()
示例#10
0
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar_mean, std=cifar_std)
])

tf_test = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar_mean, std=cifar_std)
])

channels = 3
imsize = 32
num_classes = 10

train_loader, test_loader, init_loader = data_loaders.load_cifar10(
    args.batch_size, tf, tf_test, shuffle=False)


def run_flow(img):
    objective = torch.zeros_like(img[:, 0, 0, 0])
    objective += float(-np.log(args.n_bins) * np.prod(img.shape[1:]))
    return model(img, objective)


if args.load == '':
    print('You need to pass in a checkpoint to load with --load')
    sys.exit(0)

load = args.load

checkpoints = [
示例#11
0
def cnn_val_loss(config={}, reporter=None, callback=None, return_all=False):
    print("Starting cnn_val_loss...")

    ###############################################################################
    # Arguments
    ###############################################################################
    dataset_options = ['cifar10', 'cifar100', 'fashion']

    ## Tuning parameters: all of the dropouts
    parser = argparse.ArgumentParser(description='CNN')
    parser.add_argument('--dataset',
                        default='cifar10',
                        choices=dataset_options,
                        help='Choose a dataset (cifar10, cifar100)')
    parser.add_argument(
        '--model',
        default='resnet32',
        choices=['resnet32', 'wideresnet', 'simpleconvnet'],
        help='Choose a model (resnet32, wideresnet, simpleconvnet)')

    #### Optimization hyperparameters
    parser.add_argument('--batch_size',
                        type=int,
                        default=128,
                        help='Input batch size for training (default: 128)')
    parser.add_argument('--epochs',
                        type=int,
                        default=int(config['epochs']),
                        help='Number of epochs to train (default: 200)')
    parser.add_argument('--lr',
                        type=float,
                        default=float(config['lr']),
                        help='Learning rate')
    parser.add_argument('--momentum',
                        type=float,
                        default=float(config['momentum']),
                        help='Nesterov momentum')
    parser.add_argument('--lr_decay',
                        type=float,
                        default=float(config['lr_decay']),
                        help='Factor by which to multiply the learning rate.')

    # parser.add_argument('--weight_decay', type=float, default=float(config['weight_decay']),
    #                     help='Amount of weight decay to use.')
    # parser.add_argument('--dropout', type=float, default=config['dropout'] if 'dropout' in config else 0.0,
    #                     help='Amount of dropout for wideresnet')
    # parser.add_argument('--dropout1', type=float, default=config['dropout1'] if 'dropout1' in config else -1,
    #                     help='Amount of dropout for wideresnet')
    # parser.add_argument('--dropout2', type=float, default=config['dropout2'] if 'dropout2' in config else -1,
    #                     help='Amount of dropout for wideresnet')
    # parser.add_argument('--dropout3', type=float, default=config['dropout3'] if 'dropout3' in config else -1,
    #                     help='Amount of dropout for wideresnet')
    parser.add_argument('--dropout_type',
                        type=str,
                        default=config['dropout_type'],
                        help='Type of dropout (bernoulli or gaussian)')

    # Data augmentation hyperparameters
    parser.add_argument(
        '--inscale',
        type=float,
        default=0 if 'inscale' not in config else config['inscale'],
        help='defines input scaling factor')
    parser.add_argument('--hue',
                        type=float,
                        default=0. if 'hue' not in config else config['hue'],
                        help='hue jitter rate')
    parser.add_argument(
        '--brightness',
        type=float,
        default=0. if 'brightness' not in config else config['brightness'],
        help='brightness jitter rate')
    parser.add_argument(
        '--saturation',
        type=float,
        default=0. if 'saturation' not in config else config['saturation'],
        help='saturation jitter rate')
    parser.add_argument(
        '--contrast',
        type=float,
        default=0. if 'contrast' not in config else config['contrast'],
        help='contrast jitter rate')

    # Weight decay and dropout hyperparameters for each layer
    parser.add_argument(
        '--weight_decays',
        type=str,
        default='0.0',
        help=
        'Amount of weight decay to use for each layer, represented as a comma-separated string of floats.'
    )
    parser.add_argument(
        '--dropouts',
        type=str,
        default='0.0',
        help=
        'Dropout rates for each layer, represented as a comma-separated string of floats'
    )

    parser.add_argument(
        '--nonmono',
        '-nonm',
        type=int,
        default=60,
        help='how many previous epochs to consider for nonmonotonic criterion')
    parser.add_argument(
        '--patience',
        type=int,
        default=75,
        help=
        'How long to wait for the val loss to improve before early stopping.')

    parser.add_argument(
        '--data_augmentation',
        action='store_true',
        default=config['data_augmentation'],
        help='Augment data by cropping and horizontal flipping')

    parser.add_argument(
        '--log_interval',
        type=int,
        default=10,
        help='how many steps before logging stats from training set')
    parser.add_argument(
        '--valid_log_interval',
        type=int,
        default=50,
        help='how many steps before logging stats from validations set')
    parser.add_argument('--no_cuda',
                        action='store_true',
                        default=False,
                        help='enables CUDA training')
    parser.add_argument('--save',
                        action='store_true',
                        default=False,
                        help='whether to save current run')
    parser.add_argument('--seed',
                        type=int,
                        default=11,
                        help='random seed (default: 11)')
    parser.add_argument(
        '--save_dir',
        default=config['save_dir'],
        help=
        'subdirectory of logdir/savedir to save in (default changes to date/time)'
    )

    args, unknown = parser.parse_known_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if args.cuda else "cpu")
    cudnn.benchmark = True  # Should make training should go faster for large models

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print(args)
    sys.stdout.flush()

    # args.dropout1 = args.dropout1 if args.dropout1 != -1 else args.dropout
    # args.dropout2 = args.dropout2 if args.dropout2 != -1 else args.dropout
    # args.dropout3 = args.dropout3 if args.dropout3 != -1 else args.dropout

    ###############################################################################
    # Saving
    ###############################################################################
    timestamp = '{:%Y-%m-%d}'.format(datetime.datetime.now())
    random_hash = random.getrandbits(16)
    exp_name = '{}-dset:{}-model:{}-seed:{}-hash:{}'.format(
        timestamp, args.dataset, args.model,
        args.seed if args.seed else 'None', random_hash)

    dropout_rates = [float(value) for value in args.dropouts.split(',')]
    weight_decays = [float(value) for value in args.weight_decays.split(',')]

    # Create log folder
    BASE_SAVE_DIR = 'experiments'
    save_dir = os.path.join(BASE_SAVE_DIR, args.save_dir, exp_name)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # Check whether the result.csv file exists already
    if os.path.exists(os.path.join(save_dir, 'result.csv')):
        if not args.overwrite:
            print(
                'The result file {} exists! Run with --overwrite to overwrite this experiment.'
                .format(os.path.join(save_dir, 'result.csv')))
            sys.exit(0)

    # Save command-line arguments
    with open(os.path.join(save_dir, 'args.yaml'), 'w') as f:
        yaml.dump(vars(args), f)

    epoch_csv_logger = CSVLogger(
        fieldnames=['epoch', 'train_loss', 'train_acc', 'val_loss', 'val_acc'],
        filename=os.path.join(save_dir, 'epoch_log.csv'))

    ###############################################################################
    # Data Loading/Model/Optimizer
    ###############################################################################

    if args.dataset == 'cifar10':
        train_loader, valid_loader, test_loader = data_loaders.load_cifar10(
            args,
            args.batch_size,
            val_split=True,
            augmentation=args.data_augmentation)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_loader, valid_loader, test_loader = data_loaders.load_cifar100(
            args,
            args.batch_size,
            val_split=True,
            augmentation=args.data_augmentation)
        num_classes = 100
    elif args.dataset == 'fashion':
        train_loader, valid_loader, test_loader = data_loaders.load_fashion_mnist(
            args.batch_size, val_split=True)
        num_classes = 10

    if args.model == 'resnet32':
        cnn = resnet_cifar.resnet32(dropRates=dropout_rates)
    elif args.model == 'wideresnet':
        cnn = wide_resnet.WideResNet(depth=16,
                                     num_classes=num_classes,
                                     widen_factor=8,
                                     dropRates=dropout_rates,
                                     dropType=args.dropout_type)
        # cnn = wide_resnet.WideResNet(depth=28, num_classes=num_classes, widen_factor=10, dropRate=args.dropout)
    elif args.model == 'simpleconvnet':
        cnn = models.SimpleConvNet(dropType=args.dropout_type,
                                   conv_drop1=args.dropout1,
                                   conv_drop2=args.dropout2,
                                   fc_drop=args.dropout3)

    def optim_parameters(model):
        module_list = [
            m for m in model.modules()
            if type(m) == nn.Linear or type(m) == nn.Conv2d
        ]
        weight_decays = [1e-4] * len(module_list)
        return [{
            'params': layer.parameters(),
            'weight_decay': wdecay
        } for (layer, wdecay) in zip(module_list, weight_decays)]

    cnn = cnn.to(device)
    criterion = nn.CrossEntropyLoss()
    # cnn_optimizer = torch.optim.SGD(cnn.parameters(),
    #                                 lr=args.lr,
    #                                 momentum=args.momentum,
    #                                 nesterov=True,
    #                                 weight_decay=args.weight_decay)
    cnn_optimizer = torch.optim.SGD(optim_parameters(cnn),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    nesterov=True)

    ###############################################################################
    # Training/Evaluation
    ###############################################################################
    def evaluate(loader):
        """Returns the loss and accuracy on the entire validation/test set."""
        cnn.eval()
        correct = total = loss = 0.
        with torch.no_grad():
            for images, labels in loader:
                images, labels = images.to(device), labels.to(device)
                pred = cnn(images)
                loss += F.cross_entropy(pred, labels, reduction='sum').item()
                hard_pred = torch.max(pred, 1)[1]
                total += labels.size(0)
                correct += (hard_pred == labels).sum().item()

        accuracy = correct / total
        mean_loss = loss / total
        cnn.train()
        return mean_loss, accuracy

    epoch = 1
    global_step = 0
    patience_elapsed = 0
    stored_loss = 1e8
    best_val_loss = []
    start_time = time.time()

    # This is based on the schedule used for WideResNets. The gamma (decay factor) can also be 0.2 (= 5x decay)
    # Right now we're not using the scheduler because we use nonmonotonic lr decay (based on validation performance)
    # scheduler = MultiStepLR(cnn_optimizer, milestones=[60,120,160], gamma=args.lr_decay)

    while epoch < args.epochs + 1 and patience_elapsed < args.patience:

        running_xentropy = correct = total = 0.

        progress_bar = tqdm(train_loader)
        for i, (images, labels) in enumerate(progress_bar):
            progress_bar.set_description('Epoch ' + str(epoch))
            images, labels = images.to(device), labels.to(device)

            if args.inscale > 0:
                noise = torch.rand(images.size(0), device=device)
                scaled_noise = (
                    (1 + args.inscale) -
                    (1 / (1 + args.inscale))) * noise + (1 /
                                                         (1 + args.inscale))
                images = images * scaled_noise[:, None, None, None]

            # images = F.dropout(images, p=args.indropout, training=True)  # TODO: Incorporate input dropout
            cnn.zero_grad()
            pred = cnn(images)

            xentropy_loss = criterion(pred, labels)
            xentropy_loss.backward()
            cnn_optimizer.step()

            running_xentropy += xentropy_loss.item()

            # Calculate running average of accuracy
            _, hard_pred = torch.max(pred, 1)
            total += labels.size(0)
            correct += (hard_pred == labels).sum().item()
            accuracy = correct / float(total)

            global_step += 1
            progress_bar.set_postfix(
                xentropy='%.3f' % (running_xentropy / (i + 1)),
                acc='%.3f' % accuracy,
                lr='%.3e' % cnn_optimizer.param_groups[0]['lr'])

        val_loss, val_acc = evaluate(valid_loader)
        print('Val loss: {:6.4f} | Val acc: {:6.4f}'.format(val_loss, val_acc))
        sys.stdout.flush()
        stats = {
            'global_step': global_step,
            'time': time.time() - start_time,
            'loss': val_loss,
            'acc': val_acc
        }
        # logger.write('valid', stats)

        if (len(best_val_loss) > args.nonmono
                and val_loss > min(best_val_loss[:-args.nonmono])):
            cnn_optimizer.param_groups[0]['lr'] *= args.lr_decay
            print('Decaying the learning rate to {}'.format(
                cnn_optimizer.param_groups[0]['lr']))
            sys.stdout.flush()

        if val_loss < stored_loss:
            with open(os.path.join(save_dir, 'best_checkpoint.pt'), 'wb') as f:
                torch.save(cnn.state_dict(), f)
            print('Saving model (new best validation)')
            sys.stdout.flush()
            stored_loss = val_loss
            patience_elapsed = 0
        else:
            patience_elapsed += 1

        best_val_loss.append(val_loss)

        # scheduler.step(epoch)

        avg_xentropy = running_xentropy / (i + 1)
        train_acc = correct / float(total)

        if callback is not None:
            callback(epoch, avg_xentropy, train_acc, val_loss, val_acc, config)

        if reporter is not None:
            reporter(timesteps_total=epoch, mean_loss=val_loss)

        if cnn_optimizer.param_groups[0][
                'lr'] < 1e-7:  # Another stopping criterion based on decaying the lr
            break

        epoch += 1

        epoch_row = {
            'epoch': str(epoch),
            'train_loss': avg_xentropy,
            'train_acc': str(train_acc),
            'val_loss': str(val_loss),
            'val_acc': str(val_acc)
        }
        epoch_csv_logger.writerow(epoch_row)

    # Load best model and run on test
    with open(os.path.join(save_dir, 'best_checkpoint.pt'), 'rb') as f:
        cnn.load_state_dict(torch.load(f))

    train_loss = avg_xentropy
    train_acc = correct / float(total)

    # Run on val and test data.
    val_loss, val_acc = evaluate(valid_loader)
    test_loss, test_acc = evaluate(test_loader)

    print('=' * 89)
    print(
        '| End of training | trn loss: {:8.5f} | trn acc {:8.5f} | val loss {:8.5f} | val acc {:8.5f} | test loss {:8.5f} | test acc {:8.5f}'
        .format(train_loss, train_acc, val_loss, val_acc, test_loss, test_acc))
    print('=' * 89)
    sys.stdout.flush()

    # Save the final val and test performance to a results CSV file
    with open(os.path.join(save_dir, 'result_{}.csv'.format(time.time())),
              'w') as result_file:
        result_writer = csv.DictWriter(result_file,
                                       fieldnames=[
                                           'train_loss', 'train_acc',
                                           'val_loss', 'val_acc', 'test_loss',
                                           'test_acc'
                                       ])
        result_writer.writeheader()
        result_writer.writerow({
            'train_loss': train_loss,
            'train_acc': train_acc,
            'val_loss': val_loss,
            'val_acc': val_acc,
            'test_loss': test_loss,
            'test_acc': test_acc
        })
        result_file.flush()

    if return_all:
        print("RETURNING ", train_loss, train_acc, val_loss, val_acc,
              test_loss, test_acc)
        sys.stdout.flush()
        return train_loss, train_acc, val_loss, val_acc, test_loss, test_acc
    else:
        print("RETURNING ", stored_loss)
        sys.stdout.flush()
        return stored_loss
def experiment():
    parser = argparse.ArgumentParser(description='CNN Hyperparameter Fine-tuning')
    parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'cifar100'],
                        help='Choose a dataset')
    parser.add_argument('--model', default='resnet18', choices=['resnet18', 'wideresnet'],
                        help='Choose a model')
    parser.add_argument('--num_finetune_epochs', type=int, default=200,
                        help='Number of fine-tuning epochs')
    parser.add_argument('--lr', type=float, default=0.1,
                        help='Learning rate')
    parser.add_argument('--optimizer', type=str, default='sgdm',
                        help='Choose an optimizer')
    parser.add_argument('--batch_size', type=int, default=128,
                        help='Mini-batch size')
    parser.add_argument('--data_augmentation', action='store_true', default=True,
                        help='Whether to use data augmentation')
    parser.add_argument('--wdecay', type=float, default=5e-4,
                        help='Amount of weight decay')
    parser.add_argument('--load_checkpoint', type=str,
                        help='Path to pre-trained checkpoint to load and finetune')
    parser.add_argument('--save_dir', type=str, default='finetuned_checkpoints',
                        help='Save directory for the fine-tuned checkpoint')
    args = parser.parse_args()
    args.load_checkpoint = '/h/lorraine/PycharmProjects/CG_IFT_test/baseline_checkpoints/cifar10_resnet18_sgdm_lr0.1_wd0.0005_aug0.pt'

    if args.dataset == 'cifar10':
        num_classes = 10
        train_loader, val_loader, test_loader = data_loaders.load_cifar10(args.batch_size, val_split=True,
                                                                          augmentation=args.data_augmentation)
    elif args.dataset == 'cifar100':
        num_classes = 100
        train_loader, val_loader, test_loader = data_loaders.load_cifar100(args.batch_size, val_split=True,
                                                                           augmentation=args.data_augmentation)

    if args.model == 'resnet18':
        cnn = ResNet18(num_classes=num_classes)
    elif args.model == 'wideresnet':
        cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10, dropRate=0.3)

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    test_id = '{}_{}_{}_lr{}_wd{}_aug{}'.format(args.dataset, args.model, args.optimizer, args.lr, args.wdecay,
                                                int(args.data_augmentation))
    filename = os.path.join(args.save_dir, test_id + '.csv')
    csv_logger = CSVLogger(
        fieldnames=['epoch', 'train_loss', 'train_acc', 'val_loss', 'val_acc', 'test_loss', 'test_acc'],
        filename=filename)

    checkpoint = torch.load(args.load_checkpoint)
    init_epoch = checkpoint['epoch']
    cnn.load_state_dict(checkpoint['model_state_dict'])
    model = cnn.cuda()
    model.train()

    args.hyper_train = 'augment'  # 'all_weight'  # 'weight'

    def init_hyper_train(model):
        """

        :return:
        """
        init_hyper = None
        if args.hyper_train == 'weight':
            init_hyper = np.sqrt(args.wdecay)
            model.weight_decay = Variable(torch.FloatTensor([init_hyper]).cuda(), requires_grad=True)
            model.weight_decay = model.weight_decay.cuda()
        elif args.hyper_train == 'all_weight':
            num_p = sum(p.numel() for p in model.parameters())
            weights = np.ones(num_p) * np.sqrt(args.wdecay)
            model.weight_decay = Variable(torch.FloatTensor(weights).cuda(), requires_grad=True)
            model.weight_decay = model.weight_decay.cuda()
        model = model.cuda()
        return init_hyper

    if args.hyper_train == 'augment':  # Dont do inside the prior function, else scope is wrong
        augment_net = UNet(in_channels=3,
                           n_classes=3,
                           depth=5,
                           wf=6,
                           padding=True,
                           batch_norm=False,
                           up_mode='upconv')  # TODO(PV): Initialize UNet properly
        augment_net = augment_net.cuda()

    def get_hyper_train():
        """

        :return:
        """
        if args.hyper_train == 'weight' or args.hyper_train == 'all_weight':
            return [model.weight_decay]
        if args.hyper_train == 'augment':
            return augment_net.parameters()

    def get_hyper_train_flat():
        return torch.cat([p.view(-1) for p in get_hyper_train()])

    # TODO: Check this size

    init_hyper_train(model)

    if args.hyper_train == 'all_weight':
        wdecay = 0.0
    else:
        wdecay = args.wdecay
    optimizer = optim.SGD(model.parameters(), lr=args.lr * 0.2 * 0.2, momentum=0.9, nesterov=True,
                          weight_decay=wdecay)  # args.wdecay)
    # print(checkpoint['optimizer_state_dict'])
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler = MultiStepLR(optimizer, milestones=[60, 120], gamma=0.2)  # [60, 120, 160]
    hyper_optimizer = torch.optim.Adam(get_hyper_train(), lr=1e-3)  # try 0.1 as lr

    # Set random regularization hyperparameters
    # data_augmentation_hparams = {}  # Random values for hue, saturation, brightness, contrast, rotation, etc.
    if args.dataset == 'cifar10':
        num_classes = 10
        train_loader, val_loader, test_loader = data_loaders.load_cifar10(args.batch_size, val_split=True,
                                                                          augmentation=args.data_augmentation)
    elif args.dataset == 'cifar100':
        num_classes = 100
        train_loader, val_loader, test_loader = data_loaders.load_cifar100(args.batch_size, val_split=True,
                                                                           augmentation=args.data_augmentation)

    def test(loader):
        model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
        correct = 0.
        total = 0.
        losses = []
        for images, labels in loader:
            images = images.cuda()
            labels = labels.cuda()

            with torch.no_grad():
                pred = model(images)

            xentropy_loss = F.cross_entropy(pred, labels)
            losses.append(xentropy_loss.item())

            pred = torch.max(pred.data, 1)[1]
            total += labels.size(0)
            correct += (pred == labels).sum().item()

        avg_loss = float(np.mean(losses))
        acc = correct / total
        model.train()
        return avg_loss, acc

    def prepare_data(x, y):
        """

        :param x:
        :param y:
        :return:
        """
        x, y = x.cuda(), y.cuda()

        # x, y = Variable(x), Variable(y)
        return x, y

    def train_loss_func(x, y):
        """

        :param x:
        :param y:
        :return:
        """
        x, y = prepare_data(x, y)

        reg_loss = 0.0
        if args.hyper_train == 'weight':
            pred = model(x)
            xentropy_loss = F.cross_entropy(pred, y)
            # print(f"weight_decay: {torch.exp(model.weight_decay).shape}")
            for p in model.parameters():
                # print(f"weight_decay: {torch.exp(model.weight_decay).shape}")
                # print(f"shape: {p.shape}")
                reg_loss = reg_loss + .5 * (model.weight_decay ** 2) * torch.sum(p ** 2)
                # print(f"reg_loss: {reg_loss}")
        elif args.hyper_train == 'all_weight':
            pred = model(x)
            xentropy_loss = F.cross_entropy(pred, y)
            count = 0
            for p in model.parameters():
                reg_loss = reg_loss + .5 * torch.sum(
                    (model.weight_decay[count: count + p.numel()] ** 2) * torch.flatten(p ** 2))
                count += p.numel()
        elif args.hyper_train == 'augment':
            augmented_x = augment_net(x)
            pred = model(augmented_x)
            xentropy_loss = F.cross_entropy(pred, y)
        return xentropy_loss + reg_loss, pred

    def val_loss_func(x, y):
        """

        :param x:
        :param y:
        :return:
        """
        x, y = prepare_data(x, y)
        pred = model(x)
        xentropy_loss = F.cross_entropy(pred, y)
        return xentropy_loss

    for epoch in range(init_epoch, init_epoch + args.num_finetune_epochs):
        xentropy_loss_avg = 0.
        total_val_loss = 0.
        correct = 0.
        total = 0.

        progress_bar = tqdm(train_loader)
        for i, (images, labels) in enumerate(progress_bar):
            progress_bar.set_description('Finetune Epoch ' + str(epoch))

            # TODO: Take a hyperparameter step here
            optimizer.zero_grad(), hyper_optimizer.zero_grad()
            val_loss, weight_norm, grad_norm = hyper_step(1, 1, get_hyper_train, get_hyper_train_flat,
                                                                model, val_loss_func,
                                                                val_loader, train_loss_func, train_loader,
                                                                hyper_optimizer)
            # del val_loss
            # print(f"hyper: {get_hyper_train()}")

            images, labels = images.cuda(), labels.cuda()
            # pred = model(images)
            # xentropy_loss = F.cross_entropy(pred, labels)
            xentropy_loss, pred = train_loss_func(images, labels)

            optimizer.zero_grad(), hyper_optimizer.zero_grad()
            xentropy_loss.backward()
            optimizer.step()

            xentropy_loss_avg += xentropy_loss.item()

            # Calculate running average of accuracy
            pred = torch.max(pred.data, 1)[1]
            total += labels.size(0)
            correct += (pred == labels.data).sum().item()
            accuracy = correct / total

            progress_bar.set_postfix(
                train='%.5f' % (xentropy_loss_avg / (i + 1)),
                val='%.4f' % (total_val_loss / (i + 1)),
                acc='%.4f' % accuracy,
                weight='%.2f' % weight_norm,
                update='%.3f' % grad_norm)

        val_loss, val_acc = test(val_loader)
        test_loss, test_acc = test(test_loader)
        tqdm.write('val loss: {:6.4f} | val acc: {:6.4f} | test loss: {:6.4f} | test_acc: {:6.4f}'.format(
            val_loss, val_acc, test_loss, test_acc))

        scheduler.step(epoch)

        row = {'epoch': str(epoch),
               'train_loss': str(xentropy_loss_avg / (i + 1)), 'train_acc': str(accuracy),
               'val_loss': str(val_loss), 'val_acc': str(val_acc),
               'test_loss': str(test_loss), 'test_acc': str(test_acc)}
        csv_logger.writerow(row)