Ejemplo n.º 1
0
    def run_mnist_experiment(seed, permuted, sleep_t):
        torch.manual_seed(seed)
        if not permuted:
            train_loader, test_loader = data_loaders.load_mnist(
                batch_size_train, batch_size_test, norm_mean, norm_std)
            file_nm = '../experiments/seed' + str(
                seed) + '/mnist_forgetting.csv'
        else:
            train_loader, test_loader = data_loaders.load_permuted_mnist(
                batch_size_train, batch_size_test, norm_mean, norm_std)
            file_nm = '../experiments/seed' + str(
                seed) + '/permuted_mnist_forgetting.csv'

        print('running seed', seed, 'on', file_nm)
        nn = MNISTNet()
        train_model(nn,
                    train_loader,
                    n_epochs,
                    verbose=True,
                    less_intensive=True,
                    sleep_time=sleep_t)

        write_forgetting_events_mnist(file_nm, nn)
        print("finished writing!")
        accuracy = verify_model(nn, test_loader)
        print("final accuracy", accuracy, ", seed", seed, ", is permuted",
              permuted)
        print("seed", seed, "has", generate_forgetting_events_stats(nn))
        print("finished", seed)
Ejemplo n.º 2
0
def load_mod_dl(args):
    """

    :param args:
    :return:
    """
    if args.dataset == 'cifar10':
        imsize, in_channel, num_classes = 32, 3, 10
        train_loader, val_loader, test_loader = data_loaders.load_cifar10(
            args.batch_size,
            val_split=True,
            augmentation=args.data_augmentation,
            subset=[args.train_size, args.val_size, args.test_size])
    elif args.dataset == 'cifar100':
        imsize, in_channel, num_classes = 32, 3, 100
        train_loader, val_loader, test_loader = data_loaders.load_cifar100(
            args.batch_size,
            val_split=True,
            augmentation=args.data_augmentation,
            subset=[args.train_size, args.val_size, args.test_size])
    elif args.dataset == 'mnist':
        imsize, in_channel, num_classes = 28, 1, 10
        num_train = 50000
        train_loader, val_loader, test_loader = data_loaders.load_mnist(
            args.batch_size,
            subset=[args.train_size, args.val_size, args.test_size],
            num_train=num_train,
            only_split_train=False)

    if args.model == 'resnet18':
        cnn = ResNet18(num_classes=num_classes)
    elif args.model == 'cbr':
        cnn = CBRStudent(in_channel, num_classes)

    # This essentially does no mixup.
    mixup_mat = -100 * torch.ones([num_classes, num_classes]).cuda()

    checkpoint = None
    if args.load_checkpoint:
        checkpoint = torch.load(args.load_checkpoint)
        mixup_mat = checkpoint['mixup_grid']
        print(f"loaded mixupmat from {args.load_checkpoint}")

        if args.rand_mixup:
            # Randomise mixup grid
            rng = np.random.RandomState(args.seed)
            mixup_mat = rng.uniform(
                0.5, 1.0, (num_classes, num_classes)).astype(np.float32)
            print("Randomised the mixup mat")
        mixup_mat = torch.from_numpy(
            mixup_mat.reshape(num_classes, num_classes)).cuda()

    model = cnn.cuda()
    model.train()

    return model, mixup_mat, train_loader, val_loader, test_loader, checkpoint
Ejemplo n.º 3
0
def load_dataset(dataset):

    if dataset == "Bank":
        X_train, X_test, y_train, y_test = load_bank(test_size=test_size,
                                                     random_state=random_state)
    elif dataset == "Adult":
        X_train, X_test, y_train, y_test = load_adult(
            test_size=test_size, random_state=random_state)
    elif dataset == "Heart":
        X_train, X_test, y_train, y_test = load_heart(
            test_size=test_size, random_state=random_state)
    elif dataset == "Stroke":
        X_train, X_test, y_train, y_test = load_stroke(
            test_size=test_size, random_state=random_state)
    elif dataset == "weatherAUS":
        X_train, X_test, y_train, y_test = load_aus(test_size=test_size,
                                                    random_state=random_state)
    elif dataset == "htru2":
        X_train, X_test, y_train, y_test = load_htru2(
            test_size=test_size, random_state=random_state)
    elif dataset == "MNIST":
        X_train, X_test, y_train, y_test = load_mnist(
            test_size=test_size, random_state=random_state)
    elif dataset == "iris":
        X_train, X_test, y_train, y_test = load__iris(
            test_size=test_size, random_state=random_state)
    elif dataset == "simulated":
        X_train, X_test, y_train, y_test = load_simulated(
            n_samples,
            n_features,
            n_classes,
            test_size=test_size,
            random_state=random_state)
    else:
        ValueError("unknown dataset")

    std_scaler = StandardScaler()

    X_train = std_scaler.fit_transform(X_train)
    X_test = std_scaler.transform(X_test)

    if binary:
        y_train = 2 * y_train - 1
        y_test = 2 * y_test - 1
    for dat in [X_train, X_test, y_train, y_test]:
        dat = np.ascontiguousarray(dat)
    print("n_features : %d" % X_train.shape[1])

    return X_train, X_test, y_train, y_test
Ejemplo n.º 4
0
def load_mod_dl(args):
    """
    :param args:
    :return:
    """
    if args.dataset == 'cifar10':
        imsize, in_channel, num_classes = 32, 3, 10
        train_loader, val_loader, test_loader = data_loaders.load_cifar10(args.batch_size, val_split=True,
                                                                          augmentation=args.data_augmentation,
                                                                          subset=[args.train_size, args.val_size,
                                                                                  args.test_size])
    elif args.dataset == 'cifar100':
        imsize, in_channel, num_classes = 32, 3, 100
        train_loader, val_loader, test_loader = data_loaders.load_cifar100(args.batch_size, val_split=True,
                                                                           augmentation=args.data_augmentation,
                                                                           subset=[args.train_size, args.val_size,
                                                                                   args.test_size])
    elif args.dataset == 'mnist':
        imsize, in_channel, num_classes = 28, 1, 10
        num_train = 50000
        train_loader, val_loader, test_loader = data_loaders.load_mnist(args.batch_size,
                                                           subset=[args.train_size, args.val_size, args.test_size],
                                                           num_train=num_train, only_split_train=False)


    if args.model == 'resnet18':
        cnn = ResNet18(num_classes=num_classes, num_channels=in_channel)
    elif args.model == 'cbr':
        cnn = CBRStudent(in_channel, num_classes)
        
    mixup_mat = -1*torch.ones([num_classes,num_classes]).cuda()
    mixup_mat.requires_grad = True

    checkpoint = None
    if args.load_baseline_checkpoint:
        checkpoint = torch.load(args.load_baseline_checkpoint)
        cnn.load_state_dict(checkpoint['model_state_dict'])

    model = cnn.cuda()
    model.train()
    return model, mixup_mat, train_loader, val_loader, test_loader, checkpoint
