def cli_main():  # pragma: no-cover
    pl.seed_everything(1234)

    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument('--dataset', type=str, help='stl10, cifar10', default='cifar10')
    parser.add_argument('--ckpt_path', type=str, help='path to ckpt')
    parser.add_argument('--data_dir', type=str, help='path to ckpt', default=os.getcwd())
    args = parser.parse_args()

    # load the backbone
    backbone = CPCV2.load_from_checkpoint(args.ckpt_path, strict=False)

    if args.dataset == 'cifar10':
        dm = CIFAR10DataModule.from_argparse_args(args)
        dm.train_transforms = CPCTrainTransformsCIFAR10()
        dm.val_transforms = CPCEvalTransformsCIFAR10()
        dm.test_transforms = CPCEvalTransformsCIFAR10()

    elif args.dataset == 'stl10':
        dm = STL10DataModule.from_argparse_args(args)
        dm.train_dataloader = dm.train_dataloader_labeled
        dm.val_dataloader = dm.val_dataloader_labeled
        dm.train_transforms = CPCTrainTransformsSTL10()
        dm.val_transforms = CPCEvalTransformsSTL10()
        dm.test_transforms = CPCEvalTransformsSTL10()

    # finetune
    tuner = SSLFineTuner(backbone, in_features=backbone.z_dim, num_classes=backbone.num_classes)
    trainer = pl.Trainer.from_argparse_args(args, early_stop_callback=True)
    trainer.fit(tuner, dm)

    trainer.test(datamodule=dm)
예제 #2
0
def test_cpcv2(tmpdir):
    seed_everything()

    datamodule = CIFAR10DataModule(data_dir=tmpdir, num_workers=0, batch_size=2)
    datamodule.train_transforms = CPCTrainTransformsCIFAR10()
    datamodule.val_transforms = CPCEvalTransformsCIFAR10()

    model = CPCV2(encoder='resnet18', data_dir=tmpdir, batch_size=2, online_ft=True, datamodule=datamodule)
    trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir)
    trainer.fit(model)
    loss = trainer.progress_bar_dict['val_nce']

    assert float(loss) > 0
예제 #3
0
def test_cpcv2(tmpdir):
    reset_seed()

    datamodule = CIFAR10DataModule(data_dir=tmpdir, num_workers=0)
    datamodule.train_transforms = CPCTrainTransformsCIFAR10()
    datamodule.val_transforms = CPCEvalTransformsCIFAR10()

    model = CPCV2(encoder='resnet18', data_dir=tmpdir, batch_size=2, online_ft=True, datamodule=datamodule)
    trainer = pl.Trainer(overfit_batches=2, max_epochs=1, default_root_dir=tmpdir)
    trainer.fit(model)
    loss = trainer.callback_metrics['loss']

    assert loss > 0
예제 #4
0
    def apply_to(self, setting: IIDSetting) -> IIDResults:
        """ Applies this method to the particular experimental setting.
        
        Extend this class and overwrite this method to customize training.      
        """
        if not self.is_applicable(setting):
            raise RuntimeError(
                f"Can only apply methods of type {type(self)} on settings "
                f"that inherit from {type(self).target_setting}. "
                f"(Given setting is of type {type(setting)})."
            )

        # Seed everything first:
        self.config.seed_everything()
        setting.configure(config=self.config)
        from sequoia.common.transforms import ToTensor
        # setting.transforms = [ToTensor(), CPCTrainTransformsCIFAR10()]
        setting.train_transforms = [ToTensor(), CPCTrainTransformsCIFAR10()]
        setting.val_transforms = [ToTensor(), CPCEvalTransformsCIFAR10()]

        # TODO: Seems a weird that we would have to do this.
        setting.data_dir = self.config.data_dir
        setting.config = self.config
        setting.batch_size = 16
        
        # # load resnet18 pretrained using CPC on imagenet
        model = CPCV2(pretrained='resnet18', datamodule=setting)
        # cpc_resnet18 = model.encoder
        # cpc_resnet18.freeze()

        
        trainer = pl.Trainer(gpus=1, fast_dev_run=True)
        trainer.fit(model, datamodule=setting)
        test_results = trainer.test(model)
        print(f"Test outputs: {test_results}")
        raise NotImplementedError("TODO: The CPCV2 model doesn't have a 'test_step' method.")