Пример #1
0
def main(args):
    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
    device = torch.device(
        'cuda' if args.use_cuda and torch.cuda.is_available() else 'cpu')

    if (args.output_folder is not None):
        if not os.path.exists(args.output_folder):
            os.makedirs(args.output_folder)
            logging.debug('Creating folder `{0}`'.format(args.output_folder))

        folder = os.path.join(args.output_folder,
                              time.strftime('%Y-%m-%d_%H%M%S'))
        os.makedirs(folder)
        logging.debug('Creating folder `{0}`'.format(folder))

        args.folder = os.path.abspath(args.folder)
        args.model_path = os.path.abspath(os.path.join(folder, 'model.th'))
        # Save the configuration in a config.json file
        with open(os.path.join(folder, 'config.json'), 'w') as f:
            json.dump(vars(args), f, indent=2)
        logging.info('Saving configuration file in `{0}`'.format(
            os.path.abspath(os.path.join(folder, 'config.json'))))

    dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=args.num_shots,
                                      num_test_per_class=args.num_shots_test)
    class_augmentations = [Rotation([90, 180, 270])]
    if args.dataset == 'sinusoid':
        transform = ToTensor()

        meta_train_dataset = Sinusoid(args.num_shots + args.num_shots_test,
                                      num_tasks=1000000,
                                      transform=transform,
                                      target_transform=transform,
                                      dataset_transform=dataset_transform)
        meta_val_dataset = Sinusoid(args.num_shots + args.num_shots_test,
                                    num_tasks=1000000,
                                    transform=transform,
                                    target_transform=transform,
                                    dataset_transform=dataset_transform)

        model = ModelMLPSinusoid(hidden_sizes=[40, 40])
        loss_function = F.mse_loss

    elif args.dataset == 'omniglot':
        transform = Compose([Resize(28), ToTensor()])

        meta_train_dataset = Omniglot(args.folder,
                                      transform=transform,
                                      target_transform=Categorical(
                                          args.num_ways),
                                      num_classes_per_task=args.num_ways,
                                      meta_train=True,
                                      class_augmentations=class_augmentations,
                                      dataset_transform=dataset_transform,
                                      download=True)
        meta_val_dataset = Omniglot(args.folder,
                                    transform=transform,
                                    target_transform=Categorical(
                                        args.num_ways),
                                    num_classes_per_task=args.num_ways,
                                    meta_val=True,
                                    class_augmentations=class_augmentations,
                                    dataset_transform=dataset_transform)

        model = ModelConvOmniglot(args.num_ways, hidden_size=args.hidden_size)
        loss_function = F.cross_entropy

    elif args.dataset == 'miniimagenet':
        transform = Compose([Resize(84), ToTensor()])

        meta_train_dataset = MiniImagenet(
            args.folder,
            transform=transform,
            target_transform=Categorical(args.num_ways),
            num_classes_per_task=args.num_ways,
            meta_train=True,
            class_augmentations=class_augmentations,
            dataset_transform=dataset_transform,
            download=True)
        meta_val_dataset = MiniImagenet(
            args.folder,
            transform=transform,
            target_transform=Categorical(args.num_ways),
            num_classes_per_task=args.num_ways,
            meta_val=True,
            class_augmentations=class_augmentations,
            dataset_transform=dataset_transform)

        model = ModelConvMiniImagenet(args.num_ways,
                                      hidden_size=args.hidden_size)
        loss_function = F.cross_entropy

    else:
        raise NotImplementedError('Unknown dataset `{0}`.'.format(
            args.dataset))

    meta_train_dataloader = BatchMetaDataLoader(meta_train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=args.num_workers,
                                                pin_memory=True)
    meta_val_dataloader = BatchMetaDataLoader(meta_val_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.num_workers,
                                              pin_memory=True)

    meta_optimizer = torch.optim.Adam(model.parameters(), lr=args.meta_lr)
    metalearner = ModelAgnosticMetaLearning(
        model,
        meta_optimizer,
        first_order=args.first_order,
        num_adaptation_steps=args.num_steps,
        step_size=args.step_size,
        loss_function=loss_function,
        device=device)

    best_val_accuracy = None

    # Training loop
    epoch_desc = 'Epoch {{0: <{0}d}}'.format(1 +
                                             int(math.log10(args.num_epochs)))
    for epoch in range(args.num_epochs):
        metalearner.train(meta_train_dataloader,
                          max_batches=args.num_batches,
                          verbose=args.verbose,
                          desc='Training',
                          leave=False)
        results = metalearner.evaluate(meta_val_dataloader,
                                       max_batches=args.num_batches,
                                       verbose=args.verbose,
                                       desc=epoch_desc.format(epoch + 1))

        if (best_val_accuracy is None) \
                or (best_val_accuracy < results['accuracies_after']):
            best_val_accuracy = results['accuracies_after']
            if args.output_folder is not None:
                with open(args.model_path, 'wb') as f:
                    torch.save(model.state_dict(), f)

    if hasattr(meta_train_dataset, 'close'):
        meta_train_dataset.close()
        meta_val_dataset.close()