Ejemplo n.º 5
0
def load_baseline_model(args):
    """

    :param args:
    :return:
    """
    if args.dataset == 'cifar10':
        num_classes = 10
        train_loader, val_loader, test_loader = data_loaders.load_cifar10(args.batch_size, val_split=True,
                                                                          augmentation=args.data_augmentation)
    elif args.dataset == 'cifar100':
        num_classes = 100
        train_loader, val_loader, test_loader = data_loaders.load_cifar100(args.batch_size, val_split=True,
                                                                           augmentation=args.data_augmentation)
    elif args.dataset == 'mnist':
        args.datasize, args.valsize, args.testsize = 100, 100, 100
        num_train = args.datasize
        if args.datasize == -1:
            num_train = 50000

        from data_loaders import load_mnist
        train_loader, val_loader, test_loader = load_mnist(args.batch_size,
                                                           subset=[args.datasize, args.valsize, args.testsize],
                                                           num_train=num_train)

    if args.model == 'resnet18':
        cnn = ResNet18(num_classes=num_classes)
    elif args.model == 'wideresnet':
        cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10, dropRate=0.3)

    checkpoint = None
    if args.load_baseline_checkpoint:
        checkpoint = torch.load(args.load_baseline_checkpoint)
        cnn.load_state_dict(checkpoint['model_state_dict'])

    model = cnn.cuda()
    model.train()
    return model, train_loader, val_loader, test_loader, checkpoint
