Exemplo n.º 1
0
    def test_anil(self):
        meta_batch_size = 32
        max_epochs = 5
        seed = 42
        ways = 5
        shots = 5
        adaptation_lr = 5e-1

        pl.seed_everything(seed)

        # Create tasksets using the benchmark interface
        tasksets = l2l.vision.benchmarks.get_tasksets(
            "cifarfs",
            train_samples=2 * shots,
            train_ways=ways,
            test_samples=2 * shots,
            test_ways=ways,
            root="~/data",
        )

        self.assertTrue(len(tasksets) == 3)

        features = l2l.vision.models.ConvBase(output_size=64,
                                              channels=3,
                                              max_pool=True)
        features = torch.nn.Sequential(features,
                                       Lambda(lambda x: x.view(-1, 256)))
        classifier = torch.nn.Linear(256, ways)
        # init model
        anil = LightningANIL(features, classifier, adaptation_lr=adaptation_lr)
        episodic_data = EpisodicBatcher(
            tasksets.train,
            tasksets.validation,
            tasksets.test,
            epoch_length=meta_batch_size,
        )

        trainer = pl.Trainer(
            accumulate_grad_batches=meta_batch_size,
            min_epochs=max_epochs,
            max_epochs=max_epochs,
            progress_bar_refresh_rate=0,
            deterministic=True,
            weights_summary=None,
        )
        trainer.fit(anil, episodic_data)
        acc = trainer.test(
            test_dataloaders=tasksets.test,
            verbose=False,
        )
        self.assertTrue(acc[0]["valid_accuracy"] >= 0.20)
    def test_maml(self):
        meta_batch_size = 4
        max_epochs = 20
        seed = 42
        ways = 5
        shots = 5
        adaptation_lr = 5e-1

        pl.seed_everything(seed)

        # Create tasksets using the benchmark interface
        tasksets = l2l.vision.benchmarks.get_tasksets(
            "cifarfs",
            train_samples=2 * shots,
            train_ways=ways,
            test_samples=2 * shots,
            test_ways=ways,
            root="~/data",
        )

        self.assertTrue(len(tasksets) == 3)

        model = l2l.vision.models.CNN4(ways, embedding_size=32 * 4)

        # init model
        maml = LightningMAML(model, adaptation_lr=adaptation_lr)
        episodic_data = EpisodicBatcher(
            tasksets.train,
            tasksets.validation,
            tasksets.test,
            epoch_length=meta_batch_size,
        )

        trainer = pl.Trainer(
            accumulate_grad_batches=meta_batch_size,
            min_epochs=max_epochs,
            max_epochs=max_epochs,
            progress_bar_refresh_rate=0,
            deterministic=True,
            weights_summary=None,
        )
        trainer.fit(maml, episodic_data)
        acc = trainer.test(
            test_dataloaders=tasksets.test,
            verbose=False,
        )
        self.assertTrue(acc[0]["valid_accuracy"] >= 0.20)