Пример #2
0
def main(args):
    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
    device = torch.device('cuda' if args.use_cuda
                          and torch.cuda.is_available() else 'cpu')

    # Set up output folder, which will contain the saved model
    # and the config file for running test.py
    if (args.output_folder is not None):
        if not os.path.exists(args.output_folder):
            os.makedirs(args.output_folder)
            logging.debug('Creating output folder `{0}`'.format(args.output_folder))

        folder = os.path.join(args.output_folder,
                              time.strftime('%Y-%m-%d_%H%M%S'))
        os.makedirs(folder)
        logging.debug('Creating folder `{0}`'.format(folder))

        args.folder = os.path.abspath(args.folder)
        args.model_path = os.path.abspath(os.path.join(folder, 'model.th'))
        
        save_folder = os.path.abspath(folder)
        ckpt_folder = os.path.join(save_folder, 'checkpoints')
        if not os.path.exists(ckpt_folder):
            os.makedirs(ckpt_folder)
            logging.debug('Creating model checkpoint folder `{0}`'.format(ckpt_folder))

        # Save the configuration in a config.json file
        with open(os.path.join(folder, 'config.json'), 'w') as f:
            json.dump(vars(args), f, indent=2)
        logging.info('Saving configuration file in `{0}`'.format(
                     os.path.abspath(os.path.join(folder, 'config.json'))))

    # Load a pre-configured dataset, model, and loss function.
    benchmark = get_benchmark_by_name(args.dataset,
                                      args.folder,
                                      args.num_ways,
                                      args.num_shots,
                                      args.num_shots_test,
                                      hidden_size=args.hidden_size)

    # Set up dataloaders:
    # MetaDataset (collection of Tasks) > Task (iterable Dataset of OrderedDicts) 
    #   > task[i] (OrderedDict with shuffled train/test split) > (tuples of input & target tensors)

    # Train loader yields batches of tasks for meta-training (both inner and outer loop)
    meta_train_dataloader = BatchMetaDataLoader(benchmark.meta_train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=args.num_workers,
                                                pin_memory=True)

    # Val loader has the same format, but is used only for evaluating adaptation ability
    # without taking gradient steps on the outer loss.
    meta_val_dataloader = BatchMetaDataLoader(benchmark.meta_val_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.num_workers,
                                              pin_memory=True)

    # Initializer the meta-optimizer and metalearner (MAML)
    meta_optimizer = torch.optim.Adam(benchmark.model.parameters(), lr=args.meta_lr)
    metalearner = ModelAgnosticMetaLearning(benchmark.model,
                                            meta_optimizer,
                                            first_order=args.first_order,
                                            num_adaptation_steps=args.num_steps,
                                            step_size=args.step_size,
                                            loss_function=benchmark.loss_function,
                                            device=device)

    best_value = None

    def append_results_to_dict(d, results):
        for key, val in results.items():
            d[key] = d.get(key, []) + [val]

    # Training loop: each epoch goes through all tasks in the entire dataset, in batches
    epoch_desc = 'Epoch {{0: <{0}d}}'.format(1 + int(math.log10(args.num_epochs)))
    train_stats, val_stats = {}, {}
    for epoch in range(args.num_epochs):
        train_results = metalearner.train(meta_train_dataloader,
                            max_batches=args.num_batches,
                            verbose=args.verbose,
                            desc='Training',
                            leave=False)
        append_results_to_dict(train_stats, train_results)

        if epoch % args.validate_every == 0:
            val_results = metalearner.evaluate(meta_val_dataloader,
                                max_batches=args.num_batches,
                                verbose=args.verbose,
                                desc=epoch_desc.format(epoch + 1))
            append_results_to_dict(train_stats, train_results)

        if epoch % args.checkpoint_every == 0:
            ckpt_path = os.path.join(ckpt_folder, f'checkpoint-{epoch}.pt')
            with open(ckpt_path, 'wb') as f:
                torch.save(benchmark.model.state_dict(), f)

        # Save best model according to validation acc/loss
        if 'mean_accuracy_after' in val_results:
            if (best_value is None) or (best_value < val_results['mean_accuracy_after']):
                best_value = val_results['mean_accuracy_after']
                save_model = True
        elif (best_value is None) or (best_value > val_results['mean_outer_loss']):
            best_value = val_results['mean_outer_loss']
            save_model = True
        else:
            save_model = False

        if save_model and (args.output_folder is not None):
            with open(args.model_path, 'wb') as f:
                torch.save(benchmark.model.state_dict(), f)

    # Save train and val stats as serialized dictionaries
    with open(os.path.join(save_folder, 'train_stats.pkl'), 'wb') as f:
        pickle.dump(train_stats, f)
    with open(os.path.join(save_folder, 'val_stats.pkl'), 'wb') as f:
        pickle.dump(val_stats, f)

    if hasattr(benchmark.meta_train_dataset, 'close'):
        benchmark.meta_train_dataset.close()
        benchmark.meta_val_dataset.close()
