def cli_main(): # pragma: no-cover pl.seed_everything(1234) parser = ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser.add_argument('--dataset', type=str, help='stl10, cifar10', default='cifar10') parser.add_argument('--ckpt_path', type=str, help='path to ckpt') parser.add_argument('--data_dir', type=str, help='path to ckpt', default=os.getcwd()) args = parser.parse_args() # load the backbone backbone = CPCV2.load_from_checkpoint(args.ckpt_path, strict=False) if args.dataset == 'cifar10': dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = CPCTrainTransformsCIFAR10() dm.val_transforms = CPCEvalTransformsCIFAR10() dm.test_transforms = CPCEvalTransformsCIFAR10() elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_labeled dm.val_dataloader = dm.val_dataloader_labeled dm.train_transforms = CPCTrainTransformsSTL10() dm.val_transforms = CPCEvalTransformsSTL10() dm.test_transforms = CPCEvalTransformsSTL10() # finetune tuner = SSLFineTuner(backbone, in_features=backbone.z_dim, num_classes=backbone.num_classes) trainer = pl.Trainer.from_argparse_args(args, early_stop_callback=True) trainer.fit(tuner, dm) trainer.test(datamodule=dm)
def test_cpcv2(tmpdir): seed_everything() datamodule = CIFAR10DataModule(data_dir=tmpdir, num_workers=0, batch_size=2) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() model = CPCV2(encoder='resnet18', data_dir=tmpdir, batch_size=2, online_ft=True, datamodule=datamodule) trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir) trainer.fit(model) loss = trainer.progress_bar_dict['val_nce'] assert float(loss) > 0
def test_cpcv2(tmpdir): reset_seed() datamodule = CIFAR10DataModule(data_dir=tmpdir, num_workers=0) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() model = CPCV2(encoder='resnet18', data_dir=tmpdir, batch_size=2, online_ft=True, datamodule=datamodule) trainer = pl.Trainer(overfit_batches=2, max_epochs=1, default_root_dir=tmpdir) trainer.fit(model) loss = trainer.callback_metrics['loss'] assert loss > 0
def apply_to(self, setting: IIDSetting) -> IIDResults: """ Applies this method to the particular experimental setting. Extend this class and overwrite this method to customize training. """ if not self.is_applicable(setting): raise RuntimeError( f"Can only apply methods of type {type(self)} on settings " f"that inherit from {type(self).target_setting}. " f"(Given setting is of type {type(setting)})." ) # Seed everything first: self.config.seed_everything() setting.configure(config=self.config) from sequoia.common.transforms import ToTensor # setting.transforms = [ToTensor(), CPCTrainTransformsCIFAR10()] setting.train_transforms = [ToTensor(), CPCTrainTransformsCIFAR10()] setting.val_transforms = [ToTensor(), CPCEvalTransformsCIFAR10()] # TODO: Seems a weird that we would have to do this. setting.data_dir = self.config.data_dir setting.config = self.config setting.batch_size = 16 # # load resnet18 pretrained using CPC on imagenet model = CPCV2(pretrained='resnet18', datamodule=setting) # cpc_resnet18 = model.encoder # cpc_resnet18.freeze() trainer = pl.Trainer(gpus=1, fast_dev_run=True) trainer.fit(model, datamodule=setting) test_results = trainer.test(model) print(f"Test outputs: {test_results}") raise NotImplementedError("TODO: The CPCV2 model doesn't have a 'test_step' method.")