def experiment(exp, arch, loss, double_softmax, confidence_thresh, rampup,
               teacher_alpha, fix_ema, unsup_weight, cls_bal_scale,
               cls_bal_scale_range, cls_balance, cls_balance_loss,
               combine_batches, learning_rate, standardise_samples,
               src_affine_std, src_xlat_range, src_hflip, src_intens_flip,
               src_intens_scale_range, src_intens_offset_range,
               src_gaussian_noise_std, tgt_affine_std, tgt_xlat_range,
               tgt_hflip, tgt_intens_flip, tgt_intens_scale_range,
               tgt_intens_offset_range, tgt_gaussian_noise_std, num_epochs,
               batch_size, epoch_size, seed, log_file, model_file, device):
    settings = locals().copy()

    import os
    import sys
    import pickle
    import cmdline_helpers

    if log_file == '':
        log_file = 'output_aug_log_{}.txt'.format(exp)
    elif log_file == 'none':
        log_file = None

    if log_file is not None:
        if os.path.exists(log_file):
            print('Output log file {} already exists'.format(log_file))
            return

    use_rampup = rampup > 0

    src_intens_scale_range_lower, src_intens_scale_range_upper, src_intens_offset_range_lower, src_intens_offset_range_upper = \
        cmdline_helpers.intens_aug_options(src_intens_scale_range, src_intens_offset_range)
    tgt_intens_scale_range_lower, tgt_intens_scale_range_upper, tgt_intens_offset_range_lower, tgt_intens_offset_range_upper = \
        cmdline_helpers.intens_aug_options(tgt_intens_scale_range, tgt_intens_offset_range)

    import time
    import math
    import numpy as np
    from batchup import data_source, work_pool
    import data_loaders
    import standardisation
    import network_architectures
    import augmentation
    import torch, torch.cuda
    from torch import nn
    from torch.nn import functional as F
    import optim_weight_ema

    torch_device = torch.device(device)
    pool = work_pool.WorkerThreadPool(2)

    n_chn = 0

    if exp == 'svhn_mnist':
        d_source = data_loaders.load_svhn(zero_centre=False, greyscale=True)
        d_target = data_loaders.load_mnist(invert=False,
                                           zero_centre=False,
                                           pad32=True,
                                           val=False)
    elif exp == 'mnist_svhn':
        d_source = data_loaders.load_mnist(invert=False,
                                           zero_centre=False,
                                           pad32=True)
        d_target = data_loaders.load_svhn(zero_centre=False,
                                          greyscale=True,
                                          val=False)
    elif exp == 'svhn_mnist_rgb':
        d_source = data_loaders.load_svhn(zero_centre=False, greyscale=False)
        d_target = data_loaders.load_mnist(invert=False,
                                           zero_centre=False,
                                           pad32=True,
                                           val=False,
                                           rgb=True)
    elif exp == 'mnist_svhn_rgb':
        d_source = data_loaders.load_mnist(invert=False,
                                           zero_centre=False,
                                           pad32=True,
                                           rgb=True)
        d_target = data_loaders.load_svhn(zero_centre=False,
                                          greyscale=False,
                                          val=False)
    elif exp == 'cifar_stl':
        d_source = data_loaders.load_cifar10(range_01=False)
        d_target = data_loaders.load_stl(zero_centre=False, val=False)
    elif exp == 'stl_cifar':
        d_source = data_loaders.load_stl(zero_centre=False)
        d_target = data_loaders.load_cifar10(range_01=False, val=False)
    elif exp == 'mnist_usps':
        d_source = data_loaders.load_mnist(zero_centre=False)
        d_target = data_loaders.load_usps(zero_centre=False,
                                          scale28=True,
                                          val=False)
    elif exp == 'usps_mnist':
        d_source = data_loaders.load_usps(zero_centre=False, scale28=True)
        d_target = data_loaders.load_mnist(zero_centre=False, val=False)
    elif exp == 'syndigits_svhn':
        d_source = data_loaders.load_syn_digits(zero_centre=False)
        d_target = data_loaders.load_svhn(zero_centre=False, val=False)
    elif exp == 'synsigns_gtsrb':
        d_source = data_loaders.load_syn_signs(zero_centre=False)
        d_target = data_loaders.load_gtsrb(zero_centre=False, val=False)
    else:
        print('Unknown experiment type \'{}\''.format(exp))
        return

    # Delete the training ground truths as we should not be using them
    del d_target.train_y

    if standardise_samples:
        standardisation.standardise_dataset(d_source)
        standardisation.standardise_dataset(d_target)

    n_classes = d_source.n_classes

    print('Loaded data')

    if arch == '':
        if exp in {'mnist_usps', 'usps_mnist'}:
            arch = 'mnist-bn-32-64-256'
        if exp in {'svhn_mnist', 'mnist_svhn'}:
            arch = 'grey-32-64-128-gp'
        if exp in {
                'cifar_stl', 'stl_cifar', 'syndigits_svhn', 'svhn_mnist_rgb',
                'mnist_svhn_rgb'
        }:
            arch = 'rgb-128-256-down-gp'
        if exp in {'synsigns_gtsrb'}:
            arch = 'rgb40-96-192-384-gp'

    net_class, expected_shape = network_architectures.get_net_and_shape_for_architecture(
        arch)

    if expected_shape != d_source.train_X.shape[1:]:
        print(
            'Architecture {} not compatible with experiment {}; it needs samples of shape {}, '
            'data has samples of shape {}'.format(arch, exp, expected_shape,
                                                  d_source.train_X.shape[1:]))
        return

    student_net = net_class(n_classes).to(torch_device)
    teacher_net = net_class(n_classes).to(torch_device)
    student_params = list(student_net.parameters())
    teacher_params = list(teacher_net.parameters())
    for param in teacher_params:
        param.requires_grad = False

    student_optimizer = torch.optim.Adam(student_params, lr=learning_rate)
    if fix_ema:
        teacher_optimizer = optim_weight_ema.EMAWeightOptimizer(
            teacher_net, student_net, alpha=teacher_alpha)
    else:
        teacher_optimizer = optim_weight_ema.OldWeightEMA(teacher_net,
                                                          student_net,
                                                          alpha=teacher_alpha)
    classification_criterion = nn.CrossEntropyLoss()

    print('Built network')

    src_aug = augmentation.ImageAugmentation(
        src_hflip,
        src_xlat_range,
        src_affine_std,
        intens_flip=src_intens_flip,
        intens_scale_range_lower=src_intens_scale_range_lower,
        intens_scale_range_upper=src_intens_scale_range_upper,
        intens_offset_range_lower=src_intens_offset_range_lower,
        intens_offset_range_upper=src_intens_offset_range_upper,
        gaussian_noise_std=src_gaussian_noise_std)
    tgt_aug = augmentation.ImageAugmentation(
        tgt_hflip,
        tgt_xlat_range,
        tgt_affine_std,
        intens_flip=tgt_intens_flip,
        intens_scale_range_lower=tgt_intens_scale_range_lower,
        intens_scale_range_upper=tgt_intens_scale_range_upper,
        intens_offset_range_lower=tgt_intens_offset_range_lower,
        intens_offset_range_upper=tgt_intens_offset_range_upper,
        gaussian_noise_std=tgt_gaussian_noise_std)

    if combine_batches:

        def augment(X_sup, y_src, X_tgt):
            X_src_stu, X_src_tea = src_aug.augment_pair(X_sup)
            X_tgt_stu, X_tgt_tea = tgt_aug.augment_pair(X_tgt)
            return X_src_stu, X_src_tea, y_src, X_tgt_stu, X_tgt_tea
    else:

        def augment(X_src, y_src, X_tgt):
            X_src = src_aug.augment(X_src)
            X_tgt_stu, X_tgt_tea = tgt_aug.augment_pair(X_tgt)
            return X_src, y_src, X_tgt_stu, X_tgt_tea

    rampup_weight_in_list = [0]

    cls_bal_fn = network_architectures.get_cls_bal_function(cls_balance_loss)

    def compute_aug_loss(stu_out, tea_out):
        # Augmentation loss
        if use_rampup:
            unsup_mask = None
            conf_mask_count = None
            unsup_mask_count = None
        else:
            conf_tea = torch.max(tea_out, 1)[0]
            unsup_mask = conf_mask = (conf_tea > confidence_thresh).float()
            unsup_mask_count = conf_mask_count = conf_mask.sum()

        if loss == 'bce':
            aug_loss = network_architectures.robust_binary_crossentropy(
                stu_out, tea_out)
        else:
            d_aug_loss = stu_out - tea_out
            aug_loss = d_aug_loss * d_aug_loss

        # Class balance scaling
        if cls_bal_scale:
            if use_rampup:
                n_samples = float(aug_loss.shape[0])
            else:
                n_samples = unsup_mask.sum()
            avg_pred = n_samples / float(n_classes)
            bal_scale = avg_pred / torch.clamp(tea_out.sum(dim=0), min=1.0)
            if cls_bal_scale_range != 0.0:
                bal_scale = torch.clamp(bal_scale,
                                        min=1.0 / cls_bal_scale_range,
                                        max=cls_bal_scale_range)
            bal_scale = bal_scale.detach()
            aug_loss = aug_loss * bal_scale[None, :]

        aug_loss = aug_loss.mean(dim=1)

        if use_rampup:
            unsup_loss = aug_loss.mean() * rampup_weight_in_list[0]
        else:
            unsup_loss = (aug_loss * unsup_mask).mean()

        # Class balance loss
        if cls_balance > 0.0:
            # Compute per-sample average predicated probability
            # Average over samples to get average class prediction
            avg_cls_prob = stu_out.mean(dim=0)
            # Compute loss
            equalise_cls_loss = cls_bal_fn(avg_cls_prob,
                                           float(1.0 / n_classes))

            equalise_cls_loss = equalise_cls_loss.mean() * n_classes

            if use_rampup:
                equalise_cls_loss = equalise_cls_loss * rampup_weight_in_list[0]
            else:
                if rampup == 0:
                    equalise_cls_loss = equalise_cls_loss * unsup_mask.mean(
                        dim=0)

            unsup_loss += equalise_cls_loss * cls_balance

        return unsup_loss, conf_mask_count, unsup_mask_count

    if combine_batches:

        def f_train(X_src0, X_src1, y_src, X_tgt0, X_tgt1):
            X_src0 = torch.tensor(X_src0,
                                  dtype=torch.float,
                                  device=torch_device)
            X_src1 = torch.tensor(X_src1,
                                  dtype=torch.float,
                                  device=torch_device)
            y_src = torch.tensor(y_src, dtype=torch.long, device=torch_device)
            X_tgt0 = torch.tensor(X_tgt0,
                                  dtype=torch.float,
                                  device=torch_device)
            X_tgt1 = torch.tensor(X_tgt1,
                                  dtype=torch.float,
                                  device=torch_device)

            n_samples = X_src0.size()[0]
            n_total = n_samples + X_tgt0.size()[0]

            student_optimizer.zero_grad()
            student_net.train()
            teacher_net.train()

            # Concatenate source and target mini-batches
            X0 = torch.cat([X_src0, X_tgt0], 0)
            X1 = torch.cat([X_src1, X_tgt1], 0)

            student_logits_out = student_net(X0)
            student_prob_out = F.softmax(student_logits_out, dim=1)

            src_logits_out = student_logits_out[:n_samples]
            src_prob_out = student_prob_out[:n_samples]

            teacher_logits_out = teacher_net(X1)
            teacher_prob_out = F.softmax(teacher_logits_out, dim=1)

            # Supervised classification loss
            if double_softmax:
                clf_loss = classification_criterion(src_prob_out, y_src)
            else:
                clf_loss = classification_criterion(src_logits_out, y_src)

            unsup_loss, conf_mask_count, unsup_mask_count = compute_aug_loss(
                student_prob_out, teacher_prob_out)

            loss_expr = clf_loss + unsup_loss * unsup_weight

            loss_expr.backward()
            student_optimizer.step()
            teacher_optimizer.step()

            outputs = [
                float(clf_loss) * n_samples,
                float(unsup_loss) * n_total
            ]
            if not use_rampup:
                mask_count = float(conf_mask_count) * 0.5
                unsup_count = float(unsup_mask_count) * 0.5

                outputs.append(mask_count)
                outputs.append(unsup_count)
            return tuple(outputs)
    else:

        def f_train(X_src, y_src, X_tgt0, X_tgt1):
            X_src = torch.tensor(X_src, dtype=torch.float, device=torch_device)
            y_src = torch.tensor(y_src, dtype=torch.long, device=torch_device)
            X_tgt0 = torch.tensor(X_tgt0,
                                  dtype=torch.float,
                                  device=torch_device)
            X_tgt1 = torch.tensor(X_tgt1,
                                  dtype=torch.float,
                                  device=torch_device)

            student_optimizer.zero_grad()
            student_net.train()
            teacher_net.train()

            src_logits_out = student_net(X_src)
            student_tgt_logits_out = student_net(X_tgt0)
            student_tgt_prob_out = F.softmax(student_tgt_logits_out, dim=1)
            teacher_tgt_logits_out = teacher_net(X_tgt1)
            teacher_tgt_prob_out = F.softmax(teacher_tgt_logits_out, dim=1)

            # Supervised classification loss
            if double_softmax:
                clf_loss = classification_criterion(
                    F.softmax(src_logits_out, dim=1), y_src)
            else:
                clf_loss = classification_criterion(src_logits_out, y_src)

            unsup_loss, conf_mask_count, unsup_mask_count = compute_aug_loss(
                student_tgt_prob_out, teacher_tgt_prob_out)

            loss_expr = clf_loss + unsup_loss * unsup_weight

            loss_expr.backward()
            student_optimizer.step()
            teacher_optimizer.step()

            n_samples = X_src.size()[0]

            outputs = [
                float(clf_loss) * n_samples,
                float(unsup_loss) * n_samples
            ]
            if not use_rampup:
                mask_count = float(conf_mask_count)
                unsup_count = float(unsup_mask_count)

                outputs.append(mask_count)
                outputs.append(unsup_count)
            return tuple(outputs)

    print('Compiled training function')

    def f_pred_src(X_sup):
        X_var = torch.tensor(X_sup, dtype=torch.float, device=torch_device)
        student_net.eval()
        teacher_net.eval()
        return (F.softmax(student_net(X_var), dim=1).detach().cpu().numpy(),
                F.softmax(teacher_net(X_var), dim=1).detach().cpu().numpy())

    def f_pred_tgt(X_sup):
        X_var = torch.tensor(X_sup, dtype=torch.float, device=torch_device)
        student_net.eval()
        teacher_net.eval()
        return (F.softmax(student_net(X_var), dim=1).detach().cpu().numpy(),
                F.softmax(teacher_net(X_var), dim=1).detach().cpu().numpy())

    def f_eval_src(X_sup, y_sup):
        y_pred_prob_stu, y_pred_prob_tea = f_pred_src(X_sup)
        y_pred_stu = np.argmax(y_pred_prob_stu, axis=1)
        y_pred_tea = np.argmax(y_pred_prob_tea, axis=1)
        return (float(
            (y_pred_stu != y_sup).sum()), float((y_pred_tea != y_sup).sum()))

    def f_eval_tgt(X_sup, y_sup):
        y_pred_prob_stu, y_pred_prob_tea = f_pred_tgt(X_sup)
        y_pred_stu = np.argmax(y_pred_prob_stu, axis=1)
        y_pred_tea = np.argmax(y_pred_prob_tea, axis=1)
        return (float(
            (y_pred_stu != y_sup).sum()), float((y_pred_tea != y_sup).sum()))

    print('Compiled evaluation function')

    # Setup output
    def log(text):
        print(text)
        if log_file is not None:
            with open(log_file, 'a') as f:
                f.write(text + '\n')
                f.flush()
                f.close()

    cmdline_helpers.ensure_containing_dir_exists(log_file)

    # Report setttings
    log('Settings: {}'.format(', '.join([
        '{}={}'.format(key, settings[key])
        for key in sorted(list(settings.keys()))
    ])))

    # Report dataset size
    log('Dataset:')
    log('SOURCE Train: X.shape={}, y.shape={}'.format(d_source.train_X.shape,
                                                      d_source.train_y.shape))
    log('SOURCE Test: X.shape={}, y.shape={}'.format(d_source.test_X.shape,
                                                     d_source.test_y.shape))
    log('TARGET Train: X.shape={}'.format(d_target.train_X.shape))
    log('TARGET Test: X.shape={}, y.shape={}'.format(d_target.test_X.shape,
                                                     d_target.test_y.shape))

    print('Training...')
    sup_ds = data_source.ArrayDataSource([d_source.train_X, d_source.train_y],
                                         repeats=-1)
    tgt_train_ds = data_source.ArrayDataSource([d_target.train_X], repeats=-1)
    train_ds = data_source.CompositeDataSource([sup_ds,
                                                tgt_train_ds]).map(augment)
    train_ds = pool.parallel_data_source(train_ds)
    if epoch_size == 'large':
        n_samples = max(d_source.train_X.shape[0], d_target.train_X.shape[0])
    elif epoch_size == 'small':
        n_samples = min(d_source.train_X.shape[0], d_target.train_X.shape[0])
    elif epoch_size == 'target':
        n_samples = d_target.train_X.shape[0]
    n_train_batches = n_samples // batch_size

    source_test_ds = data_source.ArrayDataSource(
        [d_source.test_X, d_source.test_y])
    target_test_ds = data_source.ArrayDataSource(
        [d_target.test_X, d_target.test_y])

    if seed != 0:
        shuffle_rng = np.random.RandomState(seed)
    else:
        shuffle_rng = np.random

    train_batch_iter = train_ds.batch_iterator(batch_size=batch_size,
                                               shuffle=shuffle_rng)

    best_teacher_model_state = {
        k: v.cpu().numpy()
        for k, v in teacher_net.state_dict().items()
    }

    best_conf_mask_rate = 0.0
    best_src_test_err = 1.0
    for epoch in range(num_epochs):
        t1 = time.time()

        if use_rampup:
            if epoch < rampup:
                p = max(0.0, float(epoch)) / float(rampup)
                p = 1.0 - p
                rampup_value = math.exp(-p * p * 5.0)
            else:
                rampup_value = 1.0

            rampup_weight_in_list[0] = rampup_value

        train_res = data_source.batch_map_mean(f_train,
                                               train_batch_iter,
                                               n_batches=n_train_batches)

        train_clf_loss = train_res[0]
        if combine_batches:
            unsup_loss_string = 'unsup (both) loss={:.6f}'.format(train_res[1])
        else:
            unsup_loss_string = 'unsup (tgt) loss={:.6f}'.format(train_res[1])

        src_test_err_stu, src_test_err_tea = source_test_ds.batch_map_mean(
            f_eval_src, batch_size=batch_size * 2)
        tgt_test_err_stu, tgt_test_err_tea = target_test_ds.batch_map_mean(
            f_eval_tgt, batch_size=batch_size * 2)

        if use_rampup:
            unsup_loss_string = '{}, rampup={:.3%}'.format(
                unsup_loss_string, rampup_value)
            if src_test_err_stu < best_src_test_err:
                best_src_test_err = src_test_err_stu
                best_teacher_model_state = {
                    k: v.cpu().numpy()
                    for k, v in teacher_net.state_dict().items()
                }
                improve = '*** '
            else:
                improve = ''
        else:
            conf_mask_rate = train_res[-2]
            unsup_mask_rate = train_res[-1]
            if conf_mask_rate > best_conf_mask_rate:
                best_conf_mask_rate = conf_mask_rate
                improve = '*** '
                best_teacher_model_state = {
                    k: v.cpu().numpy()
                    for k, v in teacher_net.state_dict().items()
                }
            else:
                improve = ''
            unsup_loss_string = '{}, conf mask={:.3%}, unsup mask={:.3%}'.format(
                unsup_loss_string, conf_mask_rate, unsup_mask_rate)

        t2 = time.time()

        log('{}Epoch {} took {:.2f}s: TRAIN clf loss={:.6f}, {}; '
            'SRC TEST ERR={:.3%}, TGT TEST student err={:.3%}, TGT TEST teacher err={:.3%}'
            .format(improve, epoch, t2 - t1, train_clf_loss, unsup_loss_string,
                    src_test_err_stu, tgt_test_err_stu, tgt_test_err_tea))

    # Save network
    if model_file != '':
        cmdline_helpers.ensure_containing_dir_exists(model_file)
        with open(model_file, 'wb') as f:
            torch.save(best_teacher_model_state, f)