def main(args):
    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
    device = torch.device(
        'cuda' if args.use_cuda and torch.cuda.is_available() else 'cpu')

    if (args.output_folder is not None):
        if not os.path.exists(args.output_folder):
            os.makedirs(args.output_folder)
            logging.debug('Creating folder `{0}`'.format(args.output_folder))

        folder = os.path.join(args.output_folder,
                              time.strftime('%Y-%m-%d_%H%M%S'))
        os.makedirs(folder)
        logging.debug('Creating folder `{0}`'.format(folder))

        args.folder = os.path.abspath(args.folder)
        args.model_path = os.path.abspath(os.path.join(folder, 'model.th'))
        # Save the configuration in a config.json file
        with open(os.path.join(folder, 'config.json'), 'w') as f:
            json.dump(vars(args), f, indent=2)
        logging.info('Saving configuration file in `{0}`'.format(
            os.path.abspath(os.path.join(folder, 'config.json'))))

    benchmark = get_benchmark_by_name(args.dataset,
                                      args.folder,
                                      args.num_ways,
                                      args.num_shots,
                                      args.num_shots_test,
                                      hidden_size=args.hidden_size)

    meta_train_dataloader = BatchMetaDataLoader(benchmark.meta_train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=args.num_workers,
                                                pin_memory=True)
    meta_val_dataloader = BatchMetaDataLoader(benchmark.meta_val_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.num_workers,
                                              pin_memory=True)

    meta_optimizer = torch.optim.Adam(benchmark.model.parameters(),
                                      lr=args.meta_lr)
    metalearner = ModelAgnosticMetaLearning(
        benchmark.model,
        meta_optimizer,
        first_order=args.first_order,
        num_adaptation_steps=args.num_steps,
        step_size=args.step_size,
        loss_function=benchmark.loss_function,
        device=device)

    best_value = None

    # Training loop
    all_results = []
    epoch_desc = 'Epoch {{0: <{0}d}}'.format(1 +
                                             int(math.log10(args.num_epochs)))
    for epoch in range(args.num_epochs):
        metalearner.train(meta_train_dataloader,
                          max_batches=args.num_batches,
                          verbose=args.verbose,
                          desc='Training',
                          leave=False)
        results = metalearner.evaluate(meta_val_dataloader,
                                       max_batches=args.num_batches,
                                       verbose=args.verbose,
                                       desc=epoch_desc.format(epoch + 1))
        print(results['accuracies_after'])
        all_results.append(results['accuracies_after'])
        with open(os.path.join(folder, 'results.json'), 'w') as f:
            json.dump(all_results, f, indent=2)

        # Save best model
        if 'accuracies_after' in results:
            if (best_value is None) or (best_value <
                                        results['accuracies_after']):
                best_value = results['accuracies_after']
                save_model = True
        elif (best_value is None) or (best_value > results['mean_outer_loss']):
            best_value = results['mean_outer_loss']
            save_model = True
        else:
            save_model = False

        if save_model and (args.output_folder is not None):
            with open(args.model_path, 'wb') as f:
                torch.save(benchmark.model.state_dict(), f)
    print(all_results)

    if hasattr(benchmark.meta_train_dataset, 'close'):
        benchmark.meta_train_dataset.close()
        benchmark.meta_val_dataset.close()
