def cpc_v2_example(): from pl_bolts.models.self_supervised import CPC_v2 from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.models.self_supervised.cpc import CPCTrainTransformsCIFAR10, CPCEvalTransformsCIFAR10 from pytorch_lightning.plugins import DDPPlugin # Data module. dm = CIFAR10DataModule(num_workers=12, batch_size=32) dm.train_transforms = CPCTrainTransformsCIFAR10() dm.val_transforms = CPCEvalTransformsCIFAR10() # Model. model = CPC_v2(encoder="cpc_encoder") # Fit. trainer = pl.Trainer(gpus=2, accelerator="ddp", plugins=DDPPlugin(find_unused_parameters=False)) trainer.fit(model, datamodule=dm) #-------------------- # CIFAR-10 pretrained model: weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-cifar10-v4-exp3/epoch%3D474.ckpt" # STL-10 pretrained model: #weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-stl10-v0-exp3/epoch%3D624.ckpt" cpc_v2 = CPC_v2.load_from_checkpoint(weight_path, strict=False) cpc_v2.freeze()
def test_cpcv2(tmpdir, datadir): datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() model = CPC_v2( encoder='mobilenet_v3_small', patch_size=8, patch_overlap=2, online_ft=True, num_classes=datamodule.num_classes, ) trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model, datamodule=datamodule)
def cli_main(): # pragma: no cover from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule seed_everything(1234) parser = ArgumentParser() parser = Trainer.add_argparse_args(parser) parser.add_argument('--dataset', type=str, help='stl10, cifar10', default='cifar10') parser.add_argument('--ckpt_path', type=str, help='path to ckpt') parser.add_argument('--data_dir', type=str, help='path to ckpt', default=os.getcwd()) args = parser.parse_args() # load the backbone backbone = CPC_v2.load_from_checkpoint(args.ckpt_path, strict=False) if args.dataset == 'cifar10': dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = CPCTrainTransformsCIFAR10() dm.val_transforms = CPCEvalTransformsCIFAR10() dm.test_transforms = CPCEvalTransformsCIFAR10() elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_labeled dm.val_dataloader = dm.val_dataloader_labeled dm.train_transforms = CPCTrainTransformsSTL10() dm.val_transforms = CPCEvalTransformsSTL10() dm.test_transforms = CPCEvalTransformsSTL10() # finetune tuner = SSLFineTuner(backbone, in_features=backbone.z_dim, num_classes=backbone.num_classes) trainer = Trainer.from_argparse_args(args, early_stop_callback=True) trainer.fit(tuner, dm) trainer.test(datamodule=dm)
def mix_and_match_any_part_or_subclass_example(): from pl_bolts.models.self_supervised import CPC_v2 from pl_bolts.losses.self_supervised_learning import FeatureMapContrastiveTask from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.models.self_supervised.cpc import CPCTrainTransformsCIFAR10, CPCEvalTransformsCIFAR10 from pytorch_lightning.plugins import DDPPlugin # Data module. dm = CIFAR10DataModule(num_workers=12, batch_size=32) dm.train_transforms = CPCTrainTransformsCIFAR10() dm.val_transforms = CPCEvalTransformsCIFAR10() # Model. amdim_task = FeatureMapContrastiveTask(comparisons="01, 11, 02", bidirectional=True) model = CPC_v2(encoder="cpc_encoder", contrastive_task=amdim_task) # Fit. trainer = pl.Trainer(gpus=2, accelerator="ddp", plugins=DDPPlugin(find_unused_parameters=False)) trainer.fit(model, datamodule=dm)
def test_cpcv2(tmpdir, datadir): datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() model = CPC_v2( encoder="mobilenet_v3_small", patch_size=8, patch_overlap=2, online_ft=True, num_classes=datamodule.num_classes, ) # FIXME: workaround for bug caused by # https://github.com/PyTorchLightning/lightning-bolts/commit/2e903c333c37ea83394c7da2ce826de1b82fb356 model.datamodule = datamodule trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir, gpus=1 if torch.cuda.device_count() > 0 else 0) trainer.fit(model, datamodule=datamodule)