コード例 #1
0
def cpc_v2_example():
	from pl_bolts.models.self_supervised import CPC_v2
	from pl_bolts.datamodules import CIFAR10DataModule
	from pl_bolts.models.self_supervised.cpc import CPCTrainTransformsCIFAR10, CPCEvalTransformsCIFAR10
	from pytorch_lightning.plugins import DDPPlugin

	# Data module.
	dm = CIFAR10DataModule(num_workers=12, batch_size=32)
	dm.train_transforms = CPCTrainTransformsCIFAR10()
	dm.val_transforms = CPCEvalTransformsCIFAR10()

	# Model.
	model = CPC_v2(encoder="cpc_encoder")

	# Fit.
	trainer = pl.Trainer(gpus=2, accelerator="ddp", plugins=DDPPlugin(find_unused_parameters=False))
	trainer.fit(model, datamodule=dm)

	#--------------------
	# CIFAR-10 pretrained model:
	weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-cifar10-v4-exp3/epoch%3D474.ckpt"
	# STL-10 pretrained model:
	#weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-stl10-v0-exp3/epoch%3D624.ckpt"
	cpc_v2 = CPC_v2.load_from_checkpoint(weight_path, strict=False)

	cpc_v2.freeze()
コード例 #2
0
def test_cpcv2(tmpdir, datadir):
    datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2)
    datamodule.train_transforms = CPCTrainTransformsCIFAR10()
    datamodule.val_transforms = CPCEvalTransformsCIFAR10()

    model = CPC_v2(
        encoder='mobilenet_v3_small',
        patch_size=8,
        patch_overlap=2,
        online_ft=True,
        num_classes=datamodule.num_classes,
    )
    trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
    trainer.fit(model, datamodule=datamodule)
コード例 #3
0
def cli_main():  # pragma: no cover
    from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule

    seed_everything(1234)

    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parser)
    parser.add_argument('--dataset',
                        type=str,
                        help='stl10, cifar10',
                        default='cifar10')
    parser.add_argument('--ckpt_path', type=str, help='path to ckpt')
    parser.add_argument('--data_dir',
                        type=str,
                        help='path to ckpt',
                        default=os.getcwd())
    args = parser.parse_args()

    # load the backbone
    backbone = CPC_v2.load_from_checkpoint(args.ckpt_path, strict=False)

    if args.dataset == 'cifar10':
        dm = CIFAR10DataModule.from_argparse_args(args)
        dm.train_transforms = CPCTrainTransformsCIFAR10()
        dm.val_transforms = CPCEvalTransformsCIFAR10()
        dm.test_transforms = CPCEvalTransformsCIFAR10()

    elif args.dataset == 'stl10':
        dm = STL10DataModule.from_argparse_args(args)
        dm.train_dataloader = dm.train_dataloader_labeled
        dm.val_dataloader = dm.val_dataloader_labeled
        dm.train_transforms = CPCTrainTransformsSTL10()
        dm.val_transforms = CPCEvalTransformsSTL10()
        dm.test_transforms = CPCEvalTransformsSTL10()

    # finetune
    tuner = SSLFineTuner(backbone,
                         in_features=backbone.z_dim,
                         num_classes=backbone.num_classes)
    trainer = Trainer.from_argparse_args(args, early_stop_callback=True)
    trainer.fit(tuner, dm)

    trainer.test(datamodule=dm)
コード例 #4
0
def mix_and_match_any_part_or_subclass_example():
	from pl_bolts.models.self_supervised import CPC_v2
	from pl_bolts.losses.self_supervised_learning import FeatureMapContrastiveTask
	from pl_bolts.datamodules import CIFAR10DataModule
	from pl_bolts.models.self_supervised.cpc import CPCTrainTransformsCIFAR10, CPCEvalTransformsCIFAR10
	from pytorch_lightning.plugins import DDPPlugin

	# Data module.
	dm = CIFAR10DataModule(num_workers=12, batch_size=32)
	dm.train_transforms = CPCTrainTransformsCIFAR10()
	dm.val_transforms = CPCEvalTransformsCIFAR10()

	# Model.
	amdim_task = FeatureMapContrastiveTask(comparisons="01, 11, 02", bidirectional=True)
	model = CPC_v2(encoder="cpc_encoder", contrastive_task=amdim_task)

	# Fit.
	trainer = pl.Trainer(gpus=2, accelerator="ddp", plugins=DDPPlugin(find_unused_parameters=False))
	trainer.fit(model, datamodule=dm)
コード例 #5
0
def test_cpcv2(tmpdir, datadir):
    datamodule = CIFAR10DataModule(data_dir=datadir,
                                   num_workers=0,
                                   batch_size=2)
    datamodule.train_transforms = CPCTrainTransformsCIFAR10()
    datamodule.val_transforms = CPCEvalTransformsCIFAR10()

    model = CPC_v2(
        encoder="mobilenet_v3_small",
        patch_size=8,
        patch_overlap=2,
        online_ft=True,
        num_classes=datamodule.num_classes,
    )

    # FIXME: workaround for bug caused by
    # https://github.com/PyTorchLightning/lightning-bolts/commit/2e903c333c37ea83394c7da2ce826de1b82fb356
    model.datamodule = datamodule

    trainer = Trainer(fast_dev_run=True,
                      default_root_dir=tmpdir,
                      gpus=1 if torch.cuda.device_count() > 0 else 0)
    trainer.fit(model, datamodule=datamodule)