def train(model: nn.Module,
          loader: DataLoader,
          class_loss: nn.Module,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          epoch: int,
          callback: VisdomLogger,
          freq: int,
          ex: Experiment = None) -> None:
    model.train()
    device = next(model.parameters()).device
    to_device = lambda x: x.to(device, non_blocking=True)
    loader_length = len(loader)
    train_losses = AverageMeter(device=device, length=loader_length)
    train_accs = AverageMeter(device=device, length=loader_length)

    pbar = tqdm(loader, ncols=80, desc='Training   [{:03d}]'.format(epoch))
    for i, (batch, labels, indices) in enumerate(pbar):
        batch, labels, indices = map(to_device, (batch, labels, indices))
        logits, features = model(batch)
        loss = class_loss(logits, labels).mean()
        acc = (logits.detach().argmax(1) == labels).float().mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        train_losses.append(loss)
        train_accs.append(acc)

        if callback is not None and not (i + 1) % freq:
            step = epoch + i / loader_length
            callback.scalar('xent',
                            step,
                            train_losses.last_avg,
                            title='Train Losses')
            callback.scalar('train_acc',
                            step,
                            train_accs.last_avg,
                            title='Train Acc')

    if ex is not None:
        for i, (loss, acc) in enumerate(
                zip(train_losses.values_list, train_accs.values_list)):
            step = epoch + i / loader_length
            ex.log_scalar('train.loss', loss, step=step)
            ex.log_scalar('train.acc', acc, step=step)