Exemplo n.º 3
0
    def eval_func(hyperparams):
        """this function runs a complete train/valid/testing loop
        for an encoder model, given the specified hyperparameters, 
        it returns the AUC on the test set

        Defining it inside the tuning function to avoid having to manually pass the train_df and valid_df,
                Ax tuning doesn't allow that
        """

        #first, run the metalearning
        lightning = LitFFNN(train_df,
                            valid_df,
                            layer_1_dim=hyperparams['layer_1_size'],
                            layer_2_dim=hyperparams['layer_2_size'],
                            layer_3_dim=hyperparams['layer_3_size'],
                            learning_rate=hyperparams['learning_rate'],
                            batch_size=50,
                            is_marker=is_marker,
                            dropout=hyperparams['dropout'])

        #build LightningMAML object
        maml = LightningMAML(
            lightning.model,
            lr=hyperparams['learning_rate'],
            adaptation_lr=hyperparams['adaptation_lr'],
            train_ways=2,
            test_ways=2,
            train_shots=hyperparams['k'],
            test_shots=hyperparams['k'],
            train_queries=hyperparams['k'],
            test_queries=hyperparams['k'],
            adaptation_steps=hyperparams['k']
        )  #doesn't necessarily need to be 'k' across the board
        # this does simplify parameter search though

        # build metalearning dataloaders
        train_set = build_taskset(all_trains, k=hyperparams['k'] * 2)
        val_set = build_taskset(all_valids, k=hyperparams['k'] * 2)

        for task in train_set:
            X, y = task

        episodic_data = EpisodicBatcher(train_set, val_set)

        #set up the trainer/logger/callbacks
        checkpoint_callback = ModelCheckpoint(filepath='checkpoint_dir',
                                              save_top_k=1,
                                              verbose=False,
                                              monitor='valid_loss',
                                              mode='min')

        #build trainer for MAML
        trainer = pl.Trainer(
            max_epochs=1000,
            min_epochs=50,
            progress_bar_refresh_rate=0,
            weights_summary=None,
            check_val_every_n_epoch=40,
            checkpoint_callback=checkpoint_callback,
            callbacks=[EarlyStopping(monitor='valid_loss', patience=20)])

        #run the metalearning
        trainer.fit(maml, episodic_data)

        #load metalearned parameters into standard lightning
        ## follow this by runnng the task-specific training
        maml.load_state_dict(
            torch.load(checkpoint_callback.best_model_path)['state_dict'])
        lightning.model.load_state_dict(maml.model.module.state_dict())

        #remove the log/model files ==> otherwise memory usage would really stack up
        os.system('rm ' + checkpoint_callback.best_model_path)
        os.system('rm -R lightning_logs/*')

        #setr up the trainer/logger/callbacks
        checkpoint_callback = ModelCheckpoint(filepath='checkpoint_dir',
                                              save_top_k=1,
                                              verbose=False,
                                              monitor='val_loss',
                                              mode='min')

        tube_logger = TestTubeLogger('checkpoint_dir', name='test_tube_logger')

        trainer = pl.Trainer(
            max_epochs=500,
            min_epochs=1,
            logger=tube_logger,
            progress_bar_refresh_rate=0,
            weights_summary=None,
            check_val_every_n_epoch=1,
            checkpoint_callback=checkpoint_callback,
            callbacks=[
                EarlyStopping(monitor='val_loss', patience=10)
            ])  #the patience of 20 is mentioned in the DeepMicro paper

        #run training
        trainer.fit(lightning)

        #load model from best-performing epoch
        lightning.load_state_dict(
            torch.load(checkpoint_callback.best_model_path)['state_dict'])

        #delete the files once we're done with them -- no need to store things unnecessarily
        os.system('rm ' + checkpoint_callback.best_model_path)
        os.system('rm -R checkpoint_dir/test_tube_logger/*')

        #set model to eval()
        lightning = lightning.eval()

        # Get AUC for the full setup on the test set
        test_dataset = MicroDataset(test_df, is_marker=is_marker)
        test_loader = DataLoader(test_dataset, shuffle=False, batch_size=50)

        test_preds = []
        test_trues = []

        for batch in test_loader:
            test_preds.append(lightning(batch[0]))
            test_trues.append(batch[1])

        test_preds = torch.cat(test_preds).detach().numpy()
        test_roc = roc_curve(
            torch.cat(test_trues).detach().numpy(), test_preds[:, 1])
        auc_val = auc(test_roc[0], test_roc[1])
        if math.isnan(auc_val):
            auc_val = 0

        if return_model:
            return (auc_val, lightning)

        return (auc_val)
Exemplo n.º 4
0
def main():
    parser = ArgumentParser(conflict_handler="resolve", add_help=True)
    # add model and trainer specific args
    parser = LightningPrototypicalNetworks.add_model_specific_args(parser)
    parser = LightningMetaOptNet.add_model_specific_args(parser)
    parser = LightningMAML.add_model_specific_args(parser)
    parser = LightningANIL.add_model_specific_args(parser)
    parser = pl.Trainer.add_argparse_args(parser)

    # add script-specific args
    parser.add_argument("--algorithm", type=str, default="protonet")
    parser.add_argument("--dataset", type=str, default="mini-imagenet")
    parser.add_argument("--root", type=str, default="~/data")
    parser.add_argument("--meta_batch_size", type=int, default=16)
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()
    dict_args = vars(args)

    pl.seed_everything(args.seed)

    # Create tasksets using the benchmark interface
    if False and args.dataset in ["mini-imagenet", "tiered-imagenet"]:
        data_augmentation = "lee2019"
    else:
        data_augmentation = "normalize"
    tasksets = l2l.vision.benchmarks.get_tasksets(
        name=args.dataset,
        train_samples=args.train_queries + args.train_shots,
        train_ways=args.train_ways,
        test_samples=args.test_queries + args.test_shots,
        test_ways=args.test_ways,
        root=args.root,
        data_augmentation=data_augmentation,
    )
    episodic_data = EpisodicBatcher(
        tasksets.train,
        tasksets.validation,
        tasksets.test,
        epoch_length=args.meta_batch_size * 10,
    )

    # init model
    if args.dataset in ["mini-imagenet", "tiered-imagenet"]:
        model = l2l.vision.models.ResNet12(output_size=args.train_ways)
    else:  # CIFAR-FS, FC100
        model = l2l.vision.models.CNN4(
            output_size=args.train_ways,
            hidden_size=64,
            embedding_size=64 * 4,
        )
    features = model.features
    classifier = model.classifier

    # init algorithm
    if args.algorithm == "protonet":
        algorithm = LightningPrototypicalNetworks(features=features,
                                                  **dict_args)
    elif args.algorithm == "maml":
        algorithm = LightningMAML(model, **dict_args)
    elif args.algorithm == "anil":
        algorithm = LightningANIL(features, classifier, **dict_args)
    elif args.algorithm == "metaoptnet":
        algorithm = LightningMetaOptNet(features, **dict_args)

    trainer = pl.Trainer.from_argparse_args(
        args,
        gpus=1,
        accumulate_grad_batches=args.meta_batch_size,
        callbacks=[
            l2l.utils.lightning.TrackTestAccuracyCallback(),
            l2l.utils.lightning.NoLeaveProgressBar(),
        ],
    )
    trainer.fit(model=algorithm, datamodule=episodic_data)
    trainer.test(ckpt_path="best")