Exemple #1
0
def test_CIFAR10():
    cifar = cifar10.CIFAR10()  # just make sure we can create the class
    cifar.DOWNLOAD_IF_MISSING = False
    assert cifar.meta_const['image']['shape'] == (32, 32, 3)
    assert cifar.meta_const['image']['dtype'] == 'uint8'
    assert cifar.descr['n_classes'] == 10
    assert cifar.meta[0] == dict(id=0, label='frog', split='train')
    assert cifar.meta[49999] == dict(id=49999,
                                     label='automobile',
                                     split='train')
    assert cifar.meta[50000] == dict(id=50000, label='cat', split='test')
    assert cifar.meta[59999] == dict(id=59999, label='horse', split='test')
    assert len(cifar.meta) == 60000
def main():
    args = get_args()
    # CUDA setting
    if not torch.cuda.is_available():
        raise ValueError("Should buy GPU!")
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    device = torch.device('cuda')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.backends.cudnn.benchmark = True

    def _rescale(img):
        return img * 2.0 - 1.0

    def _noise_adder(img):
        return torch.empty_like(img, dtype=img.dtype).uniform_(0.0, 1/128.0) + img

    minority_class_labels = [0,1,2,3,4]

    train_transform = transforms.Compose([
                    transforms.Resize(64),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                    ])

    train_dataset = cifar10.CIFAR10(root='./data',
                        train=True,
                        download=True,
                        transform=train_transform,
                        minority_classes = minority_class_labels,
                        keep_ratio = 0.05)

    if args.oversample == 1:
        oversampler = createOverSampler(train_dataset)
        train_loader = cycle(torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
                                                shuffle = False, num_workers=args.num_workers,sampler=oversampler))
    else:
        train_loader = iter(torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
                                                sampler=InfiniteSamplerWrapper(train_dataset),
                                                  num_workers=args.num_workers,
                                                pin_memory=True))

    if args.calc_FID:
        # eval_dataset = datasets.ImageFolder(
        #     os.path.join(args.data_root, 'val'),
        #     transforms.Compose([
        #         transforms.ToTensor(), _rescale,
        #     ])
        # )
        # eval_loader = iter(data.DataLoader(
        #     eval_dataset, args.batch_size,
        #     sampler=InfiniteSamplerWrapper(eval_dataset),
        #     num_workers=args.num_workers, pin_memory=True)
        # )

        eval_dataset = cifar10.CIFAR10(root=args.data_root,
                                            train=False,
                                            download=True,
                                            transform=train_transform,
                                            minority_classes = None,
                                            keep_ratio = None)
        eval_loader = iter(torch.utils.data.DataLoader(eval_dataset, batch_size=args.batch_size,
                                                    sampler=InfiniteSamplerWrapper(eval_dataset),
                                                    num_workers=args.num_workers,
                                                     pin_memory=True))
        # eval_loader = cycle(torch.utils.data.DataLoader(eval_dataset, batch_size=args.batch_size,
        #                                             shuffle = False,
        #                                              num_workers=args.num_workers))


    else:
        eval_loader = None
    num_classes = len(train_dataset.classes)

    print(' prepared datasets...')
    print(' Number of training images: {}'.format(len(train_dataset)))

    # Prepare directories.
    args.num_classes = num_classes
    args, writer = prepare_results_dir(args)
    # initialize models.
    _n_cls = num_classes if args.cGAN else 0
    gen = ResNetGenerator(
        args.gen_num_features, args.gen_dim_z, args.gen_bottom_width,
        activation=F.relu, num_classes=_n_cls, distribution=args.gen_distribution
    ).to(device)
    if args.dis_arch_concat:
        dis = SNResNetConcatDiscriminator(args.dis_num_features, _n_cls, F.relu, args.dis_emb).to(device)
    else:
        dis = SNResNetProjectionDiscriminator(args.dis_num_features, _n_cls, F.relu, args.transform_space).to(device)
    inception_model = inception.InceptionV3().to(device) if args.calc_FID else None
    
    inception_model = torch.nn.DataParallel(inception_model)
    gen = torch.nn.DataParallel(gen)
    dis = torch.nn.DataParallel(dis)

    opt_gen = optim.Adam(gen.parameters(), args.lr, (args.beta1, args.beta2))
    opt_dis = optim.Adam(dis.parameters(), args.lr, (args.beta1, args.beta2))

    # gen_criterion = getattr(L, 'gen_{}'.format(args.loss_type))
    # dis_criterion = getattr(L, 'dis_{}'.format(args.loss_type))
    gen_criterion = L.GenLoss(args.loss_type, args.relativistic_loss)
    dis_criterion = L.DisLoss(args.loss_type, args.relativistic_loss)

    print(' Initialized models...\n')

    if args.args_path is not None:
        print(' Load weights...\n')
        prev_args, gen, opt_gen, dis, opt_dis = utils.resume_from_args(
            args.args_path, args.gen_ckpt_path, args.dis_ckpt_path
        )

    # Training loop
    for n_iter in tqdm.tqdm(range(1, args.max_iteration + 1)):

        if n_iter >= args.lr_decay_start:
            decay_lr(opt_gen, args.max_iteration, args.lr_decay_start, args.lr)
            decay_lr(opt_dis, args.max_iteration, args.lr_decay_start, args.lr)

        # ==================== Beginning of 1 iteration. ====================
        _l_g = .0
        cumulative_loss_dis = .0
        for i in range(args.n_dis):
            if i == 0:
                fake, pseudo_y, _ = sample_from_gen(args, device, num_classes, gen)
                dis_fake = dis(fake, pseudo_y)
                if args.relativistic_loss:
                    real, y = sample_from_data(args, device, train_loader)
                    dis_real = dis(real, y)
                else:
                    dis_real = None

                loss_gen = gen_criterion(dis_fake, dis_real)

                gen.zero_grad()
                loss_gen.backward()
                opt_gen.step()
                _l_g += loss_gen.item()
                if n_iter % 10 == 0 and writer is not None:
                    writer.add_scalar('gen', _l_g, n_iter)

            fake, pseudo_y, _ = sample_from_gen(args, device, num_classes, gen)
            real, y = sample_from_data(args, device, train_loader)

            dis_fake, dis_real = dis(fake, pseudo_y), dis(real, y)
            loss_dis = dis_criterion(dis_fake, dis_real)


            # for k,v in dis.named_parameters():
            #     if "transformer.layers.1.linear1.bias" in k:
            #         embedding_weights_a = v.clone()

                # if "block5.c2.bias" in k:
                #     embedding_weights_a = v.clone()
            # print (embedding_weights_a)
            dis.zero_grad()
            loss_dis.backward()
            opt_dis.step()
            #
            # for k,v in dis.named_parameters():
            #     if "transformer.layers.1.linear1.bias" in k:
            #         embedding_weights_b = v.clone()
            #
            #     # if "block5.c2.bias" in k:
            #     #     embedding_weights_b = v.clone()
            #
            # print (torch.equal(embedding_weights_a.data, embedding_weights_b.data))

            cumulative_loss_dis += loss_dis.item()
            if n_iter % 10 == 0 and i == args.n_dis - 1 and writer is not None:
                cumulative_loss_dis /= args.n_dis
                writer.add_scalar('dis', cumulative_loss_dis / args.n_dis, n_iter)
        # ==================== End of 1 iteration. ====================

        if n_iter % args.log_interval == 0:
            tqdm.tqdm.write(
                'iteration: {:07d}/{:07d}, loss gen: {:05f}, loss dis {:05f}'.format(
                    n_iter, args.max_iteration, _l_g, cumulative_loss_dis))
            if not args.no_image:
                writer.add_image(
                    'fake', torchvision.utils.make_grid(
                        fake, nrow=4, normalize=True, scale_each=True))
                writer.add_image(
                    'real', torchvision.utils.make_grid(
                        real, nrow=4, normalize=True, scale_each=True))
            # Save previews

            utils.save_images(
                n_iter, n_iter // args.checkpoint_interval, args.results_root,
                args.train_image_root, fake[:32], real[:32]
            )

        if n_iter % args.checkpoint_interval == 0:
            # Save checkpoints!
            utils.save_checkpoints(
                args, n_iter, n_iter // args.checkpoint_interval,
                gen, opt_gen, dis, opt_dis
            )
        if n_iter % args.eval_interval == 0:
            # TODO (crcrpar): implement Ineption score, FID, and Geometry score
            # Once these criterion are prepared, val_loader will be used.

            fid_score = evaluation.evaluate(
                args, n_iter, gen, device, inception_model, eval_loader, to_save = True
            )

            tqdm.tqdm.write(
                '[Eval] iteration: {:07d}/{:07d}, FID: {:07f}'.format(
                    n_iter, args.max_iteration, fid_score))
            if writer is not None:
                writer.add_scalar("FID", fid_score, n_iter)
                # Project embedding weights if exists.
                embedding_layer = getattr(dis, 'l_y', None)
                if embedding_layer is not None:
                    writer.add_embedding(
                        embedding_layer.weight.data,
                        list(range(args.num_classes)),
                        global_step=n_iter
                    )
    if args.test:
        shutil.rmtree(args.results_root)
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=100, metavar='N',
                        help='total batch size for training (default: 100)')
    parser.add_argument('--labeled-batch-size', type=int, default=50, metavar='N',
                        help='labeled input batch size (default: 50)')
    parser.add_argument('--n-labeled', type=int, default=3000, metavar='N',
                        help='number of labelled data (default: 3000)')
    parser.add_argument('--alpha', type=float, default=0, metavar='ALPHA',
                        help='Hyperparameter for the loss (default: 0)')
    parser.add_argument('--beta', type=float, default=1, metavar='BETA',
                        help='Hyperparameter for the distance to weight function (default: 1.0)')
    parser.add_argument('--pure', default=False, type=str2bool, metavar='BOOL',
                        help='Is the unlabelled data pure')
    parser.add_argument('--weights', default='none', choices=['encoding', 'raw', 'none'], type=str, metavar='S',
                        help='What weights to use.')
    parser.add_argument('--encoder', default=None, type=str, metavar='S',
                        help='File name for the pretrained autoencoder.')
    parser.add_argument('--output', default='default_ouput.csv', type=str, metavar='S',
                        help='File name for the output.')
    parser.add_argument('--exclude-unlabeled', default=False, type=str2bool, metavar='BOOL',
                        help='exclude unlabeled examples from the training set')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=50, metavar='N',
                        help='number of epochs to train (default: 50)')
    parser.add_argument('--lr', type=float, default=0.0001, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.0, metavar='M',
                        help='SGD momentum (default: 0.0)')
    parser.add_argument('--gamma', type=float, default=0.99, metavar='GAMMA',
                        help='Gamma for learning rate decay (default: 1.0)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=0, metavar='S',
                        help='random seed (default: 0)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--runs', type=int, default=10, metavar='N',
                        help='Number of runs')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    parser.add_argument('--validation', default=False, type=str2bool, metavar='BOOL',
                        help='Is the unlabelled data pure')

    args = parser.parse_args()
    torch.manual_seed(args.seed) # set seed for pytorch
    use_cuda = torch.cuda.is_available()

    validation_test = 'validation' if args.validation else 'test'
    wanted_classes = {0,1,2,3,4} # only care about the first 5 classes, the rest are considered as unanticipated classes
    args.num_classes = len(wanted_classes)

    folder = os.path.expanduser('./cifar10_results')
    try:
        os.makedirs(folder)
    except OSError as e:
        if e.errno == errno.EEXIST:
            pass
        else:
            raise
    output_path = os.path.join(folder, args.output)


    for seed in range(args.runs): # seed for creating labelled and unlabelled data training data.

        device = torch.device("cuda" if use_cuda else "cpu")

        kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
        
        train_dataset = cifar10.CIFAR10('../cifar10', dataset='train',
                                       weights=args.weights,
                                       encoder=args.encoder,
                                       n_labeled=args.n_labeled,
                                       wanted_classes=wanted_classes,
                                       pure=args.pure,
                                       download=True,
                                       transform=transforms.Compose([
                                       transforms.ToTensor()]),
                                       seed=seed,
                                       alpha=args.beta,
                                       func='exp')

        if args.exclude_unlabeled:
            sampler = SubsetRandomSampler(range(args.n_labeled))
            batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False)
        else:
            batch_sampler = data.TwoStreamBatchSampler(
                range(args.n_labeled, len(train_dataset)), range(args.n_labeled), args.batch_size, args.labeled_batch_size)

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_sampler=batch_sampler,
            **kwargs)

        test_loader = torch.utils.data.DataLoader(
            cifar10.CIFAR10('../cifar10', dataset=validation_test, wanted_classes=wanted_classes, transform=transforms.Compose([
                transforms.ToTensor()])),
            batch_size=args.test_batch_size, shuffle=True, **kwargs)

        model = Cifar10CNN(args.num_classes).to(device)
        # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
        optimizer = optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum)

        lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.gamma)
        for epoch in range(1, args.epochs + 1):
            alpha = utils.alpha_ramp_up(args.alpha, epoch, 10, 40) #######################################################################
            if not args.exclude_unlabeled:
                pseudo_label.assign_labels(args, model, device, train_dataset, range(args.n_labeled, len(train_dataset)))
            
            start = time.time()
            pseudo_label.train(args, model, device, train_loader, optimizer, epoch, alpha)
            print('\nTraining one epoch took: {:.4f} seconds.\n'.format(time.time()-start))
            accuracy = pseudo_label.test(args, model, device, test_loader)
            lr_scheduler.step()

        with open(output_path, 'a') as writeFile:
            writer = csv.writer(writeFile)
            writer.writerow([seed, accuracy])


        if (args.save_model):
            torch.save(model.state_dict(),"cifar10_model.pt")
