def cpc_v2_example(): from pl_bolts.models.self_supervised import CPC_v2 from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.models.self_supervised.cpc import CPCTrainTransformsCIFAR10, CPCEvalTransformsCIFAR10 from pytorch_lightning.plugins import DDPPlugin # Data module. dm = CIFAR10DataModule(num_workers=12, batch_size=32) dm.train_transforms = CPCTrainTransformsCIFAR10() dm.val_transforms = CPCEvalTransformsCIFAR10() # Model. model = CPC_v2(encoder="cpc_encoder") # Fit. trainer = pl.Trainer(gpus=2, accelerator="ddp", plugins=DDPPlugin(find_unused_parameters=False)) trainer.fit(model, datamodule=dm) #-------------------- # CIFAR-10 pretrained model: weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-cifar10-v4-exp3/epoch%3D474.ckpt" # STL-10 pretrained model: #weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-stl10-v0-exp3/epoch%3D624.ckpt" cpc_v2 = CPC_v2.load_from_checkpoint(weight_path, strict=False) cpc_v2.freeze()
def test_byol(tmpdir, datadir): datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() model = BYOL(data_dir=datadir, num_classes=datamodule) trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model, datamodule=datamodule)
def test_byol(tmpdir): seed_everything() datamodule = CIFAR10DataModule(data_dir=tmpdir, num_workers=0, batch_size=2) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() model = BYOL(data_dir=tmpdir, num_classes=datamodule) trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir, max_steps=2) trainer.fit(model, datamodule) loss = trainer.progress_bar_dict['loss'] assert float(loss) < 1.0
def test_cpcv2(tmpdir): seed_everything() datamodule = CIFAR10DataModule(data_dir=tmpdir, num_workers=0, batch_size=2) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() model = CPCV2(encoder='resnet18', data_dir=tmpdir, batch_size=2, online_ft=True, datamodule=datamodule) trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir) trainer.fit(model) loss = trainer.progress_bar_dict['val_nce'] assert float(loss) > 0
def test_cpcv2(tmpdir): reset_seed() datamodule = CIFAR10DataModule(data_dir=tmpdir, num_workers=0) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() model = CPCV2(encoder='resnet18', data_dir=tmpdir, batch_size=2, online_ft=True, datamodule=datamodule) trainer = pl.Trainer(overfit_batches=2, max_epochs=1, default_root_dir=tmpdir) trainer.fit(model) loss = trainer.callback_metrics['loss'] assert loss > 0
def test_cpcv2(tmpdir, datadir): datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() model = CPC_v2( encoder='mobilenet_v3_small', patch_size=8, patch_overlap=2, online_ft=True, num_classes=datamodule.num_classes, ) trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model, datamodule=datamodule)
def mix_and_match_any_part_or_subclass_example(): from pl_bolts.models.self_supervised import CPC_v2 from pl_bolts.losses.self_supervised_learning import FeatureMapContrastiveTask from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.models.self_supervised.cpc import CPCTrainTransformsCIFAR10, CPCEvalTransformsCIFAR10 from pytorch_lightning.plugins import DDPPlugin # Data module. dm = CIFAR10DataModule(num_workers=12, batch_size=32) dm.train_transforms = CPCTrainTransformsCIFAR10() dm.val_transforms = CPCEvalTransformsCIFAR10() # Model. amdim_task = FeatureMapContrastiveTask(comparisons="01, 11, 02", bidirectional=True) model = CPC_v2(encoder="cpc_encoder", contrastive_task=amdim_task) # Fit. trainer = pl.Trainer(gpus=2, accelerator="ddp", plugins=DDPPlugin(find_unused_parameters=False)) trainer.fit(model, datamodule=dm)
def test_cpcv2(tmpdir, datadir): datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() model = CPC_v2( encoder="mobilenet_v3_small", patch_size=8, patch_overlap=2, online_ft=True, num_classes=datamodule.num_classes, ) # FIXME: workaround for bug caused by # https://github.com/PyTorchLightning/lightning-bolts/commit/2e903c333c37ea83394c7da2ce826de1b82fb356 model.datamodule = datamodule trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir, gpus=1 if torch.cuda.device_count() > 0 else 0) trainer.fit(model, datamodule=datamodule)
def apply_to(self, setting: IIDSetting) -> IIDResults: """ Applies this method to the particular experimental setting. Extend this class and overwrite this method to customize training. """ if not self.is_applicable(setting): raise RuntimeError( f"Can only apply methods of type {type(self)} on settings " f"that inherit from {type(self).target_setting}. " f"(Given setting is of type {type(setting)})." ) # Seed everything first: self.config.seed_everything() setting.configure(config=self.config) from sequoia.common.transforms import ToTensor # setting.transforms = [ToTensor(), CPCTrainTransformsCIFAR10()] setting.train_transforms = [ToTensor(), CPCTrainTransformsCIFAR10()] setting.val_transforms = [ToTensor(), CPCEvalTransformsCIFAR10()] # TODO: Seems a weird that we would have to do this. setting.data_dir = self.config.data_dir setting.config = self.config setting.batch_size = 16 # # load resnet18 pretrained using CPC on imagenet model = CPCV2(pretrained='resnet18', datamodule=setting) # cpc_resnet18 = model.encoder # cpc_resnet18.freeze() trainer = pl.Trainer(gpus=1, fast_dev_run=True) trainer.fit(model, datamodule=setting) test_results = trainer.test(model) print(f"Test outputs: {test_results}") raise NotImplementedError("TODO: The CPCV2 model doesn't have a 'test_step' method.")