예제 #2
0
def episodic_validate(
        args: argparse.Namespace,
        val_loader: torch.utils.data.DataLoader,
        model: DDP,
        use_callback: bool,
        suffix: str = 'test') -> Tuple[torch.tensor, torch.tensor]:

    print('==> Start testing')

    model.eval()
    nb_episodes = int(args.test_num / args.batch_size_val)

    # ========== Metrics initialization  ==========

    H, W = args.image_size, args.image_size
    c = model.module.bottleneck_dim
    h = model.module.feature_res[0]
    w = model.module.feature_res[1]

    runtimes = torch.zeros(args.n_runs)
    deltas_init = torch.zeros((args.n_runs, nb_episodes, args.batch_size_val))
    deltas_final = torch.zeros((args.n_runs, nb_episodes, args.batch_size_val))
    val_IoUs = np.zeros(args.n_runs)
    val_losses = np.zeros(args.n_runs)

    # ========== Perform the runs  ==========
    for run in tqdm(range(args.n_runs)):

        # =============== Initialize the metric dictionaries ===============

        loss_meter = AverageMeter()
        iter_num = 0
        cls_intersection = defaultdict(int)  # Default value is 0
        cls_union = defaultdict(int)
        IoU = defaultdict(int)

        # =============== episode = group of tasks ===============
        runtime = 0
        for e in tqdm(range(nb_episodes)):
            t0 = time.time()
            features_s = torch.zeros(args.batch_size_val, args.shot, c, h,
                                     w).to(dist.get_rank())
            features_q = torch.zeros(args.batch_size_val, 1, c, h,
                                     w).to(dist.get_rank())
            gt_s = 255 * torch.ones(args.batch_size_val, args.shot,
                                    args.image_size,
                                    args.image_size).long().to(dist.get_rank())
            gt_q = 255 * torch.ones(args.batch_size_val, 1, args.image_size,
                                    args.image_size).long().to(dist.get_rank())
            n_shots = torch.zeros(args.batch_size_val).to(dist.get_rank())
            classes = []  # All classes considered in the tasks

            # =========== Generate tasks and extract features for each task ===============
            for i in range(args.batch_size_val):
                try:
                    qry_img, q_label, spprt_imgs, s_label, subcls, _, _ = iter_loader.next(
                    )
                except:
                    iter_loader = iter(val_loader)
                    qry_img, q_label, spprt_imgs, s_label, subcls, _, _ = iter_loader.next(
                    )
                iter_num += 1

                q_label = q_label.to(dist.get_rank(), non_blocking=True)
                spprt_imgs = spprt_imgs.to(dist.get_rank(), non_blocking=True)
                s_label = s_label.to(dist.get_rank(), non_blocking=True)
                qry_img = qry_img.to(dist.get_rank(), non_blocking=True)

                f_s = model.module.extract_features(spprt_imgs.squeeze(0))
                f_q = model.module.extract_features(qry_img)

                shot = f_s.size(0)
                n_shots[i] = shot
                features_s[i, :shot] = f_s.detach()
                features_q[i] = f_q.detach()
                gt_s[i, :shot] = s_label
                gt_q[i, 0] = q_label
                classes.append([class_.item() for class_ in subcls])

            # =========== Normalize features along channel dimension ===============
            if args.norm_feat:
                features_s = F.normalize(features_s, dim=2)
                features_q = F.normalize(features_q, dim=2)

            # =========== Create a callback is args.visdom_port != -1 ===============
            callback = VisdomLogger(
                port=args.visdom_port) if use_callback else None

            # ===========  Initialize the classifier + prototypes + F/B parameter Π ===============
            classifier = Classifier(args)
            classifier.init_prototypes(features_s, features_q, gt_s, gt_q,
                                       classes, callback)
            batch_deltas = classifier.compute_FB_param(features_q=features_q,
                                                       gt_q=gt_q)
            deltas_init[run, e, :] = batch_deltas.cpu()

            # =========== Perform RePRI inference ===============
            batch_deltas = classifier.RePRI(features_s, features_q, gt_s, gt_q,
                                            classes, n_shots, callback)
            deltas_final[run, e, :] = batch_deltas
            t1 = time.time()
            runtime += t1 - t0
            logits = classifier.get_logits(features_q)  # [n_tasks, shot, h, w]
            logits = F.interpolate(logits,
                                   size=(H, W),
                                   mode='bilinear',
                                   align_corners=True)
            probas = classifier.get_probas(logits).detach()
            intersection, union, _ = batch_intersectionAndUnionGPU(
                probas, gt_q, 2)  # [n_tasks, shot, num_class]
            intersection, union = intersection.cpu(), union.cpu()

            # ================== Log metrics ==================
            one_hot_gt = to_one_hot(gt_q, 2)
            valid_pixels = gt_q != 255
            loss = classifier.get_ce(probas,
                                     valid_pixels,
                                     one_hot_gt,
                                     reduction='mean')
            loss_meter.update(loss.item())
            for i, task_classes in enumerate(classes):
                for j, class_ in enumerate(task_classes):
                    cls_intersection[class_] += intersection[
                        i, 0, j + 1]  # Do not count background
                    cls_union[class_] += union[i, 0, j + 1]

            for class_ in cls_union:
                IoU[class_] = cls_intersection[class_] / (cls_union[class_] +
                                                          1e-10)

            if (iter_num % 200 == 0):
                mIoU = np.mean([IoU[i] for i in IoU])
                print(
                    'Test: [{}/{}] '
                    'mIoU {:.4f} '
                    'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f}) '.format(
                        iter_num,
                        args.test_num,
                        mIoU,
                        loss_meter=loss_meter,
                    ))
        runtimes[run] = runtime
        mIoU = np.mean(list(IoU.values()))
        print('mIoU---Val result: mIoU {:.4f}.'.format(mIoU))
        for class_ in cls_union:
            print("Class {} : {:.4f}".format(class_, IoU[class_]))

        val_IoUs[run] = mIoU
        val_losses[run] = loss_meter.avg

    # ================== Save metrics ==================
    if args.save_oracle:
        root = os.path.join('plots', 'oracle')
        os.makedirs(root, exist_ok=True)
        np.save(os.path.join(root, 'delta_init.npy'), deltas_init.numpy())
        np.save(os.path.join(root, 'delta_final.npy'), deltas_final.numpy())

    print('Average mIoU over {} runs --- {:.4f}.'.format(
        args.n_runs, val_IoUs.mean()))
    print('Average runtime / run --- {:.4f}.'.format(runtimes.mean()))

    return val_IoUs.mean(), val_losses.mean()