def experiment(plot_path, ds_name, no_aug, affine_std, scale_u_range,
               scale_x_range, scale_y_range, xlat_range, hflip, intens_flip,
               intens_scale_range, intens_offset_range, grid_h, grid_w, seed):
    settings = locals().copy()

    import os
    import sys
    import cmdline_helpers

    intens_scale_range_lower, intens_scale_range_upper = cmdline_helpers.colon_separated_range(
        intens_scale_range)
    intens_offset_range_lower, intens_offset_range_upper = cmdline_helpers.colon_separated_range(
        intens_offset_range)
    scale_u_range = cmdline_helpers.colon_separated_range(scale_u_range)
    scale_x_range = cmdline_helpers.colon_separated_range(scale_x_range)
    scale_y_range = cmdline_helpers.colon_separated_range(scale_y_range)

    import numpy as np
    # from skimage.util import montage2d
    from skimage.util import montage as montage2d
    from PIL import Image
    from batchup import data_source
    import data_loaders
    import augmentation

    n_chn = 0

    if ds_name == 'mnist':
        d_source = data_loaders.load_mnist(zero_centre=False)
    elif ds_name == 'usps':
        d_source = data_loaders.load_usps(zero_centre=False, scale28=True)
    elif ds_name == 'svhn_grey':
        d_source = data_loaders.load_svhn(zero_centre=False, greyscale=True)
    elif ds_name == 'svhn':
        d_source = data_loaders.load_svhn(zero_centre=False, greyscale=False)
    elif ds_name == 'cifar':
        d_source = data_loaders.load_cifar10()
    elif ds_name == 'stl':
        d_source = data_loaders.load_stl()
    elif ds_name == 'syndigits':
        d_source = data_loaders.load_syn_digits(zero_centre=False,
                                                greyscale=False)
    elif ds_name == 'synsigns':
        d_source = data_loaders.load_syn_signs(zero_centre=False,
                                               greyscale=False)
    elif ds_name == 'gtsrb':
        d_source = data_loaders.load_gtsrb(zero_centre=False, greyscale=False)
    else:
        print('Unknown dataset \'{}\''.format(ds_name))
        return

    # Delete the training ground truths as we should not be using them
    del d_source.train_y

    n_classes = d_source.n_classes

    print('Loaded data')

    src_aug = augmentation.ImageAugmentation(
        hflip,
        xlat_range,
        affine_std,
        intens_flip=intens_flip,
        intens_scale_range_lower=intens_scale_range_lower,
        intens_scale_range_upper=intens_scale_range_upper,
        intens_offset_range_lower=intens_offset_range_lower,
        intens_offset_range_upper=intens_offset_range_upper,
        scale_u_range=scale_u_range,
        scale_x_range=scale_x_range,
        scale_y_range=scale_y_range)

    def augment(X):
        if not no_aug:
            X = src_aug.augment(X)
        return X,

    rampup_weight_in_list = [0]

    print('Rendering...')
    train_ds = data_source.ArrayDataSource([d_source.train_X],
                                           repeats=-1).map(augment)
    n_samples = len(d_source.train_X)

    if seed != 0:
        shuffle_rng = np.random.RandomState(seed)
    else:
        shuffle_rng = np.random

    batch_size = grid_h * grid_w
    display_batch_iter = train_ds.batch_iterator(batch_size=batch_size,
                                                 shuffle=shuffle_rng)

    best_src_test_err = 1.0

    x_batch, = next(display_batch_iter)

    montage = []
    for chn_i in range(x_batch.shape[1]):
        m = montage2d(x_batch[:, chn_i, :, :], grid_shape=(grid_h, grid_w))
        montage.append(m[:, :, None])
    montage = np.concatenate(montage, axis=2)

    if montage.shape[2] == 1:
        montage = montage[:, :, 0]

    lower = min(0.0, montage.min())
    upper = max(1.0, montage.max())
    montage = (montage - lower) / (upper - lower)
    montage = (np.clip(montage, 0.0, 1.0) * 255.0).astype(np.uint8)

    Image.fromarray(montage).save(plot_path)