Exemple #4
0
def main():
    args = get_args()
    # CUDA setting
    if not torch.cuda.is_available():
        raise ValueError("Should buy GPU!")
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    device = torch.device('cuda')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.backends.cudnn.benchmark = True

    def _rescale(img):
        return img * 2.0 - 1.0

    def _noise_adder(img):
        return torch.empty_like(img, dtype=img.dtype).uniform_(0.0,
                                                               1 / 128.0) + img

    eval_dataset = cifar10.CIFAR10(root=args.data_root,
                                   train=False,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.Resize(64),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5))
                                   ]),
                                   minority_classes=None,
                                   keep_ratio=None)
    eval_loader = iter(
        torch.utils.data.DataLoader(
            eval_dataset,
            batch_size=args.batch_size,
            sampler=InfiniteSamplerWrapper(eval_dataset),
            num_workers=args.num_workers,
            pin_memory=True))

    print(' prepared datasets...')

    # Prepare directories.
    num_classes = len(eval_dataset.classes)
    args.num_classes = num_classes

    # initialize models.
    _n_cls = num_classes if args.cGAN else 0
    gen = ResNetGenerator(args.gen_num_features,
                          args.gen_dim_z,
                          args.gen_bottom_width,
                          activation=F.relu,
                          num_classes=_n_cls,
                          distribution=args.gen_distribution).to(device)
    if args.dis_arch_concat:
        dis = SNResNetConcatDiscriminator(args.dis_num_features, _n_cls,
                                          F.relu, args.dis_emb).to(device)
    else:
        dis = SNResNetProjectionDiscriminator(args.dis_num_features, _n_cls,
                                              F.relu,
                                              args.transform_space).to(device)
    inception_model = inception.InceptionV3().to(
        device) if args.calc_FID else None

    gen = torch.nn.DataParallel(gen)
    # dis = torch.nn.DataParallel(dis)

    opt_gen = optim.Adam(gen.parameters(), args.lr, (args.beta1, args.beta2))
    opt_dis = optim.Adam(dis.parameters(), args.lr, (args.beta1, args.beta2))

    # gen_criterion = getattr(L, 'gen_{}'.format(args.loss_type))
    # dis_criterion = getattr(L, 'dis_{}'.format(args.loss_type))
    gen_criterion = L.GenLoss(args.loss_type, args.relativistic_loss)
    dis_criterion = L.DisLoss(args.loss_type, args.relativistic_loss)

    print(' Initialized models...\n')

    if args.args_path is None:
        print("Please specify weights to load")
        exit()
    else:
        print(' Load weights...\n')

        prev_args, gen, opt_gen, dis, opt_dis = utils.resume_from_args(
            args.args_path, args.gen_ckpt_path, args.dis_ckpt_path)
    args.n_fid_batches = args.n_eval_batches
    fid_score = evaluation.evaluate(args,
                                    0,
                                    gen,
                                    device,
                                    inception_model,
                                    eval_loader,
                                    to_save=False)
    print(fid_score)