예제 #3
0
                                           download=True,
                                           transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=test_batch_size,
                                             shuffle=False,
                                             num_workers=0)

    model_pgp = alexnet().to(cuda)

    epochs = 170
    prune_rate = 0.05
    remove_ratio = 0.5
    optimizer_pgp = optim.SGD(model_pgp.parameters(), lr=lr, momentum=momentum)
    zero_initializer = functools.partial(torch.nn.init.constant_, val=0)

    logger = VisdomLogger(port=10999)
    logger = LoggerForSacred(logger)

    if not os.path.exists("temp_target.p"):

        loss_acc, model_architecture = pgp.pgp(epochs,
                                               prune_rate,
                                               remove_ratio,
                                               testloader,
                                               gng_criterion,
                                               cuda=cuda,
                                               model=model_pgp,
                                               train_loader=trainloader,
                                               train_ratio=1,
                                               prune_once=False,
                                               initializer_fn=zero_initializer,
예제 #4
0
def main(epochs, cpu, cudnn_flag, visdom_port, visdom_freq, temp_dir, seed,
         no_bias_decay, label_smoothing, temperature):
    device = torch.device(
        'cuda:0' if torch.cuda.is_available() and not cpu else 'cpu')
    callback = VisdomLogger(port=visdom_port) if visdom_port else None
    if cudnn_flag == 'deterministic':
        setattr(cudnn, cudnn_flag, True)

    torch.manual_seed(seed)
    loaders, recall_ks = get_loaders()

    torch.manual_seed(seed)
    model = get_model(num_classes=loaders.num_classes)
    class_loss = SmoothCrossEntropy(epsilon=label_smoothing,
                                    temperature=temperature)

    model.to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    parameters = []
    if no_bias_decay:
        parameters.append(
            {'params': [par for par in model.parameters() if par.dim() != 1]})
        parameters.append({
            'params': [par for par in model.parameters() if par.dim() == 1],
            'weight_decay':
            0
        })
    else:
        parameters.append({'params': model.parameters()})
    optimizer, scheduler = get_optimizer_scheduler(parameters=parameters,
                                                   loader_length=len(
                                                       loaders.train))

    # setup partial function to simplify call
    eval_function = partial(evaluate,
                            model=model,
                            recall=recall_ks,
                            query_loader=loaders.query,
                            gallery_loader=loaders.gallery)

    # setup best validation logger
    metrics = eval_function()
    if callback is not None:
        callback.scalars(
            ['l2', 'cosine'],
            0, [metrics.recall['l2'][1], metrics.recall['cosine'][1]],
            title='Val Recall@1')
    pprint(metrics.recall)
    best_val = (0, metrics.recall, deepcopy(model.state_dict()))

    torch.manual_seed(seed)
    for epoch in range(epochs):
        if cudnn_flag == 'benchmark':
            setattr(cudnn, cudnn_flag, True)

        train(model=model,
              loader=loaders.train,
              class_loss=class_loss,
              optimizer=optimizer,
              scheduler=scheduler,
              epoch=epoch,
              callback=callback,
              freq=visdom_freq,
              ex=ex)

        # validation
        if cudnn_flag == 'benchmark':
            setattr(cudnn, cudnn_flag, False)
        metrics = eval_function()
        print('Validation [{:03d}]'.format(epoch)), pprint(metrics.recall)
        ex.log_scalar('val.recall_l2@1',
                      metrics.recall['l2'][1],
                      step=epoch + 1)
        ex.log_scalar('val.recall_cosine@1',
                      metrics.recall['cosine'][1],
                      step=epoch + 1)

        if callback is not None:
            callback.scalars(
                ['l2', 'cosine'],
                epoch + 1,
                [metrics.recall['l2'][1], metrics.recall['cosine'][1]],
                title='Val Recall')

        # save model dict if the chosen validation metric is better
        if metrics.recall['cosine'][1] >= best_val[1]['cosine'][1]:
            best_val = (epoch + 1, metrics.recall,
                        deepcopy(model.state_dict()))

    # logging
    ex.info['recall'] = best_val[1]

    # saving
    save_name = os.path.join(
        temp_dir, '{}_{}.pt'.format(ex.current_run.config['model']['arch'],
                                    ex.current_run.config['dataset']['name']))
    torch.save(state_dict_to_cpu(best_val[2]), save_name)
    ex.add_artifact(save_name)

    if callback is not None:
        save_name = os.path.join(temp_dir, 'visdom_data.pt')
        callback.save(save_name)
        ex.add_artifact(save_name)

    return best_val[1]['cosine'][1]
def main_worker(rank: int, world_size: int, args: argparse.Namespace) -> None:

    print(f"==> Running process rank {rank}.")
    setup(args, rank, world_size)

    if args.manual_seed is not None:
        cudnn.benchmark = False
        cudnn.deterministic = True
        torch.cuda.manual_seed(args.manual_seed + rank)
        np.random.seed(args.manual_seed + rank)
        torch.manual_seed(args.manual_seed + rank)
        torch.cuda.manual_seed_all(args.manual_seed + rank)
        random.seed(args.manual_seed + rank)

    callback = None if args.visdom_port == -1 else VisdomLogger(
        port=args.visdom_port)

    # ========== Model + Optimizer ==========
    model = get_model(args).to(rank)
    modules_ori = [
        model.layer0, model.layer1, model.layer2, model.layer3, model.layer4
    ]
    modules_new = [model.ppm, model.bottleneck, model.classifier]

    params_list = []
    for module in modules_ori:
        params_list.append(dict(params=module.parameters(), lr=args.lr))
    for module in modules_new:
        params_list.append(
            dict(params=module.parameters(), lr=args.lr * args.scale_lr))
    optimizer = get_optimizer(args, params_list)

    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = DDP(model, device_ids=[rank])

    savedir = get_model_dir(args)

    # ========== Validation ==================
    validate_fn = episodic_validate if args.episodic_val else standard_validate

    # ========== Data  =====================
    train_loader, train_sampler = get_train_loader(args)
    val_loader, _ = get_val_loader(
        args
    )  # mode='train' means that we will validate on images from validation set, but with the bases classes

    # ========== Scheduler  ================
    scheduler = get_scheduler(args, optimizer, len(train_loader))

    # ========== Metrics initialization ====
    max_val_mIoU = 0.
    if args.debug:
        iter_per_epoch = 5
    else:
        iter_per_epoch = len(train_loader)
    log_iter = int(iter_per_epoch / args.log_freq) + 1

    metrics: Dict[str, Tensor] = {
        "val_mIou": torch.zeros((args.epochs, 1)).type(torch.float32),
        "val_loss": torch.zeros((args.epochs, 1)).type(torch.float32),
        "train_mIou": torch.zeros((args.epochs, log_iter)).type(torch.float32),
        "train_loss": torch.zeros((args.epochs, log_iter)).type(torch.float32),
    }

    # ========== Training  =================
    for epoch in tqdm(range(args.epochs)):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        train_mIou, train_loss = do_epoch(args=args,
                                          train_loader=train_loader,
                                          iter_per_epoch=iter_per_epoch,
                                          model=model,
                                          optimizer=optimizer,
                                          scheduler=scheduler,
                                          epoch=epoch,
                                          callback=callback,
                                          log_iter=log_iter)

        val_mIou, val_loss = validate_fn(args=args,
                                         val_loader=val_loader,
                                         model=model,
                                         use_callback=False,
                                         suffix=f'train_{epoch}')
        if args.distributed:
            dist.all_reduce(val_mIou), dist.all_reduce(val_loss)
            val_mIou /= world_size
            val_loss /= world_size

        if main_process(args):
            # Live plot if desired with visdom
            if callback is not None:
                callback.scalar('val_loss',
                                epoch,
                                val_loss,
                                title='Validiation Loss')
                callback.scalar('mIoU_val',
                                epoch,
                                val_mIou,
                                title='Val metrics')

            # Model selection
            if val_mIou.item() > max_val_mIoU:
                max_val_mIoU = val_mIou.item()
                os.makedirs(savedir, exist_ok=True)
                filename = os.path.join(savedir, f'best.pth')
                if args.save_models:
                    print('Saving checkpoint to: ' + filename)
                    torch.save(
                        {
                            'epoch': epoch,
                            'state_dict': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        }, filename)
            print("=> Max_mIoU = {:.3f}".format(max_val_mIoU))

            # Sort and save the metrics
            for k in metrics:
                metrics[k][epoch] = eval(k)

            for k, e in metrics.items():
                path = os.path.join(savedir, f"{k}.npy")
                np.save(path, e.cpu().numpy())

    if args.save_models and main_process(args):
        filename = os.path.join(savedir, 'final.pth')
        print(f'Saving checkpoint to: {filename}')
        torch.save(
            {
                'epoch': args.epochs,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }, filename)

    cleanup()
def do_epoch(args: argparse.Namespace,
             train_loader: torch.utils.data.DataLoader, model: DDP,
             optimizer: torch.optim.Optimizer,
             scheduler: torch.optim.lr_scheduler, epoch: int,
             callback: VisdomLogger, iter_per_epoch: int,
             log_iter: int) -> Tuple[torch.tensor, torch.tensor]:
    loss_meter = AverageMeter()
    train_losses = torch.zeros(log_iter).to(dist.get_rank())
    train_mIous = torch.zeros(log_iter).to(dist.get_rank())

    iterable_train_loader = iter(train_loader)

    if main_process(args):
        bar = tqdm(range(iter_per_epoch))
    else:
        bar = range(iter_per_epoch)

    for i in bar:
        model.train()
        current_iter = epoch * len(train_loader) + i + 1

        images, gt = iterable_train_loader.next()
        images = images.to(dist.get_rank(), non_blocking=True)
        gt = gt.to(dist.get_rank(), non_blocking=True)

        loss = compute_loss(
            args=args,
            model=model,
            images=images,
            targets=gt.long(),
            num_classes=args.num_classes_tr,
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if args.scheduler == 'cosine':
            scheduler.step()

        if i % args.log_freq == 0:
            model.eval()
            logits = model(images)
            intersection, union, target = intersectionAndUnionGPU(
                logits.argmax(1), gt, args.num_classes_tr, 255)
            if args.distributed:
                dist.all_reduce(loss)
                dist.all_reduce(intersection)
                dist.all_reduce(union)
                dist.all_reduce(target)

            allAcc = (intersection.sum() / (target.sum() + 1e-10))  # scalar
            mAcc = (intersection / (target + 1e-10)).mean()
            mIoU = (intersection / (union + 1e-10)).mean()
            loss_meter.update(loss.item() / dist.get_world_size())

            if main_process(args):
                if callback is not None:
                    t = current_iter / len(train_loader)
                    callback.scalar('loss_train_batch',
                                    t,
                                    loss_meter.avg,
                                    title='Loss')
                    callback.scalars(['mIoU', 'mAcc', 'allAcc'],
                                     t, [mIoU, mAcc, allAcc],
                                     title='Training metrics')
                    for index, param_group in enumerate(
                            optimizer.param_groups):
                        lr = param_group['lr']
                        callback.scalar('lr', t, lr, title='Learning rate')
                        break

                train_losses[int(i / args.log_freq)] = loss_meter.avg
                train_mIous[int(i / args.log_freq)] = mIoU

    if args.scheduler != 'cosine':
        scheduler.step()

    return train_mIous, train_losses
예제 #7
0
def main(seed, pretrain, resume, evaluate, print_runtime, epochs, disable_tqdm,
         visdom_port, ckpt_path, make_plot, cuda):
    device = torch.device("cuda" if cuda else "cpu")
    callback = None if visdom_port is None else VisdomLogger(port=visdom_port)
    if seed is not None:
        random.seed(seed)
        torch.manual_seed(seed)
        cudnn.deterministic = True
    torch.cuda.set_device(0)
    # create model
    print("=> Creating model '{}'".format(
        ex.current_run.config['model']['arch']))
    model = torch.nn.DataParallel(get_model()).cuda()

    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    optimizer = get_optimizer(model)

    if pretrain:
        pretrain = os.path.join(pretrain, 'checkpoint.pth.tar')
        if os.path.isfile(pretrain):
            print("=> loading pretrained weight '{}'".format(pretrain))
            checkpoint = torch.load(pretrain)
            model_dict = model.state_dict()
            params = checkpoint['state_dict']
            params = {k: v for k, v in params.items() if k in model_dict}
            model_dict.update(params)
            model.load_state_dict(model_dict)
        else:
            print(
                '[Warning]: Did not find pretrained model {}'.format(pretrain))

    if resume:
        resume_path = ckpt_path + '/checkpoint.pth.tar'
        if os.path.isfile(resume_path):
            print("=> loading checkpoint '{}'".format(resume_path))
            checkpoint = torch.load(resume_path)
            start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            # scheduler.load_state_dict(checkpoint['scheduler'])
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                resume_path, checkpoint['epoch']))
        else:
            print('[Warning]: Did not find checkpoint {}'.format(resume_path))
    else:
        start_epoch = 0
        best_prec1 = -1

    cudnn.benchmark = True

    # Data loading code
    evaluator = Evaluator(device=device, ex=ex)
    if evaluate:
        print("Evaluating")
        results = evaluator.run_full_evaluation(model=model,
                                                model_path=ckpt_path,
                                                callback=callback)
        #MYMOD
        #,model_tag='best',
        #shots=[5],
        #method="tim-gd")
        return results

    # If this line is reached, then training the model
    trainer = Trainer(device=device, ex=ex)
    scheduler = get_scheduler(optimizer=optimizer,
                              num_batches=len(trainer.train_loader),
                              epochs=epochs)
    tqdm_loop = warp_tqdm(list(range(start_epoch, epochs)),
                          disable_tqdm=disable_tqdm)
    for epoch in tqdm_loop:
        # Do one epoch
        trainer.do_epoch(model=model,
                         optimizer=optimizer,
                         epoch=epoch,
                         scheduler=scheduler,
                         disable_tqdm=disable_tqdm,
                         callback=callback)

        # Evaluation on validation set
        prec1 = trainer.meta_val(model=model,
                                 disable_tqdm=disable_tqdm,
                                 epoch=epoch,
                                 callback=callback)
        print('Meta Val {}: {}'.format(epoch, prec1))
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        if not disable_tqdm:
            tqdm_loop.set_description('Best Acc {:.2f}'.format(best_prec1 *
                                                               100.))

        # Save checkpoint
        save_checkpoint(state={
            'epoch': epoch + 1,
            'arch': ex.current_run.config['model']['arch'],
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer': optimizer.state_dict()
        },
                        is_best=is_best,
                        folder=ckpt_path)
        if scheduler is not None:
            scheduler.step()

    # Final evaluation on test set
    results = evaluator.run_full_evaluation(model=model, model_path=ckpt_path)
    return results