Ejemplo n.º 8
0
args.cuda = not args.no_cuda and torch.cuda.is_available()
use_device = 'cuda:0' if args.cuda else 'cpu'

torch.manual_seed(args.seed)
np.random.seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)


def half_image_noise(image):
    image[0, 14:] = torch.randn(14, 28)
    return image


if args.dataset == 'MNIST':
    train_loader, val_loader, test_loader = data_loaders.load_mnist(
        args.batch_size, val_split=True)
    in_channel = 1
    fc_shape = 800
elif args.dataset == 'CIFAR10':
    train_loader, val_loader, test_loader = data_loaders.load_cifar10(
        args.batch_size, val_split=True)
    in_channel = 3
    fc_shape = 1250

###############################################################################
# Saving
###############################################################################
short_args = {'num_layers': 'l', 'dropout': 'drop', 'input_dropout': 'indrop'}
flags = {}
subdir = "normal_mnist"
files_used = ['mnist/train']
        default=None,
        help="The path of the saved weights. Should be specified when testing")
    parser.add_argument(
        '-w2',
        '--weights2',
        default=None,
        help="The path of the saved weights. Should be specified when testing")
    args = parser.parse_args()
    print(args)

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    # load data
    if args.dataset == 0:
        (x_train, y_train), (x_test, y_test) = load_mnist()
    elif args.dataset == 1:
        (x_train, y_train), (x_test, y_test) = load_fashion_mnist()
    elif args.dataset == 2:
        (x_train, y_train), (x_test, y_test) = load_svhn()
    # define model
    model, eval_model, manipulate_model = CapsNet(
        input_shape=x_train.shape[1:],
        n_class=len(np.unique(np.argmax(y_train, 1))),
        routings=args.routings,
        l1=args.l1)
    model.summary()

    flags = [0] * 10
    index = [0] * 10
    digits = np.where(y_test == 1)[1]
