Пример #1
0
def main(epochs, cpu, cudnn_flag, visdom_port, visdom_freq, temp_dir, seed,
         no_bias_decay, label_smoothing, temperature):
    device = torch.device(
        'cuda:0' if torch.cuda.is_available() and not cpu else 'cpu')
    callback = VisdomLogger(port=visdom_port) if visdom_port else None
    if cudnn_flag == 'deterministic':
        setattr(cudnn, cudnn_flag, True)

    torch.manual_seed(seed)
    loaders, recall_ks = get_loaders()

    torch.manual_seed(seed)
    model = get_model(num_classes=loaders.num_classes)
    class_loss = SmoothCrossEntropy(epsilon=label_smoothing,
                                    temperature=temperature)

    model.to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    parameters = []
    if no_bias_decay:
        parameters.append(
            {'params': [par for par in model.parameters() if par.dim() != 1]})
        parameters.append({
            'params': [par for par in model.parameters() if par.dim() == 1],
            'weight_decay':
            0
        })
    else:
        parameters.append({'params': model.parameters()})
    optimizer, scheduler = get_optimizer_scheduler(parameters=parameters,
                                                   loader_length=len(
                                                       loaders.train))

    # setup partial function to simplify call
    eval_function = partial(evaluate,
                            model=model,
                            recall=recall_ks,
                            query_loader=loaders.query,
                            gallery_loader=loaders.gallery)

    # setup best validation logger
    metrics = eval_function()
    if callback is not None:
        callback.scalars(
            ['l2', 'cosine'],
            0, [metrics.recall['l2'][1], metrics.recall['cosine'][1]],
            title='Val Recall@1')
    pprint(metrics.recall)
    best_val = (0, metrics.recall, deepcopy(model.state_dict()))

    torch.manual_seed(seed)
    for epoch in range(epochs):
        if cudnn_flag == 'benchmark':
            setattr(cudnn, cudnn_flag, True)

        train(model=model,
              loader=loaders.train,
              class_loss=class_loss,
              optimizer=optimizer,
              scheduler=scheduler,
              epoch=epoch,
              callback=callback,
              freq=visdom_freq,
              ex=ex)

        # validation
        if cudnn_flag == 'benchmark':
            setattr(cudnn, cudnn_flag, False)
        metrics = eval_function()
        print('Validation [{:03d}]'.format(epoch)), pprint(metrics.recall)
        ex.log_scalar('val.recall_l2@1',
                      metrics.recall['l2'][1],
                      step=epoch + 1)
        ex.log_scalar('val.recall_cosine@1',
                      metrics.recall['cosine'][1],
                      step=epoch + 1)

        if callback is not None:
            callback.scalars(
                ['l2', 'cosine'],
                epoch + 1,
                [metrics.recall['l2'][1], metrics.recall['cosine'][1]],
                title='Val Recall')

        # save model dict if the chosen validation metric is better
        if metrics.recall['cosine'][1] >= best_val[1]['cosine'][1]:
            best_val = (epoch + 1, metrics.recall,
                        deepcopy(model.state_dict()))

    # logging
    ex.info['recall'] = best_val[1]

    # saving
    save_name = os.path.join(
        temp_dir, '{}_{}.pt'.format(ex.current_run.config['model']['arch'],
                                    ex.current_run.config['dataset']['name']))
    torch.save(state_dict_to_cpu(best_val[2]), save_name)
    ex.add_artifact(save_name)

    if callback is not None:
        save_name = os.path.join(temp_dir, 'visdom_data.pt')
        callback.save(save_name)
        ex.add_artifact(save_name)

    return best_val[1]['cosine'][1]