예제 #8
0
def main(args):
    rng = np.random.RandomState(args.seed)

    if args.test:
        assert args.checkpoint is not None, 'Please inform the checkpoint (trained model)'

    if args.logdir is None:
        logdir = get_logdir(args)
    else:
        logdir = pathlib.Path(args.logdir)
    if not logdir.exists():
        logdir.mkdir()

    print('Writing logs to {}'.format(logdir))

    device = torch.device(
        'cuda',
        args.gpu_idx) if torch.cuda.is_available() else torch.device('cpu')

    if args.port is not None:
        logger = VisdomLogger(port=args.port)
    else:
        logger = None

    print('Loading Data')
    x, y, yforg, usermapping, filenames = load_dataset(args.dataset_path)

    dev_users = range(args.dev_users[0], args.dev_users[1])
    if args.devset_size is not None:
        # Randomly select users from the dev set
        dev_users = rng.choice(dev_users, args.devset_size, replace=False)

    if args.devset_sk_size is not None:
        assert args.devset_sk_size <= len(
            dev_users), 'devset-sk-size should be smaller than devset-size'

        # Randomly select users from the dev set to have skilled forgeries (others don't)
        dev_sk_users = set(
            rng.choice(dev_users, args.devset_sk_size, replace=False))
    else:
        dev_sk_users = set(dev_users)

    print('{} users in dev set; {} users with skilled forgeries'.format(
        len(dev_users), len(dev_sk_users)))

    if args.exp_users is not None:
        val_users = range(args.exp_users[0], args.exp_users[1])
        print('Testing with users from {} to {}'.format(
            args.exp_users[0], args.exp_users[1]))
    elif args.use_testset:
        val_users = range(0, 300)
        print('Testing with Exploitation set')
    else:
        val_users = range(300, 350)

    print('Initializing model')
    base_model = models.available_models[args.model]().to(device)
    weights = base_model.build_weights(device)
    maml = MAML(base_model,
                args.num_updates,
                args.num_updates,
                args.train_lr,
                args.meta_lr,
                args.meta_min_lr,
                args.epochs,
                args.learn_task_lr,
                weights,
                device,
                logger,
                loss_function=balanced_binary_cross_entropy,
                is_classification=True)

    if args.checkpoint:
        params = torch.load(args.checkpoint)
        maml.load(params)

    if args.test:
        test_and_save(args, device, logdir, maml, val_users, x, y, yforg)
        return

    # Pretraining
    if args.pretrain_epochs > 0:
        print('Pre-training')
        data = util.get_subset((x, y, yforg), subset=range(350, 881))

        wrapped_model = PretrainWrapper(base_model, weights)

        if not args.pretrain_forg:
            data = util.remove_forgeries(data, forg_idx=2)

        train_loader, val_loader = pretrain.setup_data_loaders(
            data, 32, args.input_size)
        n_classes = len(np.unique(y))

        classification_layer = nn.Linear(base_model.feature_space_size,
                                         n_classes).to(device)
        if args.pretrain_forg:
            forg_layer = nn.Linear(base_model.feature_space_size, 1).to(device)
        else:
            forg_layer = nn.Module()  # Stub module with no parameters

        pretrain_args = argparse.Namespace(lr=0.01,
                                           lr_decay=0.1,
                                           lr_decay_times=1,
                                           momentum=0.9,
                                           weight_decay=0.001,
                                           forg=args.pretrain_forg,
                                           lamb=args.pretrain_forg_lambda,
                                           epochs=args.pretrain_epochs)
        print(pretrain_args)
        pretrain.train(wrapped_model,
                       classification_layer,
                       forg_layer,
                       train_loader,
                       val_loader,
                       device,
                       logger,
                       pretrain_args,
                       logdir=None)

    # MAML training

    trainset = MAMLDataSet(data=(x, y, yforg),
                           subset=dev_users,
                           sk_subset=dev_sk_users,
                           num_gen_train=args.num_gen,
                           num_rf_train=args.num_rf,
                           num_gen_test=args.num_gen_test,
                           num_rf_test=args.num_rf_test,
                           num_sk_test=args.num_sk_test,
                           input_shape=args.input_size,
                           test=False,
                           rng=np.random.RandomState(args.seed))

    val_set = MAMLDataSet(data=(x, y, yforg),
                          subset=val_users,
                          num_gen_train=args.num_gen,
                          num_rf_train=args.num_rf,
                          num_gen_test=args.num_gen_test,
                          num_rf_test=args.num_rf_test,
                          num_sk_test=args.num_sk_test,
                          input_shape=args.input_size,
                          test=True,
                          rng=np.random.RandomState(args.seed))

    loader = DataLoader(trainset,
                        batch_size=args.meta_batch_size,
                        shuffle=True,
                        num_workers=2,
                        collate_fn=trainset.collate_fn)

    print('Training')
    best_val_acc = 0
    with tqdm(initial=0, total=len(loader) * args.epochs) as pbar:
        if args.checkpoint is not None:
            postupdate_accs, postupdate_losses, preupdate_losses = test_one_epoch(
                maml, val_set, device, args.num_updates)

            if logger:
                for i in range(args.num_updates):
                    logger.scalar('val_postupdate_loss_{}'.format(i), 0,
                                  np.mean(postupdate_losses, axis=0)[i])

                    logger.scalar('val_postupdate_acc_{}'.format(i), 0,
                                  np.mean(postupdate_accs, axis=0)[i])

        for epoch in range(args.epochs):
            loss_weights = get_per_step_loss_importance_vector(
                args.num_updates, args.msl_epochs, epoch)

            n_batches = len(loader)
            for step, item in enumerate(loader):
                item = move_to_gpu(*item, device=device)
                maml.meta_learning_step((item[0], item[1]), (item[2], item[3]),
                                        loss_weights, epoch + step / n_batches)
                pbar.update(1)

            maml.scheduler.step()

            postupdate_accs, postupdate_losses, preupdate_losses = test_one_epoch(
                maml, val_set, device, args.num_updates)

            if logger:
                for i in range(args.num_updates):
                    logger.scalar('val_postupdate_loss_{}'.format(i),
                                  epoch + 1,
                                  np.mean(postupdate_losses, axis=0)[i])

                    logger.scalar('val_postupdate_acc_{}'.format(i), epoch + 1,
                                  np.mean(postupdate_accs, axis=0)[i])

                logger.save(logdir / 'train_curves.pickle')
            this_val_loss = np.mean(postupdate_losses, axis=0)[-1]
            this_val_acc = np.mean(postupdate_accs, axis=0)[-1]

            if this_val_acc > best_val_acc:
                best_val_acc = this_val_acc
                torch.save(maml.parameters, logdir / 'best_model.pth')
            print('Epoch {}. Val loss: {:.4f}. Val Acc: {:.2f}%'.format(
                epoch, this_val_loss, this_val_acc * 100))

    # Re-load best parameters and test with 10 folds
    params = torch.load(logdir / 'best_model.pth')
    maml.load(params)

    test_and_save(args, device, logdir, maml, val_users, x, y, yforg)