Ejemplo n.º 10
0
def experiment(exp, arch, learning_rate, standardise_samples, affine_std,
               xlat_range, hflip, intens_flip, intens_scale_range,
               intens_offset_range, gaussian_noise_std, num_epochs, batch_size,
               seed, log_file, device):
    import os
    import sys
    import cmdline_helpers

    if log_file == '':
        log_file = 'output_aug_log_{}.txt'.format(exp)
    elif log_file == 'none':
        log_file = None

    if log_file is not None:
        if os.path.exists(log_file):
            print('Output log file {} already exists'.format(log_file))
            return

    intens_scale_range_lower, intens_scale_range_upper, intens_offset_range_lower, intens_offset_range_upper = \
        cmdline_helpers.intens_aug_options(intens_scale_range, intens_offset_range)

    import time
    import math
    import numpy as np
    from batchup import data_source, work_pool
    import data_loaders
    import standardisation
    import network_architectures
    import augmentation
    import torch, torch.cuda
    from torch import nn
    from torch.nn import functional as F

    with torch.cuda.device(device):
        pool = work_pool.WorkerThreadPool(2)

        n_chn = 0

        if exp == 'svhn_mnist':
            d_source = data_loaders.load_svhn(zero_centre=False,
                                              greyscale=True)
            d_target = data_loaders.load_mnist(invert=False,
                                               zero_centre=False,
                                               pad32=True,
                                               val=False)
        elif exp == 'mnist_svhn':
            d_source = data_loaders.load_mnist(invert=False,
                                               zero_centre=False,
                                               pad32=True)
            d_target = data_loaders.load_svhn(zero_centre=False,
                                              greyscale=True,
                                              val=False)
        elif exp == 'svhn_mnist_rgb':
            d_source = data_loaders.load_svhn(zero_centre=False,
                                              greyscale=False)
            d_target = data_loaders.load_mnist(invert=False,
                                               zero_centre=False,
                                               pad32=True,
                                               val=False,
                                               rgb=True)
        elif exp == 'mnist_svhn_rgb':
            d_source = data_loaders.load_mnist(invert=False,
                                               zero_centre=False,
                                               pad32=True,
                                               rgb=True)
            d_target = data_loaders.load_svhn(zero_centre=False,
                                              greyscale=False,
                                              val=False)
        elif exp == 'cifar_stl':
            d_source = data_loaders.load_cifar10(range_01=False)
            d_target = data_loaders.load_stl(zero_centre=False, val=False)
        elif exp == 'stl_cifar':
            d_source = data_loaders.load_stl(zero_centre=False)
            d_target = data_loaders.load_cifar10(range_01=False, val=False)
        elif exp == 'mnist_usps':
            d_source = data_loaders.load_mnist(zero_centre=False)
            d_target = data_loaders.load_usps(zero_centre=False,
                                              scale28=True,
                                              val=False)
        elif exp == 'usps_mnist':
            d_source = data_loaders.load_usps(zero_centre=False, scale28=True)
            d_target = data_loaders.load_mnist(zero_centre=False, val=False)
        elif exp == 'syndigits_svhn':
            d_source = data_loaders.load_syn_digits(zero_centre=False)
            d_target = data_loaders.load_svhn(zero_centre=False, val=False)
        elif exp == 'svhn_syndigits':
            d_source = data_loaders.load_svhn(zero_centre=False, val=False)
            d_target = data_loaders.load_syn_digits(zero_centre=False)
        elif exp == 'synsigns_gtsrb':
            d_source = data_loaders.load_syn_signs(zero_centre=False)
            d_target = data_loaders.load_gtsrb(zero_centre=False, val=False)
        elif exp == 'gtsrb_synsigns':
            d_source = data_loaders.load_gtsrb(zero_centre=False, val=False)
            d_target = data_loaders.load_syn_signs(zero_centre=False)
        else:
            print('Unknown experiment type \'{}\''.format(exp))
            return

        # Delete the training ground truths as we should not be using them
        del d_target.train_y

        if standardise_samples:
            standardisation.standardise_dataset(d_source)
            standardisation.standardise_dataset(d_target)

        n_classes = d_source.n_classes

        print('Loaded data')

        if arch == '':
            if exp in {'mnist_usps', 'usps_mnist'}:
                arch = 'mnist-bn-32-64-256'
            if exp in {'svhn_mnist', 'mnist_svhn'}:
                arch = 'grey-32-64-128-gp'
            if exp in {
                    'cifar_stl', 'stl_cifar', 'syndigits_svhn',
                    'svhn_syndigits', 'svhn_mnist_rgb', 'mnist_svhn_rgb'
            }:
                arch = 'rgb-48-96-192-gp'
            if exp in {'synsigns_gtsrb', 'gtsrb_synsigns'}:
                arch = 'rgb40-48-96-192-384-gp'

        net_class, expected_shape = network_architectures.get_net_and_shape_for_architecture(
            arch)

        if expected_shape != d_source.train_X.shape[1:]:
            print(
                'Architecture {} not compatible with experiment {}; it needs samples of shape {}, '
                'data has samples of shape {}'.format(
                    arch, exp, expected_shape, d_source.train_X.shape[1:]))
            return

        net = net_class(n_classes).cuda()
        params = list(net.parameters())

        optimizer = torch.optim.Adam(params, lr=learning_rate)
        classification_criterion = nn.CrossEntropyLoss()

        print('Built network')

        aug = augmentation.ImageAugmentation(
            hflip,
            xlat_range,
            affine_std,
            intens_scale_range_lower=intens_scale_range_lower,
            intens_scale_range_upper=intens_scale_range_upper,
            intens_offset_range_lower=intens_offset_range_lower,
            intens_offset_range_upper=intens_offset_range_upper,
            intens_flip=intens_flip,
            gaussian_noise_std=gaussian_noise_std)

        def augment(X_sup, y_sup):
            X_sup = aug.augment(X_sup)
            return [X_sup, y_sup]

        def f_train(X_sup, y_sup):
            X_sup = torch.autograd.Variable(torch.from_numpy(X_sup).cuda())
            y_sup = torch.autograd.Variable(
                torch.from_numpy(y_sup).long().cuda())

            optimizer.zero_grad()
            net.train(mode=True)

            sup_logits_out = net(X_sup)

            # Supervised classification loss
            clf_loss = classification_criterion(sup_logits_out, y_sup)

            loss_expr = clf_loss

            loss_expr.backward()
            optimizer.step()

            n_samples = X_sup.size()[0]

            return float(clf_loss.data.cpu().numpy()) * n_samples

        print('Compiled training function')

        def f_pred_src(X_sup):
            X_var = torch.autograd.Variable(torch.from_numpy(X_sup).cuda())
            net.train(mode=False)
            return F.softmax(net(X_var)).data.cpu().numpy()

        def f_pred_tgt(X_sup):
            X_var = torch.autograd.Variable(torch.from_numpy(X_sup).cuda())
            net.train(mode=False)
            return F.softmax(net(X_var)).data.cpu().numpy()

        def f_eval_src(X_sup, y_sup):
            y_pred_prob = f_pred_src(X_sup)
            y_pred = np.argmax(y_pred_prob, axis=1)
            return float((y_pred != y_sup).sum())

        def f_eval_tgt(X_sup, y_sup):
            y_pred_prob = f_pred_tgt(X_sup)
            y_pred = np.argmax(y_pred_prob, axis=1)
            return float((y_pred != y_sup).sum())

        print('Compiled evaluation function')

        # Setup output
        def log(text):
            print(text)
            if log_file is not None:
                with open(log_file, 'a') as f:
                    f.write(text + '\n')
                    f.flush()
                    f.close()

        cmdline_helpers.ensure_containing_dir_exists(log_file)

        # Report setttings
        log('sys.argv={}'.format(sys.argv))

        # Report dataset size
        log('Dataset:')
        log('SOURCE Train: X.shape={}, y.shape={}'.format(
            d_source.train_X.shape, d_source.train_y.shape))
        log('SOURCE Test: X.shape={}, y.shape={}'.format(
            d_source.test_X.shape, d_source.test_y.shape))
        log('TARGET Train: X.shape={}'.format(d_target.train_X.shape))
        log('TARGET Test: X.shape={}, y.shape={}'.format(
            d_target.test_X.shape, d_target.test_y.shape))

        print('Training...')
        train_ds = data_source.ArrayDataSource(
            [d_source.train_X, d_source.train_y]).map(augment)

        source_test_ds = data_source.ArrayDataSource(
            [d_source.test_X, d_source.test_y])
        target_test_ds = data_source.ArrayDataSource(
            [d_target.test_X, d_target.test_y])

        if seed != 0:
            shuffle_rng = np.random.RandomState(seed)
        else:
            shuffle_rng = np.random

        best_src_test_err = 1.0
        for epoch in range(num_epochs):
            t1 = time.time()

            train_res = train_ds.batch_map_mean(f_train,
                                                batch_size=batch_size,
                                                shuffle=shuffle_rng)

            train_clf_loss = train_res[0]
            src_test_err, = source_test_ds.batch_map_mean(
                f_eval_src, batch_size=batch_size * 4)
            tgt_test_err, = target_test_ds.batch_map_mean(
                f_eval_tgt, batch_size=batch_size * 4)

            t2 = time.time()

            if src_test_err < best_src_test_err:
                log('*** Epoch {} took {:.2f}s: TRAIN clf loss={:.6f}; '
                    'SRC TEST ERR={:.3%}, TGT TEST err={:.3%}'.format(
                        epoch, t2 - t1, train_clf_loss, src_test_err,
                        tgt_test_err))
                best_src_test_err = src_test_err
            else:
                log('Epoch {} took {:.2f}s: TRAIN clf loss={:.6f}; '
                    'SRC TEST ERR={:.3%}, TGT TEST err={:.3%}'.format(
                        epoch, t2 - t1, train_clf_loss, src_test_err,
                        tgt_test_err))