Пример #4
0
def main(args):
    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)

    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed(args.random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    device = torch.device(
        'cuda' if args.use_cuda and torch.cuda.is_available() else 'cpu')

    if (args.output_folder is not None):
        if not os.path.exists(args.output_folder):
            os.makedirs(args.output_folder)
            logging.debug('Creating folder `{0}`'.format(args.output_folder))

        folder = os.path.join(args.output_folder,
                              time.strftime('%Y-%m-%d_%H%M%S'))
        os.makedirs(folder)
        logging.debug('Creating folder `{0}`'.format(folder))

        args.folder = os.path.abspath(args.folder)
        args.model_path = os.path.abspath(os.path.join(folder, 'model.th'))
        outfile_path = os.path.abspath(
            os.path.join(folder, 'model_results.json'))
        # Save the configuration in a config.json file
        with open(os.path.join(folder, 'config.json'), 'w') as f:
            json.dump(vars(args), f, indent=2)
        logging.info('Saving configuration file in `{0}`'.format(
            os.path.abspath(os.path.join(folder, 'config.json'))))

    benchmark = get_benchmark_by_name(
        args.dataset,
        args.folder,
        args.num_ways,
        args.num_shots,
        args.num_shots_test,
        hidden_size=args.hidden_size,
        random_seed=args.random_seed,
        num_training_samples=args.num_training_samples)

    meta_train_dataloader = BatchMetaDataLoader(benchmark.meta_train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=args.num_workers,
                                                pin_memory=True)
    meta_val_dataloader = BatchMetaDataLoader(benchmark.meta_val_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.num_workers,
                                              pin_memory=True)

    meta_optimizer = torch.optim.Adam(benchmark.model.parameters(),
                                      lr=args.meta_lr)
    metalearner = ModelAgnosticMetaLearning(
        benchmark.model,
        meta_optimizer,
        first_order=args.first_order,
        num_adaptation_steps=args.num_steps,
        step_size=args.step_size,
        loss_function=benchmark.loss_function,
        device=device)
    #print(benchmark.model)

    best_value = None
    output = []
    pretty_print('epoch', 'train loss', 'train acc', 'train prec', 'val loss',
                 'val acc', 'val prec')

    # Training loop
    epoch_desc = 'Epoch {{0: <{0}d}}'.format(1 +
                                             int(math.log10(args.num_epochs)))
    for epoch in range(args.num_epochs):
        train_results = metalearner.train(meta_train_dataloader,
                                          max_batches=args.num_batches,
                                          verbose=args.verbose,
                                          desc='Training',
                                          leave=False)
        val_results = metalearner.evaluate(meta_val_dataloader,
                                           max_batches=args.num_batches,
                                           verbose=args.verbose,
                                           desc=epoch_desc.format(epoch + 1))
        pretty_print(
            (epoch + 1), train_results['mean_outer_loss'],
            train_results['accuracies_after'],
            train_results['precision_after'], val_results['mean_outer_loss'],
            val_results['accuracies_after'], val_results['precision_after'])

        # Save best model
        if 'accuracies_after' in val_results:
            if (best_value is None) or (best_value <
                                        val_results['accuracies_after']):
                best_value = val_results['accuracies_after']
                save_model = True
        elif (best_value is None) or (best_value >
                                      val_results['mean_outer_loss']):
            best_value = val_results['mean_outer_loss']
            save_model = True
        else:
            save_model = False

        if save_model and (args.output_folder is not None):
            with open(args.model_path, 'wb') as f:
                torch.save(benchmark.model.state_dict(), f)
        # saving results for later use - plotting, etc.
        output.append({
            'epoch': (epoch + 1),
            'train_loss': train_results['mean_outer_loss'],
            'train_acc': train_results['accuracies_after'],
            'train_prec': train_results['precision_after'],
            'val_loss': val_results['mean_outer_loss'],
            'val_acc': val_results['accuracies_after'],
            'val_prec': val_results['precision_after']
        })
        if (args.output_folder is not None):
            with open(outfile_path, 'w') as f:
                json.dump(output, f)

    if hasattr(benchmark.meta_train_dataset, 'close'):
        benchmark.meta_train_dataset.close()
        benchmark.meta_val_dataset.close()
Пример #5
0
def main(args):
    logging.basicConfig(level=logging.INFO if args.silent else logging.DEBUG)
    device = torch.device(
        'cuda' if args.use_cuda and torch.cuda.is_available() else 'cpu')

    if not path.exists(args.output_folder):
        os.makedirs(args.output_folder)
        logging.debug('Creating folder `{0}`'.format(args.output_folder))

    if args.run_name is None:
        args.run_name = time.strftime('%Y-%m-%d_%H%M%S')

    folder = path.join(args.output_folder, args.run_name)
    os.makedirs(folder, exist_ok=False)
    logging.debug('Creating folder `{0}`'.format(folder))

    args.folder = path.abspath(args.folder)
    args.model_path = path.abspath(path.join(folder, 'model.th'))
    # Save the configuration in a config.json file
    with open(path.join(folder, 'config.json'), 'w') as f:
        stored_args = argparse.Namespace(**vars(args))
        stored_args.folder = path.relpath(stored_args.folder, folder)
        stored_args.model_path = path.relpath(stored_args.model_path, folder)
        json.dump(vars(stored_args), f, indent=2)
    logging.info('Saving configuration file in `{0}`'.format(
        path.abspath(path.join(folder, 'config.json'))))

    benchmark = get_benchmark_by_name(args.dataset,
                                      args.folder,
                                      args.num_ways,
                                      args.num_shots,
                                      args.num_shots_test,
                                      args.no_max_pool,
                                      hidden_size=args.hidden_size)

    meta_train_dataloader = BatchMetaDataLoaderWithLabels(
        benchmark.meta_train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True)
    meta_val_dataloader = BatchMetaDataLoader(benchmark.meta_val_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.num_workers,
                                              pin_memory=True)
    meta_optimizer = torch.optim.Adam(benchmark.model.parameters(),
                                      lr=args.meta_lr)

    alpha = weighting.get_param_strategy(args.spsa_alpha_strategy,
                                         args.spsa_alpha,
                                         args.spsa_alpha_exp_gamma,
                                         args.spsa_alpha_step_step_every,
                                         args.spsa_alpha_step_multiplier)
    beta = weighting.get_param_strategy(args.spsa_beta_strategy,
                                        args.spsa_beta,
                                        args.spsa_beta_exp_gamma,
                                        args.spsa_beta_step_step_every,
                                        args.spsa_beta_step_multiplier)

    if args.task_weighting == 'none':
        task_weighting = weighting.TaskWeightingNone(device)
    elif args.task_weighting == 'spsa-delta':
        task_weighting = weighting.SpsaWeighting(args.batch_size, alpha, beta,
                                                 device)
    elif args.task_weighting == 'spsa-per-class':
        task_weighting = weighting.SpsaWeightingPerClass(
            max_classes=100,
            class_info_label='_class_ids',
            alpha=alpha,
            beta=beta,
            device=device)
    elif args.task_weighting == 'spsa-per-coarse-class':
        task_weighting = weighting.SpsaWeightingPerClass(
            max_classes=20,
            class_info_label='_coarse_class_ids',
            alpha=alpha,
            beta=beta,
            device=device)
    elif args.task_weighting == 'spsa-track':
        task_weighting = weighting.SpsaTrackWeighting(args.batch_size, alpha,
                                                      beta, device)
    elif args.task_weighting == 'sin':
        task_weighting = weighting.SinWeighting(args.batch_size, device)
    elif args.task_weighting == 'gradient':
        task_weighting = weighting.GradientWeighting(args.use_inner_optimizer,
                                                     args.batch_size,
                                                     device=device)
        meta_optimizer = torch.optim.Adam(
            list(benchmark.model.parameters()) +
            task_weighting.outer_optimization_weights,
            lr=args.meta_lr)
    elif args.task_weighting == 'gradient-novel-loss':
        task_weighting = weighting.GradientNovelLossWeighting(
            args.use_inner_optimizer, args.batch_size, device=device)
        meta_optimizer = torch.optim.Adam(
            list(benchmark.model.parameters()) +
            task_weighting.outer_optimization_weights,
            lr=args.meta_lr)
    else:
        raise ValueError(f'Unknown weighting value: {args.task_weighting}')

    metalearner = ModelAgnosticMetaLearning(
        benchmark.model,
        meta_optimizer,
        first_order=args.first_order,
        num_adaptation_steps=args.num_steps,
        step_size=args.step_size,
        loss_function=benchmark.loss_function,
        device=device)

    if args.load is not None:
        with open(args.load, 'rb') as f:
            benchmark.model.load_state_dict(torch.load(f, map_location=device))

    best_value = None

    weight_normalizer = weighting.WeightNormalizer(
        normalize_after=args.normalize_spsa_weights_after)

    # Training loop
    epoch_desc = 'Epoch {{0: <{0}d}}'.format(1 +
                                             int(math.log10(args.num_epochs)))

    for epoch in range(args.num_epochs):
        metalearner.train(meta_train_dataloader,
                          task_weighting,
                          weight_normalizer,
                          epoch,
                          max_batches=args.num_batches,
                          silent=args.silent,
                          desc='Training',
                          leave=False)
        results = metalearner.evaluate(meta_val_dataloader,
                                       max_batches=args.num_batches,
                                       silent=args.silent,
                                       desc=epoch_desc.format(epoch + 1))

        # Save best model
        save_model = False
        if (best_value is None) or (best_value < results['accuracies_after']):
            best_value = results['accuracies_after']
            save_model = True

        if save_model and (args.output_folder is not None):
            with open(args.model_path, 'wb') as f:
                torch.save(benchmark.model.state_dict(), f)

    if hasattr(benchmark.meta_train_dataset, 'close'):
        benchmark.meta_train_dataset.close()
        benchmark.meta_val_dataset.close()
Пример #6
0
def main(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    wandb.init(project='meta-analogy')
    # print(args.use_cuda)
    # print(f'CUDA IS AVAILABLE: {torch.cuda.is_available()}')
    # assert args.use_cuda
    # assert torch.cuda.is_available()

    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
    device = torch.device('cuda' if args.use_cuda
                          and torch.cuda.is_available() else 'cpu')

    if (args.output_folder is not None):
        if not os.path.exists(args.output_folder):
            os.makedirs(args.output_folder)
            logging.debug('Creating folder `{0}`'.format(args.output_folder))

        folder = os.path.join(args.output_folder,
                              time.strftime('%Y-%m-%d_%H%M%S')+args.model_type)
        os.makedirs(folder)
        logging.debug('Creating folder `{0}`'.format(folder))

        # args.folder = os.path.abspath(args.folder)
        args.model_path = os.path.abspath(os.path.join(folder, 'model.th'))
        # Save the configuration in a config.json file
        with open(os.path.join(folder, 'config.json'), 'w') as f:
            json.dump(vars(args), f, indent=2)
        logging.info('Saving configuration file in `{0}`'.format(
                     os.path.abspath(os.path.join(folder, 'config.json'))))

    dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=args.num_shots,
                                      num_test_per_class=args.num_shots_test)

    # meta_train_dataset = Analogy(num_samples_per_task=args.batch_size,
    #                         dataset_transform=dataset_transform)
    # meta_val_dataset = Analogy(num_samples_per_task=args.batch_size,
    #                         dataset_transform=dataset_transform)
    meta_train_dataset = Analogy(dataset_transform=dataset_transform)
    meta_val_dataset = Analogy(dataset_transform=dataset_transform)

    meta_train_dataloader = BatchMetaDataLoader(meta_train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=args.num_workers,
                                                pin_memory=True)
    meta_val_dataloader = BatchMetaDataLoader(meta_val_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.num_workers,
                                              pin_memory=True)

    if args.model_type == 'linear':
        model = MetaLinear(in_features=300, out_features=300)
    elif args.model_type == 'mlp1':
        model = MetaMLPModel(in_features=300, out_features=300, hidden_sizes=[500])
    elif args.model_type == 'mlp2':
        model = MetaMLPModel(in_features=300, out_features=300, hidden_sizes=[500, 500])
    else:
        raise ValueError('unrecognized model type')

    loss_function = nn.MSELoss()
    wandb.watch(model)

    meta_optimizer = torch.optim.Adam(model.parameters(), lr=args.meta_lr)
    metalearner = ModelAgnosticMetaLearning(model,
                                            meta_optimizer,
                                            first_order=args.first_order,
                                            num_adaptation_steps=args.num_steps,
                                            step_size=args.step_size,
                                            loss_function=loss_function,
                                            device=device)

    best_value = None

    # Training loop
    epoch_desc = 'Epoch {{0: <{0}d}}'.format(1 + int(math.log10(args.num_epochs)))
    for epoch in range(args.num_epochs):
        metalearner.train(meta_train_dataloader,
                          max_batches=args.num_batches,
                          verbose=args.verbose,
                          desc='Training',
                          leave=False)
        results = metalearner.evaluate(meta_val_dataloader,
                                       max_batches=args.num_batches,
                                       verbose=args.verbose,
                                       desc=epoch_desc.format(epoch + 1))
        wandb.log({'results' : results})

        # Save best model
        if 'accuracies_after' in results:
            if (best_value is None) or (best_value < results['accuracies_after']):
                best_value = results['accuracies_after']
                save_model = True
        elif (best_value is None) or (best_value > results['mean_outer_loss']):
            best_value = results['mean_outer_loss']
            save_model = True
        else:
            save_model = False

        if save_model and (args.output_folder is not None):
            with open(args.model_path, 'wb') as f:
                torch.save(model.state_dict(), f)
                torch.save(model.state_dict(), os.path.join(wandb.run.dir, 'model.pt'))

    if hasattr(meta_train_dataset, 'close'):
        meta_train_dataset.close()
        meta_val_dataset.close()
Пример #7
0
best_value = None

from maml.utils import tensors_to_device, compute_accuracy
out = next(iter(meta_train_dataloader))
out_c = tensors_to_device(out, device='cuda')

benchmark.model.load_state_dict(torch.load('model.th'))

# Training loop
epoch_desc = 'Epoch {{0: <{0}d}}'.format(1 + int(math.log10(args.num_epochs)))
for epoch in range(args.num_epochs):
    print(epoch)
    metalearner.train(meta_train_dataloader,
                      max_batches=args.num_batches,
                      verbose=args.verbose,
                      desc='Training',
                      leave=False)
    results = metalearner.evaluate(meta_val_dataloader,
                                   max_batches=args.num_batches,
                                   verbose=args.verbose,
                                   desc=epoch_desc.format(epoch + 1))

    # Save best model
    if 'accuracies_after' in results:
        if (best_value is None) or (best_value < results['accuracies_after']):
            best_value = results['accuracies_after']
            save_model = True
    elif (best_value is None) or (best_value > results['mean_outer_loss']):
        best_value = results['mean_outer_loss']
        save_model = True
Пример #8
0
def main(args):
    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
    device = torch.device(
        'cuda' if args.use_cuda and torch.cuda.is_available() else 'cpu')

    wandb.init(project="geometric-meta-learning")

    if (args.output_folder is not None):
        if not os.path.exists(args.output_folder):
            os.makedirs(args.output_folder)
            logging.debug('Creating folder `{0}`'.format(args.output_folder))

        folder = os.path.join(args.output_folder,
                              time.strftime('%Y-%m-%d_%H%M%S'))
        os.makedirs(folder)
        logging.debug('Creating folder `{0}`'.format(folder))

        args.folder = os.path.abspath(args.folder)
        args.model_path = os.path.abspath(os.path.join(folder, 'model.th'))
        # Save the configuration in a config.json file
        with open(os.path.join(folder, 'config.json'), 'w') as f:
            json.dump(vars(args), f, indent=2)
        logging.info('Saving configuration file in `{0}`'.format(
            os.path.abspath(os.path.join(folder, 'config.json'))))

    ensemble_size = 0
    if args.ensemble:
        ensemble_size = args.ensemble_size
    benchmark = get_benchmark_by_name(args.dataset,
                                      args.folder,
                                      args.num_ways,
                                      args.num_shots,
                                      args.num_shots_test,
                                      hidden_size=args.hidden_size,
                                      meta_batch_size=args.batch_size,
                                      ensemble_size=ensemble_size)

    meta_train_dataloader = BatchMetaDataLoader(benchmark.meta_train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=args.num_workers,
                                                pin_memory=True)
    meta_val_dataloader = BatchMetaDataLoader(benchmark.meta_val_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.num_workers,
                                              pin_memory=True)

    warp_model = None
    if args.warp:
        warp_model = make_warp_model(benchmark.model, constant=args.constant)
        meta_optimizer = torch.optim.Adam(benchmark.model.parameters(),
                                          lr=args.meta_lr)
        warp_meta_optimizer = torch.optim.Adam(warp_model.parameters(),
                                               lr=args.warp_lr)

        #warp_meta_optimizer = None
        #meta_optimizer = torch.optim.Adam(warp_model.parameters(), lr=args.warp_lr)
    else:
        meta_optimizer = torch.optim.Adam(benchmark.model.parameters(),
                                          lr=args.meta_lr)
        warp_meta_optimizer = None

    if args.link_ensemble:
        ensembler = nn.Identity(
        )  #TransformerEnsembler(args.num_ways, benchmark.model.feature_size)
        ensembler_optimizer = None  #torch.optim.Adam(ensembler.parameters(), lr=args.warp_lr)
    else:
        ensembler = None
        ensembler_optimizer = None

    warp_scheduler = None  #get_linear_schedule_with_warmup(warp_meta_optimizer, 200, 100000, last_epoch=-1, phase_shift=-math.pi)
    ensembler_scheduler = None  #get_linear_schedule_with_warmup(ensembler_optimizer, 200, 100000, last_epoch=-1, phase_shift=-math.pi)

    scheduler = None  #get_linear_schedule_with_warmup(meta_optimizer, 200, 100000, last_epoch=-1)

    metalearner = ModelAgnosticMetaLearning(
        benchmark.model,
        meta_optimizer,
        scheduler=scheduler,
        warp_optimizer=warp_meta_optimizer,
        warp_model=warp_model,
        warp_scheduler=warp_scheduler,
        first_order=args.first_order,
        num_adaptation_steps=args.num_steps,
        step_size=args.step_size,
        learn_step_size=False,
        loss_function=benchmark.loss_function,
        device=device,
        num_maml_steps=args.num_maml_steps,
        ensembler=ensembler,
        ensembler_optimizer=ensembler_optimizer,
        ensemble_size=ensemble_size,
        ensembler_scheduler=ensembler_scheduler)

    best_value = None

    # Training loop
    epoch_desc = 'Epoch {{0: <{0}d}}'.format(1 +
                                             int(math.log10(args.num_epochs)))
    for epoch in tqdm(range(args.num_epochs)):
        metalearner.train(meta_train_dataloader,
                          max_batches=args.num_batches,
                          verbose=args.verbose,
                          desc='Training',
                          leave=False)
        results = metalearner.evaluate(meta_val_dataloader,
                                       max_batches=args.num_eval_batches,
                                       verbose=args.verbose,
                                       desc=epoch_desc.format(epoch + 1))

        # Save best model
        if 'accuracies_after' in results:
            if (best_value is None) or (best_value <
                                        results['accuracies_after']):
                best_value = results['accuracies_after']
                save_model = True
        elif (best_value is None) or (best_value > results['mean_outer_loss']):
            best_value = results['mean_outer_loss']
            save_model = True
        else:
            save_model = False

        if save_model and (args.output_folder is not None):
            with open(args.model_path, 'wb') as f:
                torch.save(benchmark.model.state_dict(), f)

    if hasattr(benchmark.meta_train_dataset, 'close'):
        benchmark.meta_train_dataset.close()
        benchmark.meta_val_dataset.close()