Exemple #5
0
def test_meta_cache():
    a = cifar10.CIFAR10()
    b = cifar10.CIFAR10()
    assert a.meta == b.meta
Exemple #6
0
def test_latent_structure():
    cifar = cifar10.CIFAR10()  # just make sure we can create the class
    cifar.DOWNLOAD_IF_MISSING = False
    X = cifar.latent_structure_task()
    tasks.assert_latent_structure(X, 60000)
Exemple #7
0
def test_classification():
    cifar = cifar10.CIFAR10()  # just make sure we can create the class
    cifar.DOWNLOAD_IF_MISSING = False
    X, y = cifar.classification_task()
    tasks.assert_classification(X, y, 60000)
Exemple #8
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='CIFAR10 Example')
    parser.add_argument('--root',
                        type=str,
                        metavar='S',
                        help='Path to the root.')
    parser.add_argument('--init-num-labelled',
                        type=int,
                        default=None,
                        metavar='N',
                        help='Initial number of labelled examples.')
    parser.add_argument('--batch-size',
                        type=int,
                        default=100,
                        metavar='N',
                        help='total batch size for training (default: 100)')
    parser.add_argument('--init-epochs',
                        type=int,
                        metavar='N',
                        help='number of epochs to train for active learning.')
    parser.add_argument('--train-on-updated',
                        default=False,
                        type=str2bool,
                        metavar='BOOL',
                        help='Train on updated data? (default: False)')
    parser.add_argument('--active-learning',
                        default=False,
                        type=str2bool,
                        metavar='BOOL',
                        help='Run proposed active learning? (default: False)')
    parser.add_argument(
        '--skip',
        type=int,
        default=0,
        metavar='N',
        help=
        'Skip the first N epochs when computing the accumulated prediction changes.'
    )

    parser.add_argument('--test-batch-size',
                        type=int,
                        default=500,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        metavar='N',
                        help='number of epochs to train.')
    parser.add_argument('--lr',
                        type=float,
                        default=0.1,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.0,
                        metavar='M',
                        help='SGD momentum (default: 0.0)')
    parser.add_argument('--seed',
                        type=int,
                        metavar='S',
                        help='Seed for random number generator.')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=1,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--num-workers',
                        type=int,
                        default=1,
                        metavar='N',
                        help='Number of workers for dataloader (default: 1)')

    parser.add_argument('--num-to-sample',
                        type=int,
                        metavar='N',
                        help='Number of unlabelled exmples to be sampled')
    parser.add_argument(
        '--validate',
        default=False,
        type=str2bool,
        metavar='BOOL',
        help='Use validation set instead of test set? (default: False)')
    parser.add_argument('--output',
                        default='default_ouput.csv',
                        type=str,
                        metavar='S',
                        help='File name for the output.')

    args = parser.parse_args()
    torch.manual_seed(args.seed)  # set seed for pytorch
    use_cuda = torch.cuda.is_available()

    args.num_classes = 10

    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {
        'num_workers': args.num_workers,
        'pin_memory': True
    } if use_cuda else {}

    ##################### Active learning sampling using prediction change (fluctuation) ###############################
    if args.active_learning:
        train_dataset = cifar10.CIFAR10(root=args.root,
                                        dataset='train',
                                        init_n_labeled=args.init_num_labelled,
                                        seed=args.seed,
                                        download=True,
                                        transform=transforms.Compose(
                                            [transforms.ToTensor()]),
                                        target_transform=None,
                                        indices_name=None)  #initialise indices

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            sampler=SubsetRandomSampler(train_dataset.l_indices),
            **kwargs)

        test_loader = torch.utils.data.DataLoader(
            cifar10.CIFAR10(args.root,
                            'test',
                            seed=args.seed,
                            transform=transforms.Compose(
                                [transforms.ToTensor()])),
            batch_size=args.test_batch_size,
            shuffle=False,
            **kwargs)

        # train on the initial labelled set
        global_step = 0
        model = resnet18(pretrained=False,
                         progress=False,
                         num_classes=args.num_classes).to(device)

        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.momentum)
        method = PredictionChange(u_indices=train_dataset.u_indices,
                                  model=model,
                                  dataset=train_dataset,
                                  data_name='cifar10')

        logs = {
            "train_losses": [],
            "train_acces": [],
            "test_acces": [],
            "pred_changes": []
        }
        logger = Logger(logs)

        for epoch in range(1, args.init_epochs + 1):
            start = time.time()
            global_step, train_loss, train_acc = utils.train(
                args, model, device, train_loader, optimizer, epoch,
                global_step)
            print('Training one epoch took: {:.4f} seconds.\n'.format(
                time.time() - start))
            test_acc, _ = utils.test(model, device, test_loader)

            print('Computing prediction changes...')
            pred_change = method.compute_pred_changes(model)

            logger.append(train_losses=train_loss,
                          train_acces=train_acc,
                          test_acces=test_acc,
                          pred_changes=pred_change)

        # save the logs
        train_dataset.save_logs(logger.logs)