Ejemplo n.º 11
0
def experiment(args):
    print(f"Running experiment with args: {args}")
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.train_batch_num -= 1
    args.val_batch_num -= 1
    args.eval_batch_num -= 1

    # TODO (JON): What is yaml for right now?
    yaml = YAML()
    yaml.preserve_quotes = True
    yaml.boolean_representation = ['False', 'True']

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if args.cuda: torch.cuda.manual_seed(args.seed)

    ########## Setup dataset
    # TODO (JON): Verify that the loaders are shuffling the validation / test sets.
    if args.dataset == 'MNIST':
        num_train = args.datasize
        if num_train == -1: num_train = 50000
        train_loader, val_loader, test_loader = load_mnist(
            args.batch_size,
            subset=[args.datasize, args.valsize, args.testsize],
            num_train=num_train)
        in_channel = 1
        imsize = 28
        fc_shape = 800
        num_classes = 10
    else:
        raise Exception("Must choose MNIST dataset")
    # TODO (JON): Right now we are not using the test loader for anything.  Should evaluate it occasionally.

    ##################### Setup model
    if args.model == "mlp":
        model = Net(args.num_layers,
                    args.dropout,
                    imsize,
                    in_channel,
                    args.l2,
                    num_classes=num_classes)
    else:
        raise Exception("bad model")

    hyper = init_hyper_train(args, model)  # We need this when doing all_weight
    if args.cuda:
        model = model.cuda()
        model.weight_decay = model.weight_decay.cuda()
        # model.Gaussian.dropout = model.Gaussian.dropout.cuda()

    ############ Setup Optimizer
    # TODO (JON):  Add argument for other optimizers?
    init_optimizer = torch.optim.Adam(model.parameters(),
                                      lr=args.lr)  # , momentum=0.9)
    hyper_optimizer = torch.optim.RMSprop([get_hyper_train(
        args, model)])  # , lr=args.lrh)  # try 0.1 as lr

    ############## Setup Inversion Algorithms
    KFAC_damping = 1e-2
    kfac_opt = KFACOptimizer(model, damping=KFAC_damping)  # sec_optimizer

    ########### Perform the training
    global_step = 0
    hp_k, update = 0, 0
    for epoch_h in range(0, args.hepochs + 1):
        print(f"Hyper epoch: {epoch_h}")
        if (epoch_h) % args.hyper_log_interval == 0:
            if args.hyper_train == 'opt_data':
                if args.dataset == 'MNIST':
                    save_learned(
                        get_hyper_train(args,
                                        model).reshape(args.batch_size, imsize,
                                                       imsize), True,
                        args.batch_size, args)
            elif args.hyper_train == 'various':
                print(
                    f"saturation: {torch.sigmoid(model.various[0])}, brightness: {torch.sigmoid(model.various[1])}, decay: {torch.exp(model.various[2])}"
                )
            eval_train_corr, eval_train_loss = evaluate(
                args, model, global_step, train_loader, 'train')
            # TODO (JON):  I don't know if we want normal train loss, or eval?
            eval_val_corr, eval_val_loss = evaluate(args, model, epoch_h,
                                                    val_loader, 'valid')
            eval_test_corr, eval_test_loss = evaluate(args, model, epoch_h,
                                                      test_loader, 'test')
            if args.break_perfect_val and eval_val_corr >= 0.999 and eval_train_corr >= 0.999:
                break

        min_loss = 10e8
        elementary_epochs = args.epochs
        if epoch_h == 0:
            elementary_epochs = args.init_epochs
        if True:  # epoch_h == 0:
            optimizer = init_optimizer
        # else:
        #    optimizer = sec_optimizer
        for epoch in range(1, elementary_epochs + 1):
            global_step, epoch_train_loss = train(args, model, train_loader,
                                                  optimizer, train_loss_func,
                                                  kfac_opt, epoch, global_step)
            if np.isnan(epoch_train_loss):
                print("Loss is nan, stop the loop")
                break
            elif False:  # epoch_train_loss >= min_loss:
                print(
                    f"Breaking on epoch {epoch}. train_loss = {epoch_train_loss}, min_loss = {min_loss}"
                )
                break
            min_loss = epoch_train_loss
        # if epoch_h == 0:
        #     continue

        hp_k, update = KFAC_optimize(args, model, train_loader, val_loader,
                                     hyper_optimizer, kfac_opt, KFAC_damping,
                                     epoch_h)