示例#1
0
def cpc_v2_example():
	from pl_bolts.models.self_supervised import CPC_v2
	from pl_bolts.datamodules import CIFAR10DataModule
	from pl_bolts.models.self_supervised.cpc import CPCTrainTransformsCIFAR10, CPCEvalTransformsCIFAR10
	from pytorch_lightning.plugins import DDPPlugin

	# Data module.
	dm = CIFAR10DataModule(num_workers=12, batch_size=32)
	dm.train_transforms = CPCTrainTransformsCIFAR10()
	dm.val_transforms = CPCEvalTransformsCIFAR10()

	# Model.
	model = CPC_v2(encoder="cpc_encoder")

	# Fit.
	trainer = pl.Trainer(gpus=2, accelerator="ddp", plugins=DDPPlugin(find_unused_parameters=False))
	trainer.fit(model, datamodule=dm)

	#--------------------
	# CIFAR-10 pretrained model:
	weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-cifar10-v4-exp3/epoch%3D474.ckpt"
	# STL-10 pretrained model:
	#weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-stl10-v0-exp3/epoch%3D624.ckpt"
	cpc_v2 = CPC_v2.load_from_checkpoint(weight_path, strict=False)

	cpc_v2.freeze()
def test_byol(tmpdir, datadir):
    datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2)
    datamodule.train_transforms = CPCTrainTransformsCIFAR10()
    datamodule.val_transforms = CPCEvalTransformsCIFAR10()

    model = BYOL(data_dir=datadir, num_classes=datamodule)
    trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
    trainer.fit(model, datamodule=datamodule)
示例#3
0
def test_byol(tmpdir):
    seed_everything()

    datamodule = CIFAR10DataModule(data_dir=tmpdir, num_workers=0, batch_size=2)
    datamodule.train_transforms = CPCTrainTransformsCIFAR10()
    datamodule.val_transforms = CPCEvalTransformsCIFAR10()

    model = BYOL(data_dir=tmpdir, num_classes=datamodule)
    trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir, max_steps=2)
    trainer.fit(model, datamodule)
    loss = trainer.progress_bar_dict['loss']

    assert float(loss) < 1.0
示例#4
0
def test_cpcv2(tmpdir):
    seed_everything()

    datamodule = CIFAR10DataModule(data_dir=tmpdir, num_workers=0, batch_size=2)
    datamodule.train_transforms = CPCTrainTransformsCIFAR10()
    datamodule.val_transforms = CPCEvalTransformsCIFAR10()

    model = CPCV2(encoder='resnet18', data_dir=tmpdir, batch_size=2, online_ft=True, datamodule=datamodule)
    trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir)
    trainer.fit(model)
    loss = trainer.progress_bar_dict['val_nce']

    assert float(loss) > 0
示例#5
0
def test_cpcv2(tmpdir):
    reset_seed()

    datamodule = CIFAR10DataModule(data_dir=tmpdir, num_workers=0)
    datamodule.train_transforms = CPCTrainTransformsCIFAR10()
    datamodule.val_transforms = CPCEvalTransformsCIFAR10()

    model = CPCV2(encoder='resnet18', data_dir=tmpdir, batch_size=2, online_ft=True, datamodule=datamodule)
    trainer = pl.Trainer(overfit_batches=2, max_epochs=1, default_root_dir=tmpdir)
    trainer.fit(model)
    loss = trainer.callback_metrics['loss']

    assert loss > 0
def test_cpcv2(tmpdir, datadir):
    datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2)
    datamodule.train_transforms = CPCTrainTransformsCIFAR10()
    datamodule.val_transforms = CPCEvalTransformsCIFAR10()

    model = CPC_v2(
        encoder='mobilenet_v3_small',
        patch_size=8,
        patch_overlap=2,
        online_ft=True,
        num_classes=datamodule.num_classes,
    )
    trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
    trainer.fit(model, datamodule=datamodule)
示例#7
0
def mix_and_match_any_part_or_subclass_example():
	from pl_bolts.models.self_supervised import CPC_v2
	from pl_bolts.losses.self_supervised_learning import FeatureMapContrastiveTask
	from pl_bolts.datamodules import CIFAR10DataModule
	from pl_bolts.models.self_supervised.cpc import CPCTrainTransformsCIFAR10, CPCEvalTransformsCIFAR10
	from pytorch_lightning.plugins import DDPPlugin

	# Data module.
	dm = CIFAR10DataModule(num_workers=12, batch_size=32)
	dm.train_transforms = CPCTrainTransformsCIFAR10()
	dm.val_transforms = CPCEvalTransformsCIFAR10()

	# Model.
	amdim_task = FeatureMapContrastiveTask(comparisons="01, 11, 02", bidirectional=True)
	model = CPC_v2(encoder="cpc_encoder", contrastive_task=amdim_task)

	# Fit.
	trainer = pl.Trainer(gpus=2, accelerator="ddp", plugins=DDPPlugin(find_unused_parameters=False))
	trainer.fit(model, datamodule=dm)
示例#8
0
def test_cpcv2(tmpdir, datadir):
    datamodule = CIFAR10DataModule(data_dir=datadir,
                                   num_workers=0,
                                   batch_size=2)
    datamodule.train_transforms = CPCTrainTransformsCIFAR10()
    datamodule.val_transforms = CPCEvalTransformsCIFAR10()

    model = CPC_v2(
        encoder="mobilenet_v3_small",
        patch_size=8,
        patch_overlap=2,
        online_ft=True,
        num_classes=datamodule.num_classes,
    )

    # FIXME: workaround for bug caused by
    # https://github.com/PyTorchLightning/lightning-bolts/commit/2e903c333c37ea83394c7da2ce826de1b82fb356
    model.datamodule = datamodule

    trainer = Trainer(fast_dev_run=True,
                      default_root_dir=tmpdir,
                      gpus=1 if torch.cuda.device_count() > 0 else 0)
    trainer.fit(model, datamodule=datamodule)
示例#9
0
    def apply_to(self, setting: IIDSetting) -> IIDResults:
        """ Applies this method to the particular experimental setting.
        
        Extend this class and overwrite this method to customize training.      
        """
        if not self.is_applicable(setting):
            raise RuntimeError(
                f"Can only apply methods of type {type(self)} on settings "
                f"that inherit from {type(self).target_setting}. "
                f"(Given setting is of type {type(setting)})."
            )

        # Seed everything first:
        self.config.seed_everything()
        setting.configure(config=self.config)
        from sequoia.common.transforms import ToTensor
        # setting.transforms = [ToTensor(), CPCTrainTransformsCIFAR10()]
        setting.train_transforms = [ToTensor(), CPCTrainTransformsCIFAR10()]
        setting.val_transforms = [ToTensor(), CPCEvalTransformsCIFAR10()]

        # TODO: Seems a weird that we would have to do this.
        setting.data_dir = self.config.data_dir
        setting.config = self.config
        setting.batch_size = 16
        
        # # load resnet18 pretrained using CPC on imagenet
        model = CPCV2(pretrained='resnet18', datamodule=setting)
        # cpc_resnet18 = model.encoder
        # cpc_resnet18.freeze()

        
        trainer = pl.Trainer(gpus=1, fast_dev_run=True)
        trainer.fit(model, datamodule=setting)
        test_results = trainer.test(model)
        print(f"Test outputs: {test_results}")
        raise NotImplementedError("TODO: The CPCV2 model doesn't have a 'test_step' method.")