def main(args):
    np.random.seed(args.seed)
    dataset = get_dataset(args.dataset, args.K)
    model = MAML(dataset, args.model_type, args.loss_type, dataset.dim_input,
                 dataset.dim_output, args.alpha, args.beta, args.K,
                 args.batch_size, args.is_train, args.num_updates, args.norm)
    if args.is_train:
        model.learn(args.batch_size, dataset, args.max_steps)
    else:
        model.evaluate(dataset,
                       args.test_sample,
                       args.draw,
                       restore_checkpoint=args.restore_checkpoint,
                       restore_dir=args.restore_dir)
Esempio n. 2
0
def main():

    if torch.cuda.is_available():
        args.device = torch.device('cuda')
        args.cuda = True
    else:
        args.device = torch.device('cpu')
        args.cuda = False

    # Make as reproducible as possible.
    # Please note that pytorch does not let us make things completely reproducible across machines.
    # See https://pytorch.org/docs/stable/notes/randomness.html
    if args.seed is not None:
        print('setting seed', args.seed)
        torch.manual_seed(args.seed)
        if args.cuda:
            torch.cuda.manual_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    # Get data
    train_loader, train_eval_loader, val_loader, _ = data.get_loaders(args)
    z_loader = data.get_z_loader(args)
    args.n_groups = train_loader.dataset.n_groups

    # Get model
    if args.dataset == 'mnist':
        num_classes = 10
    elif args.dataset in 'celeba':
        num_classes = 4
    elif args.dataset == 'femnist':
        num_classes = 62
    model = utils.MetaConvModel(train_loader.dataset.image_shape[0],
                                num_classes,
                                hidden_size=128,
                                feature_size=128)
    z_model = utils.MetaConvModel(train_loader.dataset.image_shape[0],
                                  train_loader.dataset.n_groups,
                                  hidden_size=128,
                                  feature_size=128)
    #     model = utils.get_model(args, image_shape=train_loader.dataset.image_shape)

    # Loss Fn
    loss_fn = nn.CrossEntropyLoss()

    # Optimizer
    if args.optimizer == 'adam':  # This is used for MNIST.
        optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

        z_optimizer = torch.optim.Adam(z_model.parameters(), lr=1e-3)
    elif args.optimizer == 'sgd':

        # From DRNN paper
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                           model.parameters()),
                                    lr=args.learning_rate,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)

        z_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                             z_model.parameters()),
                                      lr=args.learning_rate,
                                      momentum=0.9,
                                      weight_decay=args.weight_decay)

#     import ipdb; ipdb.set_trace()
    maml_model = MAML(model,
                      z_model,
                      z_loader,
                      optimizer,
                      z_optimizer,
                      num_adaptation_steps=args.num_steps,
                      step_size=args.step_size,
                      loss_function=loss_fn,
                      device=args.device)

    z_epochs = 50
    for epoch in trange(z_epochs):
        loss, accuracy = maml_model.train_z_iter(train_loader)
        print(epoch, 'epoch ,', loss, 'z_loss ,', accuracy, 'z_accuracy')

    # Train loop
    best_worst_case_acc = 0
    best_worst_case_acc_epoch = 0
    avg_val_acc = 0
    empirical_val_acc = 0

    for epoch in trange(args.num_epochs):

        train_results = maml_model.train(train_loader,
                                         verbose=True,
                                         desc='Training',
                                         leave=False)

        # Decay learning rate after one epoch
        if args.use_lr_schedule:
            if (args.dataset == 'celeba' and epoch == 0):
                for param_group in optimizer.param_groups:
                    param_group['lr'] = 1e-5


#         import ipdb; ipdb.set_trace()
        if epoch % args.epochs_per_eval == 0:

            # validation
            worst_case_acc, stats = maml_model.evaluate(
                val_loader,
                epoch=epoch,
                log_wandb=args.log_wandb,
                n_samples_per_dist=args.n_test_per_dist,
                split='val')

            # Track early stopping values with respect to worst case.
            if worst_case_acc > best_worst_case_acc:
                best_worst_case_acc = worst_case_acc

                save_model(model, ckpt_dir, epoch, args.device)

            # Log early stopping values
            if args.log_wandb:
                wandb.log({
                    "Train Loss": train_results['mean_outer_loss'],
                    "Best Worst Case Val Acc": best_worst_case_acc,
                    "Train Accuracy": train_results['accuracy_after'],
                    "epoch": epoch
                })

            print(f"Epoch: ", epoch, "Worst Case Acc: ", worst_case_acc)