def main(args):
    utils.set_seed(args.seed)

    my_experiment = experiment(args.name,
                               args,
                               "results/",
                               commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')

    # Using first 963 classes of the omniglot as the meta-training set
    args.classes = list(range(963))

    args.traj_classes = list(range(int(963 / 2), 963))

    dataset = df.DatasetFactory.get_dataset(args.dataset,
                                            background=True,
                                            train=True,
                                            all=True)
    dataset_test = df.DatasetFactory.get_dataset(args.dataset,
                                                 background=True,
                                                 train=False,
                                                 all=True)

    # Iterators used for evaluation
    iterator_test = torch.utils.data.DataLoader(dataset_test,
                                                batch_size=5,
                                                shuffle=True,
                                                num_workers=1)

    iterator_train = torch.utils.data.DataLoader(dataset,
                                                 batch_size=5,
                                                 shuffle=True,
                                                 num_workers=1)

    sampler = ts.SamplerFactory.get_sampler(args.dataset, args.classes,
                                            dataset, dataset_test)

    config = mf.ModelFactory.get_model("na", args.dataset)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    maml = MetaLearingClassification(args, config).to(device)

    utils.freeze_layers(args.rln, maml)

    for step in range(args.steps):

        t1 = np.random.choice(args.traj_classes, args.tasks, replace=False)

        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))

        d_rand_iterator = sampler.get_complete_iterator()

        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(
            d_traj_iterators,
            d_rand_iterator,
            steps=args.update_step,
            reset=not args.no_reset)
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(
            ), x_qry.cuda(), y_qry.cuda()

        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)

        # Evaluation during training for sanity checks
        if step % 40 == 39:
            writer.add_scalar('/metatrain/train/accuracy', accs[-1], step)
            logger.info('step: %d \t training acc %s', step, str(accs))
        if step % 300 == 299:
            utils.log_accuracy(maml, my_experiment, iterator_test, device,
                               writer, step)
            utils.log_accuracy(maml, my_experiment, iterator_train, device,
                               writer, step)
def main(args):
    utils.set_seed(args.seed)

    my_experiment = experiment(args.name,
                               args,
                               "../results/",
                               commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger("experiment")

    # Using first 963 classes of the omniglot as the meta-training set
    args.classes = list(range(963))

    if torch.cuda.is_available():
        device = torch.device("cuda")
        use_cuda = True
    else:
        device = torch.device("cpu")
        use_cuda = False

    dataset = df.DatasetFactory.get_dataset(
        args.dataset,
        background=True,
        train=True,
        all=True,
        prefetch_gpu=args.prefetch_gpu,
        device=device,
    )
    dataset_test = dataset
    # dataset_test = df.DatasetFactory.get_dataset(
    #     args.dataset, background=True, train=False, all=True
    # )

    # Iterators used for evaluation
    iterator_test = torch.utils.data.DataLoader(dataset_test,
                                                batch_size=5,
                                                shuffle=True,
                                                num_workers=1)

    iterator_train = torch.utils.data.DataLoader(dataset,
                                                 batch_size=5,
                                                 shuffle=True,
                                                 num_workers=1)

    sampler = ts.SamplerFactory.get_sampler(
        args.dataset,
        args.classes,
        dataset,
        dataset_test,
        prefetch_gpu=args.prefetch_gpu,
        use_cuda=use_cuda,
    )

    config = mf.ModelFactory.get_model(args.treatment, args.dataset)

    maml = MetaLearingClassification(args, config, args.treatment).to(device)

    if args.checkpoint:
        checkpoint = torch.load(args.saved_model, map_location="cpu")

        for idx in range(len(checkpoint)):
            maml.net.parameters()[idx].data = checkpoint.parameters()[idx].data

    maml = maml.to(device)

    utils.freeze_layers(args.rln, maml)

    for step in range(args.steps):

        t1 = np.random.choice(args.classes, args.tasks, replace=False)

        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))

        d_rand_iterator = sampler.get_complete_iterator()

        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(
            d_traj_iterators,
            d_rand_iterator,
            steps=args.update_step,
            reset=not args.no_reset,
        )
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = (
                x_spt.cuda(),
                y_spt.cuda(),
                x_qry.cuda(),
                y_qry.cuda(),
            )

        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)  # , args.tasks)

        # Evaluation during training for sanity checks
        if step % 40 == 0:
            # writer.add_scalar('/metatrain/train/accuracy', accs, step)
            logger.info("step: %d \t training acc %s", step, str(accs))
        if step % 100 == 0 or step == 19999:
            torch.save(maml.net, args.model_name)
        if step % 2000 == 0 and step != 0:
            utils.log_accuracy(maml, my_experiment, iterator_test, device,
                               writer, step)
            utils.log_accuracy(maml, my_experiment, iterator_train, device,
                               writer, step)
