コード例 #1
0
ファイル: run_maml.py プロジェクト: iamsimha/pytorch-maml
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--num-classes",
                        type=int,
                        help="Number of samples per class")
    parser.add_argument("--num-samples-per-class",
                        type=int,
                        help="Number of samples per class")
    parser.add_argument("--data-folder", help="path ot omniglot folder")
    parser.add_argument("--batch-size",
                        type=int,
                        help="Batch size: This is equal to the\
                                        number of tasks per episode")
    parser.add_argument("--inner-update-lr",
                        default=0.4,
                        type=float,
                        help="Learning rate for the inner update")
    parser.add_argument("--meta-lr",
                        type=float,
                        default=0.001,
                        help="Learning rate for the meta learner")
    parser.add_argument("--num-meta-train-iterations",
                        type=int,
                        default=1000,
                        help="Number pf meta training iterations")
    parser.add_argument(
        "--num-inner-updates",
        type=int,
        default=1,
        help="Number of inner gradient steps, during train time")
    parser.add_argument(
        "--meta-test-num-inner-updates",
        type=int,
        default=1,
        help="Number of inner gradient steps during meta test time")
    parser.add_argument("--dim-hidden",
                        type=int,
                        default=16,
                        help="Number of convlution filters")
    parser.add_argument("--num-meta-test-classes",
                        type=int,
                        help="Number of classes in meta test time")
    parser.add_argument("--num-meta-test-samples-per-class",
                        type=int,
                        help="Number of samples per class, during test time")
    parser.add_argument("--num-meta-validation-iterations",
                        type=int,
                        help="Number of epsiodes for validation.")
    parser.add_argument("--num-meta-test-iterations",
                        type=int,
                        help="Number of iterations during meta test time")
    parser.add_argument("--validation-frequency",
                        type=int,
                        help="Validation Frequency")
    parser.add_argument("--device", default="cuda")
    args = parser.parse_args()
    model = build_network(args)
    model.to(args.device)
    maml = MAML(args, model)
    maml.train()
    maml.test()
コード例 #2
0
loss_fn = torch.nn.CrossEntropyLoss().to(device)
model_path = "./model/"
result_path = "./log/train"

trainiter = iter(trainloader)
evaliter = iter(testloader)

train_loss_log = []
train_acc_log = []
test_loss_log = []
test_acc_log = []

for epoch in range(epochs):
    # train
    trainbatch = trainiter.next()
    model.train()
    loss, acc = adaptation(model,
                           optimizer,
                           trainbatch,
                           loss_fn,
                           lr=0.01,
                           train_step=5,
                           train=True,
                           device=device)

    train_loss_log.append(loss.item())
    train_acc_log.append(acc)

    # test
    evalbatch = evaliter.next()
    model.eval()
コード例 #3
0
ファイル: train_recall.py プロジェクト: nikitadhawan/arm
def main():

    if torch.cuda.is_available():
        args.device = torch.device('cuda')
        args.cuda = True
    else:
        args.device = torch.device('cpu')
        args.cuda = False

    # Make as reproducible as possible.
    # Please note that pytorch does not let us make things completely reproducible across machines.
    # See https://pytorch.org/docs/stable/notes/randomness.html
    if args.seed is not None:
        print('setting seed', args.seed)
        torch.manual_seed(args.seed)
        if args.cuda:
            torch.cuda.manual_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    # Get data
    train_loader, train_eval_loader, val_loader, _ = data.get_loaders(args)
    z_loader = data.get_z_loader(args)
    args.n_groups = train_loader.dataset.n_groups

    # Get model
    if args.dataset == 'mnist':
        num_classes = 10
    elif args.dataset in 'celeba':
        num_classes = 4
    elif args.dataset == 'femnist':
        num_classes = 62
    model = utils.MetaConvModel(train_loader.dataset.image_shape[0],
                                num_classes,
                                hidden_size=128,
                                feature_size=128)
    z_model = utils.MetaConvModel(train_loader.dataset.image_shape[0],
                                  train_loader.dataset.n_groups,
                                  hidden_size=128,
                                  feature_size=128)
    #     model = utils.get_model(args, image_shape=train_loader.dataset.image_shape)

    # Loss Fn
    loss_fn = nn.CrossEntropyLoss()

    # Optimizer
    if args.optimizer == 'adam':  # This is used for MNIST.
        optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

        z_optimizer = torch.optim.Adam(z_model.parameters(), lr=1e-3)
    elif args.optimizer == 'sgd':

        # From DRNN paper
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                           model.parameters()),
                                    lr=args.learning_rate,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)

        z_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                             z_model.parameters()),
                                      lr=args.learning_rate,
                                      momentum=0.9,
                                      weight_decay=args.weight_decay)

#     import ipdb; ipdb.set_trace()
    maml_model = MAML(model,
                      z_model,
                      z_loader,
                      optimizer,
                      z_optimizer,
                      num_adaptation_steps=args.num_steps,
                      step_size=args.step_size,
                      loss_function=loss_fn,
                      device=args.device)

    z_epochs = 50
    for epoch in trange(z_epochs):
        loss, accuracy = maml_model.train_z_iter(train_loader)
        print(epoch, 'epoch ,', loss, 'z_loss ,', accuracy, 'z_accuracy')

    # Train loop
    best_worst_case_acc = 0
    best_worst_case_acc_epoch = 0
    avg_val_acc = 0
    empirical_val_acc = 0

    for epoch in trange(args.num_epochs):

        train_results = maml_model.train(train_loader,
                                         verbose=True,
                                         desc='Training',
                                         leave=False)

        # Decay learning rate after one epoch
        if args.use_lr_schedule:
            if (args.dataset == 'celeba' and epoch == 0):
                for param_group in optimizer.param_groups:
                    param_group['lr'] = 1e-5


#         import ipdb; ipdb.set_trace()
        if epoch % args.epochs_per_eval == 0:

            # validation
            worst_case_acc, stats = maml_model.evaluate(
                val_loader,
                epoch=epoch,
                log_wandb=args.log_wandb,
                n_samples_per_dist=args.n_test_per_dist,
                split='val')

            # Track early stopping values with respect to worst case.
            if worst_case_acc > best_worst_case_acc:
                best_worst_case_acc = worst_case_acc

                save_model(model, ckpt_dir, epoch, args.device)

            # Log early stopping values
            if args.log_wandb:
                wandb.log({
                    "Train Loss": train_results['mean_outer_loss'],
                    "Best Worst Case Val Acc": best_worst_case_acc,
                    "Train Accuracy": train_results['accuracy_after'],
                    "epoch": epoch
                })

            print(f"Epoch: ", epoch, "Worst Case Acc: ", worst_case_acc)