############################### Training on updated indices #########################################
    if args.train_on_updated:
        import os
        # create Dataset object and load initial indices.
        train_dataset = cifar10.CIFAR10(
            root=args.root,
            dataset='train',
            init_n_labeled=args.init_num_labelled,
            seed=args.seed,
            download=False,
            transform=transforms.Compose([transforms.ToTensor()]),
            target_transform=None,
            indices_name="init_indices.npz")  #load initial indices from file

        logs_path = os.path.join(train_dataset.init_folder, 'logs.npz')
        print("Updating indices using log file: {}...".format(logs_path))
        start = time.time()
        # sampling using proposed prediction change method
        method = PredictionChange(u_indices=train_dataset.u_indices,
                                  dataset=train_dataset,
                                  data_name='CIFAR-10')

        sample = method.select_batch_from_logs(N=args.num_to_sample,
                                               skip=args.skip,
                                               path=logs_path,
                                               key="pred_changes")
        # update and save updated indices
        filename_updated_indices = "updated_indices_N_{}_skip_{}".format(
            args.num_to_sample, args.skip)
        method.update_indices(dataset=train_dataset,
                              indices=sample,
                              filename=filename_updated_indices)
        print('Active learning sampling took: {:.4f} seconds.\n'.format(
            time.time() - start))

        print("Training on updated labelled training set...")
        train_dataset = cifar10.CIFAR10(
            root=args.root,
            dataset='train',
            init_n_labeled=args.init_num_labelled,
            seed=args.seed,
            download=False,
            transform=transforms.Compose([transforms.ToTensor()]),
            target_transform=None,
            indices_name=filename_updated_indices +
            ".npz")  #load updated indices from file

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            sampler=SubsetRandomSampler(train_dataset.l_indices),
            **kwargs)

        if args.validate:
            test_or_validate = 'validation'
        else:
            test_or_validate = 'test'
        test_loader = torch.utils.data.DataLoader(
            cifar10.CIFAR10(args.root,
                            test_or_validate,
                            seed=args.seed,
                            transform=transforms.Compose(
                                [transforms.ToTensor()])),
            batch_size=args.test_batch_size,
            shuffle=False,
            **kwargs)

        model = resnet18(pretrained=False,
                         progress=False,
                         num_classes=args.num_classes).to(device)

        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.momentum)
        global_step = 0

        for epoch in range(1, args.epochs + 1):
            start = time.time()
            ###
            global_step, _, _ = utils.train(args, model, device, train_loader,
                                            optimizer, epoch, global_step)
            print('\nTraining one epoch took: {:.4f} seconds.\n'.format(
                time.time() - start))
            ###
            test_acc, _ = utils.test(model, device, test_loader)

        with open(args.output, 'a') as write_file:
            writer = csv.writer(write_file)
            writer.writerow([args.seed, test_acc])