Beispiel #3
0
def main():
    p = class_parser.Parser()
    total_seeds = len(p.parse_known_args()[0].seed)
    rank = p.parse_known_args()[0].rank
    all_args = vars(p.parse_known_args()[0])
    print("All args = ", all_args)

    args = utils.get_run(vars(p.parse_known_args()[0]), rank)

    utils.set_seed(args['seed'])

    my_experiment = experiment(args['name'],
                               args,
                               "../results/",
                               commit_changes=False,
                               rank=0,
                               seed=1)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')

    # Using first 963 classes of the omniglot as the meta-training set
    args['classes'] = list(range(963))

    # args['traj_classes'] = list(range(int(963/2), 963))
    args['traj_classes'] = list(range(963))

    dataset = df.DatasetFactory.get_dataset(args['dataset'],
                                            background=True,
                                            train=True,
                                            path=args["path"],
                                            all=True)
    dataset_test = df.DatasetFactory.get_dataset(args['dataset'],
                                                 background=True,
                                                 train=False,
                                                 path=args["path"],
                                                 all=True)

    # Iterators used for evaluation
    iterator_test = torch.utils.data.DataLoader(dataset_test,
                                                batch_size=5,
                                                shuffle=True,
                                                num_workers=1)

    iterator_train = torch.utils.data.DataLoader(dataset,
                                                 batch_size=5,
                                                 shuffle=True,
                                                 num_workers=1)

    sampler = ts.SamplerFactory.get_sampler(args['dataset'], args['classes'],
                                            dataset, dataset_test)

    config = mf.ModelFactory.get_model("na",
                                       args['dataset'],
                                       output_dimension=1000)

    gpu_to_use = rank % args["gpus"]
    if torch.cuda.is_available():
        device = torch.device('cuda:' + str(gpu_to_use))
        logger.info("Using gpu : %s", 'cuda:' + str(gpu_to_use))
    else:
        device = torch.device('cpu')

    maml = MetaLearingClassification(args, config).to(device)

    for step in range(args['steps']):

        t1 = np.random.choice(args['traj_classes'],
                              args['tasks'],
                              replace=False)

        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))

        d_rand_iterator = sampler.get_complete_iterator()

        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(
            d_traj_iterators,
            d_rand_iterator,
            steps=args['update_step'],
            reset=not args['no_reset'])
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(
                device), x_qry.to(device), y_qry.to(device)

        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)

        # Evaluation during training for sanity checks
        if step % 40 == 5:
            writer.add_scalar('/metatrain/train/accuracy', accs[-1], step)
            logger.info('step: %d \t training acc %s', step, str(accs))
        if step % 300 == 3:
            utils.log_accuracy(maml, my_experiment, iterator_test, device,
                               writer, step)
            utils.log_accuracy(maml, my_experiment, iterator_train, device,
                               writer, step)
Beispiel #4
0
def main(args):
    # Placeholder variables
    old_accs = [0]
    old_meta_losses = [2.**30, 0]

    utils.set_seed(args.seed)

    my_experiment = experiment(args.name,
                               args,
                               "./results/",
                               commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")
    logger = logging.getLogger('experiment')

    # Using first 963 classes of the omniglot as the meta-training set
    args.classes = list(range(963))
    args.traj_classes = list(range(int(963 / 2), 963))

    dataset = df.DatasetFactory.get_dataset(args.dataset,
                                            background=True,
                                            train=True,
                                            all=True)
    dataset_test = df.DatasetFactory.get_dataset(args.dataset,
                                                 background=True,
                                                 train=False,
                                                 all=True)
    # print("ONE ITEM", len(dataset.__getitem__(0)),dataset.__getitem__(0)[0].shape,dataset.__getitem__(0)[1])
    # Iterators used for evaluation
    iterator_test = torch.utils.data.DataLoader(dataset_test,
                                                batch_size=5,
                                                shuffle=True,
                                                num_workers=1)
    iterator_train = torch.utils.data.DataLoader(dataset,
                                                 batch_size=5,
                                                 shuffle=True,
                                                 num_workers=1)
    sampler = ts.SamplerFactory.get_sampler(args.dataset, args.classes,
                                            dataset, dataset_test)
    # print("NUM CLASSES",args.classes)
    config = mf.ModelFactory.get_model("na", args.dataset)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    # maml = MetaLearingClassification(args, config).to(device)
    maml = MetaLearingClassification(args, config).to(device)
    utils.freeze_layers(args.rln, maml)  # freeze layers

    for step in range(args.steps):  #epoch
        print("STEP: ", step)
        t1 = np.random.choice(args.traj_classes, args.tasks,
                              replace=False)  #sample sine waves
        # print("TRAJ CLASSES<",args.tasks)
        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))
        # print("ANNOYINGNESS",d_traj_iterators)
        d_rand_iterator = sampler.get_complete_iterator()

        # Sample trajectory and random batch (support and query)
        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(
            d_traj_iterators,
            d_rand_iterator,
            steps=args.update_step,
            reset=not args.no_reset)
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(
            ), x_qry.cuda(), y_qry.cuda()

        # One training loop
        accs, loss = maml(x_spt, y_spt, x_qry, y_qry, step, old_accs,
                          old_meta_losses, args, config)

        # if loss[-2] >= old_meta_losses[-2]: #if training improves it,
        #     maml.set_self(other.get_self_state_dict())
        #     old_meta_losses = loss

        # else: #if not improved
        #     other.set_self(maml.get_self_state_dict())

        # Evaluation during training for sanity checks
        if step % 40 == 39:
            writer.add_scalar('/metatrain/train/accuracy', accs[-1], step)
            logger.info('step: %d \t training acc %s', step, str(accs))
        if step % 300 == 299:
            utils.log_accuracy(maml, my_experiment, iterator_test, device,
                               writer, step)
            utils.log_accuracy(maml, my_experiment, iterator_train, device,
                               writer, step)

        torch.save(maml.net, my_experiment.path + "omniglot_classifier.model")