Пример #2
0
def main(args):
    rng = np.random.RandomState(args.seed)

    if args.test:
        assert args.checkpoint is not None, 'Please inform the checkpoint (trained model)'

    if args.logdir is None:
        logdir = get_logdir(args)
    else:
        logdir = pathlib.Path(args.logdir)
    if not logdir.exists():
        logdir.mkdir()

    print('Writing logs to {}'.format(logdir))

    device = torch.device(
        'cuda',
        args.gpu_idx) if torch.cuda.is_available() else torch.device('cpu')

    if args.port is not None:
        logger = VisdomLogger(port=args.port)
    else:
        logger = None

    print('Loading Data')
    x, y, yforg, usermapping, filenames = load_dataset(args.dataset_path)

    dev_users = range(args.dev_users[0], args.dev_users[1])
    if args.devset_size is not None:
        # Randomly select users from the dev set
        dev_users = rng.choice(dev_users, args.devset_size, replace=False)

    if args.devset_sk_size is not None:
        assert args.devset_sk_size <= len(
            dev_users), 'devset-sk-size should be smaller than devset-size'

        # Randomly select users from the dev set to have skilled forgeries (others don't)
        dev_sk_users = set(
            rng.choice(dev_users, args.devset_sk_size, replace=False))
    else:
        dev_sk_users = set(dev_users)

    print('{} users in dev set; {} users with skilled forgeries'.format(
        len(dev_users), len(dev_sk_users)))

    if args.exp_users is not None:
        val_users = range(args.exp_users[0], args.exp_users[1])
        print('Testing with users from {} to {}'.format(
            args.exp_users[0], args.exp_users[1]))
    elif args.use_testset:
        val_users = range(0, 300)
        print('Testing with Exploitation set')
    else:
        val_users = range(300, 350)

    print('Initializing model')
    base_model = models.available_models[args.model]().to(device)
    weights = base_model.build_weights(device)
    maml = MAML(base_model,
                args.num_updates,
                args.num_updates,
                args.train_lr,
                args.meta_lr,
                args.meta_min_lr,
                args.epochs,
                args.learn_task_lr,
                weights,
                device,
                logger,
                loss_function=balanced_binary_cross_entropy,
                is_classification=True)

    if args.checkpoint:
        params = torch.load(args.checkpoint)
        maml.load(params)

    if args.test:
        test_and_save(args, device, logdir, maml, val_users, x, y, yforg)
        return

    # Pretraining
    if args.pretrain_epochs > 0:
        print('Pre-training')
        data = util.get_subset((x, y, yforg), subset=range(350, 881))

        wrapped_model = PretrainWrapper(base_model, weights)

        if not args.pretrain_forg:
            data = util.remove_forgeries(data, forg_idx=2)

        train_loader, val_loader = pretrain.setup_data_loaders(
            data, 32, args.input_size)
        n_classes = len(np.unique(y))

        classification_layer = nn.Linear(base_model.feature_space_size,
                                         n_classes).to(device)
        if args.pretrain_forg:
            forg_layer = nn.Linear(base_model.feature_space_size, 1).to(device)
        else:
            forg_layer = nn.Module()  # Stub module with no parameters

        pretrain_args = argparse.Namespace(lr=0.01,
                                           lr_decay=0.1,
                                           lr_decay_times=1,
                                           momentum=0.9,
                                           weight_decay=0.001,
                                           forg=args.pretrain_forg,
                                           lamb=args.pretrain_forg_lambda,
                                           epochs=args.pretrain_epochs)
        print(pretrain_args)
        pretrain.train(wrapped_model,
                       classification_layer,
                       forg_layer,
                       train_loader,
                       val_loader,
                       device,
                       logger,
                       pretrain_args,
                       logdir=None)

    # MAML training

    trainset = MAMLDataSet(data=(x, y, yforg),
                           subset=dev_users,
                           sk_subset=dev_sk_users,
                           num_gen_train=args.num_gen,
                           num_rf_train=args.num_rf,
                           num_gen_test=args.num_gen_test,
                           num_rf_test=args.num_rf_test,
                           num_sk_test=args.num_sk_test,
                           input_shape=args.input_size,
                           test=False,
                           rng=np.random.RandomState(args.seed))

    val_set = MAMLDataSet(data=(x, y, yforg),
                          subset=val_users,
                          num_gen_train=args.num_gen,
                          num_rf_train=args.num_rf,
                          num_gen_test=args.num_gen_test,
                          num_rf_test=args.num_rf_test,
                          num_sk_test=args.num_sk_test,
                          input_shape=args.input_size,
                          test=True,
                          rng=np.random.RandomState(args.seed))

    loader = DataLoader(trainset,
                        batch_size=args.meta_batch_size,
                        shuffle=True,
                        num_workers=2,
                        collate_fn=trainset.collate_fn)

    print('Training')
    best_val_acc = 0
    with tqdm(initial=0, total=len(loader) * args.epochs) as pbar:
        if args.checkpoint is not None:
            postupdate_accs, postupdate_losses, preupdate_losses = test_one_epoch(
                maml, val_set, device, args.num_updates)

            if logger:
                for i in range(args.num_updates):
                    logger.scalar('val_postupdate_loss_{}'.format(i), 0,
                                  np.mean(postupdate_losses, axis=0)[i])

                    logger.scalar('val_postupdate_acc_{}'.format(i), 0,
                                  np.mean(postupdate_accs, axis=0)[i])

        for epoch in range(args.epochs):
            loss_weights = get_per_step_loss_importance_vector(
                args.num_updates, args.msl_epochs, epoch)

            n_batches = len(loader)
            for step, item in enumerate(loader):
                item = move_to_gpu(*item, device=device)
                maml.meta_learning_step((item[0], item[1]), (item[2], item[3]),
                                        loss_weights, epoch + step / n_batches)
                pbar.update(1)

            maml.scheduler.step()

            postupdate_accs, postupdate_losses, preupdate_losses = test_one_epoch(
                maml, val_set, device, args.num_updates)

            if logger:
                for i in range(args.num_updates):
                    logger.scalar('val_postupdate_loss_{}'.format(i),
                                  epoch + 1,
                                  np.mean(postupdate_losses, axis=0)[i])

                    logger.scalar('val_postupdate_acc_{}'.format(i), epoch + 1,
                                  np.mean(postupdate_accs, axis=0)[i])

                logger.save(logdir / 'train_curves.pickle')
            this_val_loss = np.mean(postupdate_losses, axis=0)[-1]
            this_val_acc = np.mean(postupdate_accs, axis=0)[-1]

            if this_val_acc > best_val_acc:
                best_val_acc = this_val_acc
                torch.save(maml.parameters, logdir / 'best_model.pth')
            print('Epoch {}. Val loss: {:.4f}. Val Acc: {:.2f}%'.format(
                epoch, this_val_loss, this_val_acc * 100))

    # Re-load best parameters and test with 10 folds
    params = torch.load(logdir / 'best_model.pth')
    maml.load(params)

    test_and_save(args, device, logdir, maml, val_users, x, y, yforg)