Example #1
0
    def __init__(
        self,
        *args,
        **kwargs,
    ):

        self._threshold = kwargs.get("threshold", None)
        self.__instantiate_transform(kwargs)
        BaseDatasetSamplerMixin.__init__(self, *args, **kwargs)
        BaseTasksMixin.__init__(self, *args, **kwargs)
        self.clean_kwargs(kwargs)
        LightningDataModule.__init__(self, *args, **kwargs)

        self.dataset_train = None
        self.dataset_val = None
        self.dataset_test = None

        self._seed = 42
        self._num_workers = 2
        self._shuffle = True
        self._drop_last = False
        self._pin_memory = True
        self._follow_batch = []

        self._hyper_parameters = {}
def test_v1_7_0_datamodule_transform_properties(tmpdir):
    dm = MNISTDataModule()
    with pytest.deprecated_call(
            match=
            r"DataModule property `train_transforms` was deprecated in v1.5"):
        dm.train_transforms = "a"
    with pytest.deprecated_call(
            match=r"DataModule property `val_transforms` was deprecated in v1.5"
    ):
        dm.val_transforms = "b"
    with pytest.deprecated_call(
            match=
            r"DataModule property `test_transforms` was deprecated in v1.5"):
        dm.test_transforms = "c"
    with pytest.deprecated_call(
            match=
            r"DataModule property `train_transforms` was deprecated in v1.5"):
        _ = LightningDataModule(train_transforms="a")
    with pytest.deprecated_call(
            match=r"DataModule property `val_transforms` was deprecated in v1.5"
    ):
        _ = LightningDataModule(val_transforms="b")
    with pytest.deprecated_call(
            match=
            r"DataModule property `test_transforms` was deprecated in v1.5"):
        _ = LightningDataModule(test_transforms="c")
    with pytest.deprecated_call(
            match=
            r"DataModule property `test_transforms` was deprecated in v1.5"):
        _ = LightningDataModule(test_transforms="c", dims=(1, 1, 1))
def test_dm_init_from_datasets(tmpdir):

    train_ds = DummyDS()
    valid_ds = DummyDS()
    test_ds = DummyDS()

    valid_dss = [DummyDS(), DummyDS()]
    test_dss = [DummyDS(), DummyDS()]

    dm = LightningDataModule.from_datasets(train_ds,
                                           batch_size=4,
                                           num_workers=0)
    assert torch.all(next(iter(dm.train_dataloader())) == torch.ones(4))
    assert dm.val_dataloader() is None
    assert dm.test_dataloader() is None

    dm = LightningDataModule.from_datasets(train_ds,
                                           valid_ds,
                                           test_ds,
                                           batch_size=4,
                                           num_workers=0)
    assert torch.all(next(iter(dm.val_dataloader())) == torch.ones(4))
    assert torch.all(next(iter(dm.test_dataloader())) == torch.ones(4))

    dm = LightningDataModule.from_datasets(train_ds,
                                           valid_dss,
                                           test_dss,
                                           batch_size=4,
                                           num_workers=0)
    assert torch.all(next(iter(dm.val_dataloader()[0])) == torch.ones(4))
    assert torch.all(next(iter(dm.val_dataloader()[1])) == torch.ones(4))
    assert torch.all(next(iter(dm.test_dataloader()[0])) == torch.ones(4))
    assert torch.all(next(iter(dm.test_dataloader()[1])) == torch.ones(4))
Example #4
0
def test_dm_init_from_datasets_dataloaders(iterable):
    ds = DummyIDS if iterable else DummyDS

    train_ds = ds()
    dm = LightningDataModule.from_datasets(train_ds, batch_size=4, num_workers=0)
    with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock:
        dm.train_dataloader()
        dl_mock.assert_called_once_with(train_ds, batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True)
    with pytest.raises(MisconfigurationException, match="`val_dataloader` must be implemented"):
        _ = dm.val_dataloader()
    with pytest.raises(MisconfigurationException, match="`test_dataloader` must be implemented"):
        _ = dm.test_dataloader()

    train_ds_sequence = [ds(), ds()]
    dm = LightningDataModule.from_datasets(train_ds_sequence, batch_size=4, num_workers=0)
    with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock:
        dm.train_dataloader()
        dl_mock.assert_has_calls(
            [
                call(train_ds_sequence[0], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True),
                call(train_ds_sequence[1], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True),
            ]
        )
    with pytest.raises(MisconfigurationException, match="`val_dataloader` must be implemented"):
        _ = dm.val_dataloader()
    with pytest.raises(MisconfigurationException, match="`test_dataloader` must be implemented"):
        _ = dm.test_dataloader()

    valid_ds = ds()
    test_ds = ds()
    dm = LightningDataModule.from_datasets(val_dataset=valid_ds, test_dataset=test_ds, batch_size=2, num_workers=0)
    with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock:
        dm.val_dataloader()
        dl_mock.assert_called_with(valid_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True)
        dm.test_dataloader()
        dl_mock.assert_called_with(test_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True)
    with pytest.raises(MisconfigurationException, match="`train_dataloader` must be implemented"):
        _ = dm.train_dataloader()

    valid_dss = [ds(), ds()]
    test_dss = [ds(), ds()]
    predict_dss = [ds(), ds()]
    dm = LightningDataModule.from_datasets(train_ds, valid_dss, test_dss, predict_dss, batch_size=4, num_workers=0)
    with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock:
        dm.val_dataloader()
        dm.test_dataloader()
        dm.predict_dataloader()
        dl_mock.assert_has_calls(
            [
                call(valid_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
                call(valid_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
                call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
                call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
                call(predict_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
                call(predict_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
            ]
        )
def model_cases():
    class TestHparamsNamespace(LightningModule):
        learning_rate = 1

        def __contains__(self, item):
            return item == "learning_rate"

    TestHparamsDict = {"learning_rate": 2}

    class TestModel1(LightningModule):  # test for namespace
        learning_rate = 0

    model1 = TestModel1()

    class TestModel2(LightningModule):  # test for hparams namespace
        hparams = TestHparamsNamespace()

    model2 = TestModel2()

    class TestModel3(LightningModule):  # test for hparams dict
        hparams = TestHparamsDict

    model3 = TestModel3()

    class TestModel4(LightningModule):  # fail case
        batch_size = 1

    model4 = TestModel4()

    trainer = Trainer()
    datamodule = LightningDataModule()
    datamodule.batch_size = 8
    trainer.datamodule = datamodule

    model5 = LightningModule()
    model5.trainer = trainer

    class TestModel6(LightningModule):  # test for datamodule w/ hparams w/o attribute (should use datamodule)
        hparams = TestHparamsDict

    model6 = TestModel6()
    model6.trainer = trainer

    TestHparamsDict2 = {"batch_size": 2}

    class TestModel7(LightningModule):  # test for datamodule w/ hparams w/ attribute (should use datamodule)
        hparams = TestHparamsDict2

    model7 = TestModel7()
    model7.trainer = trainer

    return model1, model2, model3, model4, model5, model6, model7
Example #6
0
def predict_test(trainer: Trainer, model: LightningModule,
                 dm: LightningDataModule):
    """Checks if the trained model has high accuracy on the test set."""
    trainer.fit(model, datamodule=dm)
    dm.setup(stage="test")
    test_loader = dm.test_dataloader()
    acc = pl.metrics.Accuracy()
    for batch in test_loader:
        x, y = batch
        with torch.no_grad():
            y_hat = model(x)
        y_hat = y_hat.cpu()
        acc.update(y_hat, y)
    average_acc = acc.compute()
    assert average_acc >= 0.5, f"This model is expected to get > {0.5} in " \
                               f"test set (it got {average_acc})"
def gaussian_noise_pig_dl(
    dm: pl.LightningDataModule,
    batch_size: int,
    N_hat_multiplier: float = 1,
    sigma_multiplier: float = 1,
) -> DataLoader:
    #  * 3  # 3 is Arbitrary
    # return x + noise_std * torch.randn_like(x)
    x, _ = unpack_dataloader(dm.train_dataloader())
    N = x.shape[0]
    noise_std = torch.std(x)

    # Does not exactly respect N_hat but easier to do that way
    if N_hat_multiplier >= 1:
        n_repeats = int(N_hat_multiplier)
        x = x.repeat(1, n_repeats).view(-1, x.shape[1])
    else:
        N_hat = int(N * N_hat_multiplier)
        idx = torch.randperm(x.shape[0])[:N_hat]
        x = x[idx]

    x_hat = x + noise_std * torch.randn_like(x)

    dl = DataLoader(TensorDataset(x_hat), batch_size=batch_size, shuffle=True)
    return dl
def run_model_test_without_loggers(trainer_options: dict,
                                   model: LightningModule,
                                   data: LightningDataModule = None,
                                   min_acc: float = 0.50):
    reset_seed()

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=data)

    # correct result and ok accuracy
    assert trainer.state.finished, f"Training failed with {trainer.state}"

    model2 = load_model_from_checkpoint(
        trainer.logger, trainer.checkpoint_callback.best_model_path,
        type(model))

    # test new model accuracy
    test_loaders = model2.test_dataloader(
    ) if not data else data.test_dataloader()
    if not isinstance(test_loaders, list):
        test_loaders = [test_loaders]

    if not isinstance(model2, BoringModel):
        for dataloader in test_loaders:
            run_prediction_eval_model_template(model2,
                                               dataloader,
                                               min_acc=min_acc)
Example #9
0
def lightning_train(trainer: pl.Trainer, model: pl.LightningModule,
                    data_module: pl.LightningDataModule, checkpoint_path: str,
                    resume: bool):
    if resume:
        trainer = pl.Trainer(resume_from_checkpoint=checkpoint_path)

    exp_key = trainer.logger.experiment.get_key()
    print("Running {}-model...".format(model.name))

    # main part here
    data_module.setup('fit')
    trainer.fit(model=model, datamodule=data_module)

    model.save(trainer, checkpoint_path)

    return trainer, model, exp_key
def test_v1_7_0_datamodule_dims_property(tmpdir):
    dm = MNISTDataModule()
    with pytest.deprecated_call(
            match=r"DataModule property `dims` was deprecated in v1.5"):
        _ = dm.dims
    with pytest.deprecated_call(
            match=r"DataModule property `dims` was deprecated in v1.5"):
        _ = LightningDataModule(dims=(1, 1, 1))
def run_model_test(
    trainer_options,
    model: LightningModule,
    data: LightningDataModule = None,
    on_gpu: bool = True,
    version=None,
    with_hpc: bool = True,
    min_acc: float = 0.25,
):
    reset_seed()
    save_dir = trainer_options["default_root_dir"]

    # logger file to get meta
    logger = get_default_logger(save_dir, version=version)
    trainer_options.update(logger=logger)
    trainer = Trainer(**trainer_options)
    initial_values = torch.tensor(
        [torch.sum(torch.abs(x)) for x in model.parameters()])
    trainer.fit(model, datamodule=data)
    post_train_values = torch.tensor(
        [torch.sum(torch.abs(x)) for x in model.parameters()])

    assert trainer.state.finished, f"Training failed with {trainer.state}"
    # Check that the model is actually changed post-training
    change_ratio = torch.norm(initial_values - post_train_values)
    assert change_ratio > 0.1, f"the model is changed of {change_ratio}"

    # test model loading
    pretrained_model = load_model_from_checkpoint(
        logger, trainer.checkpoint_callback.best_model_path, type(model))

    # test new model accuracy
    test_loaders = model.test_dataloader(
    ) if not data else data.test_dataloader()
    if not isinstance(test_loaders, list):
        test_loaders = [test_loaders]

    if not isinstance(model, BoringModel):
        for dataloader in test_loaders:
            run_prediction_eval_model_template(model,
                                               dataloader,
                                               min_acc=min_acc)

    if with_hpc:
        if trainer._distrib_type in (DistributedType.DDP,
                                     DistributedType.DDP_SPAWN,
                                     DistributedType.DDP2):
            # on hpc this would work fine... but need to hack it for the purpose of the test
            trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = trainer.init_optimizers(
                pretrained_model)

        # test HPC saving
        trainer.checkpoint_connector.hpc_save(save_dir, logger)
        # test HPC loading
        checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(
            save_dir)
        trainer.checkpoint_connector.restore(checkpoint_path)
def test_dm_init_from_datasets_dataloaders(iterable):
    ds = DummyIDS if iterable else DummyDS

    train_ds = ds()
    dm = LightningDataModule.from_datasets(train_ds, batch_size=4, num_workers=0)
    with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock:
        dm.train_dataloader()
        dl_mock.assert_called_once_with(train_ds, batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True)
    assert dm.val_dataloader() is None
    assert dm.test_dataloader() is None

    train_ds_sequence = [ds(), ds()]
    dm = LightningDataModule.from_datasets(train_ds_sequence, batch_size=4, num_workers=0)
    with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock:
        dm.train_dataloader()
        dl_mock.assert_has_calls([
            call(train_ds_sequence[0], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True),
            call(train_ds_sequence[1], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True)
        ])
    assert dm.val_dataloader() is None
    assert dm.test_dataloader() is None

    valid_ds = ds()
    test_ds = ds()
    dm = LightningDataModule.from_datasets(val_dataset=valid_ds, test_dataset=test_ds, batch_size=2, num_workers=0)
    with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock:
        dm.val_dataloader()
        dl_mock.assert_called_with(valid_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True)
        dm.test_dataloader()
        dl_mock.assert_called_with(test_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True)
    assert dm.train_dataloader() is None

    valid_dss = [ds(), ds()]
    test_dss = [ds(), ds()]
    dm = LightningDataModule.from_datasets(train_ds, valid_dss, test_dss, batch_size=4, num_workers=0)
    with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock:
        dm.val_dataloader()
        dm.test_dataloader()
        dl_mock.assert_has_calls([
            call(valid_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
            call(valid_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
            call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
            call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True)
        ])
Example #13
0
 def __init__(
         self,
         train_sequences: List[str],
         validation_sequences: List[str],
         alphabet: AlphabetDataLoader,
         masking_ratio: float,
         masking_prob: float,
         random_token_prob: float,
         num_workers: int,
         toks_per_batch: int,
         crop_sizes: Tuple[int, int] = (512, 1024),
 ):
     LightningDataModule.__init__(self)
     self._train_sequences = train_sequences
     self._validation_sequences = validation_sequences
     self._alphabet = alphabet
     self._masking_ratio = masking_ratio
     self._masking_prob = masking_prob
     self._random_token_prob = random_token_prob
     self._num_workers = num_workers
     self._toks_per_batch = toks_per_batch
     self._crop_sizes = crop_sizes
Example #14
0
    def __post_init__(self,
                      observation_space: gym.Space = None,
                      action_space: gym.Space = None,
                      reward_space: gym.Space = None):
        """ Initializes the fields of the setting that weren't set from the
        command-line.
        """
        logger.debug(f"__post_init__ of Setting")
        if len(self.train_transforms) == 1 and isinstance(self.train_transforms[0], list):
            self.train_transforms = self.train_transforms[0]
        if len(self.val_transforms) == 1 and isinstance(self.val_transforms[0], list):
            self.val_transforms = self.val_transforms[0]
        if len(self.test_transforms) == 1 and isinstance(self.test_transforms[0], list):
            self.test_transforms = self.test_transforms[0]

        # Actually compose the list of Transforms or callables into a single transform.
        self.train_transforms: Compose = Compose(self.train_transforms)
        self.val_transforms: Compose = Compose(self.val_transforms)
        self.test_transforms: Compose = Compose(self.test_transforms)

        LightningDataModule.__init__(self,
            train_transforms=self.train_transforms,
            val_transforms=self.val_transforms,
            test_transforms=self.test_transforms,
        )
        
        self._observation_space = observation_space
        self._action_space = action_space
        self._reward_space = reward_space

        # TODO: It's a bit confusing to also have a `config` attribute on the
        # Setting. Might want to change this a bit.
        self.config: Config = None

        self.train_env: Environment = None  # type: ignore
        self.val_env: Environment = None  # type: ignore
        self.test_env: Environment = None  # type: ignore
Example #15
0
def model_from_config(parameters: dict,
                      datamodule: LightningDataModule) -> LightningModule:
    steps_per_epoch = len(datamodule.train_dataloader())
    total_steps = parameters["trainer"]["epochs"] * steps_per_epoch

    class_weight = compute_class_weight(datamodule.dataset_train)

    regression = parameters["data"]["targets"] == "regression"
    if regression:
        model = ImageRegression(
            n_channel=datamodule.size(0),
            lr_scheduler_total_steps=total_steps,
            **parameters["model"],
        )
    else:
        model = ImageClassification(
            n_channel=datamodule.size(0),
            n_class=len(parameters["data"]["targets"]["classes"]),
            class_weight=class_weight,
            lr_scheduler_total_steps=total_steps,
            **parameters["model"],
        )

    return model
Example #16
0
def run_model_test(
    trainer_options,
    model: LightningModule,
    data: LightningDataModule = None,
    on_gpu: bool = True,
    version=None,
    with_hpc: bool = True,
    min_acc: float = 0.25,
):
    reset_seed()
    save_dir = trainer_options["default_root_dir"]

    # logger file to get meta
    logger = get_default_logger(save_dir, version=version)
    trainer_options.update(logger=logger)
    trainer = Trainer(**trainer_options)
    initial_values = torch.tensor([torch.sum(torch.abs(x)) for x in model.parameters()])
    trainer.fit(model, datamodule=data)
    post_train_values = torch.tensor([torch.sum(torch.abs(x)) for x in model.parameters()])

    assert trainer.state.finished, f"Training failed with {trainer.state}"
    # Check that the model is actually changed post-training
    change_ratio = torch.norm(initial_values - post_train_values)
    assert change_ratio > 0.03, f"the model is changed of {change_ratio}"

    # test model loading
    _ = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path, type(model))

    # test new model accuracy
    test_loaders = model.test_dataloader() if not data else data.test_dataloader()
    if not isinstance(test_loaders, list):
        test_loaders = [test_loaders]

    if not isinstance(model, BoringModel):
        for dataloader in test_loaders:
            run_model_prediction(model, dataloader, min_acc=min_acc)

    if with_hpc:
        # test HPC saving
        # save logger to make sure we get all the metrics
        if logger:
            logger.finalize("finished")
        hpc_save_path = trainer.checkpoint_connector.hpc_save_path(save_dir)
        trainer.save_checkpoint(hpc_save_path)
        # test HPC loading
        checkpoint_path = trainer.checkpoint_connector._CheckpointConnector__get_max_ckpt_path_from_folder(save_dir)
        trainer.checkpoint_connector.restore(checkpoint_path)
Example #17
0
    def __init__(
            self,
            datamodule: pl.LightningDataModule = None,
            encoder: Union[str, torch.nn.Module, pl.LightningModule] = 'cpc_encoder',
            patch_size: int = 8,
            patch_overlap: int = 4,
            online_ft: int = True,
            task: str = 'cpc',
            num_workers: int = 4,
            learning_rate: int = 1e-4,
            data_dir: str = '',
            batch_size: int = 32,
            pretrained: str = None,
            **kwargs,
    ):
        """
        PyTorch Lightning implementation of `Data-Efficient Image Recognition with Contrastive Predictive Coding
        <https://arxiv.org/abs/1905.09272>`_

        Paper authors: (Olivier J. Hénaff, Aravind Srinivas, Jeffrey De Fauw, Ali Razavi,
        Carl Doersch, S. M. Ali Eslami, Aaron van den Oord).

        Model implemented by:

            - `William Falcon <https://github.com/williamFalcon>`_
            - `Tullie Murrell <https://github.com/tullie>`_

        Example:

            >>> from pl_bolts.models.self_supervised import CPCV2
            ...
            >>> model = CPCV2()

        Train::

            trainer = Trainer()
            trainer.fit(model)

        CLI command::

            # cifar10
            python cpc_module.py --gpus 1

            # imagenet
            python cpc_module.py
                --gpus 8
                --dataset imagenet2012
                --data_dir /path/to/imagenet/
                --meta_dir /path/to/folder/with/meta.bin/
                --batch_size 32

        To Finetune::

            python cpc_finetuner.py --ckpt_path path/to/checkpoint.ckpt --dataset cifar10 --gpus x

        Some uses::

            # load resnet18 pretrained using CPC on imagenet
            model = CPCV2(encoder='resnet18', pretrained=True)
            resnet18 = model.encoder
            renset18.freeze()

            # it supportes any torchvision resnet
            model = CPCV2(encoder='resnet50', pretrained=True)

            # use it as a feature extractor
            x = torch.rand(2, 3, 224, 224)
            out = model(x)

        Args:
            datamodule: A Datamodule (optional). Otherwise set the dataloaders directly
            encoder: A string for any of the resnets in torchvision, or the original CPC encoder,
                or a custon nn.Module encoder
            patch_size: How big to make the image patches
            patch_overlap: How much overlap should each patch have.
            online_ft: Enable a 1024-unit MLP to fine-tune online
            task: Which self-supervised task to use ('cpc', 'amdim', etc...)
            num_workers: num dataloader worksers
            learning_rate: what learning rate to use
            data_dir: where to store data
            batch_size: batch size
            pretrained: If true, will use the weights pretrained (using CPC) on Imagenet
        """

        super().__init__()
        self.save_hyperparameters()

        self.online_evaluator = self.hparams.online_ft

        if pretrained:
            self.hparams.dataset = pretrained
            self.online_evaluator = True

        # link data
        if datamodule is None:
            datamodule = CIFAR10DataModule(
                self.hparams.data_dir,
                num_workers=self.hparams.num_workers,
                batch_size=batch_size
            )
            datamodule.train_transforms = CPCTrainTransformsCIFAR10()
            datamodule.val_transforms = CPCEvalTransformsCIFAR10()
        self.datamodule = datamodule

        # init encoder
        self.encoder = encoder
        if isinstance(encoder, str):
            self.encoder = self.init_encoder()

        # info nce loss
        c, h = self.__compute_final_nb_c(self.hparams.patch_size)
        self.contrastive_task = CPCTask(num_input_channels=c, target_dim=64, embed_scale=0.1)

        self.z_dim = c * h * h
        self.num_classes = self.datamodule.num_classes

        if pretrained:
            self.load_pretrained(encoder)
def kde_pig_dl(
    dm: pl.LightningDataModule,
    batch_size: int,
    N_hat_multiplier: float = 1,
) -> DataLoader:
    # %
    gd_n_steps, gd_lr, gd_threshold = 5, 4e-1, 0.005

    # Spherical = each component has single variance.
    bgm = BayesianGaussianMixture(
        n_components=batch_size,
        covariance_type="spherical",
        warm_start=True,
    )

    x_hat = torch.Tensor()
    for idx, batch in enumerate(iter(dm.train_dataloader())):
        x, _ = batch
        device = x.device
        x = x.detach().cpu().numpy()
        # Last batch might have less elements than origin n_components
        if x.shape[0] < bgm.n_components:
            bgm = BayesianGaussianMixture(
                n_components=x.shape[0],
                covariance_type="spherical",
            )
        # Estimate KDE
        bgm.fit(x)
        # [N_components, 1], [N_components, N_features], [N_components, 1]
        weights, means, variances = (
            torch.Tensor(bgm.weights_).to(device),
            torch.Tensor(bgm.means_).to(device),
            torch.Tensor(bgm.covariances_).to(device),
        )
        filter_weights_idx = weights >= 1e-5
        weights, means, variances = (
            weights[filter_weights_idx],
            means[filter_weights_idx],
            variances[filter_weights_idx][:, None],
        )
        n_selected_components = weights.shape[0]
        p_x = D.Independent(D.Normal(means, torch.sqrt(variances)), 1)
        mix = D.Categorical(weights)
        p_x = D.MixtureSameFamily(mix, p_x)
        # Sample according to multiplier
        x_start = p_x.sample(
            (
                n_selected_components
                * ((batch_size // n_selected_components) + 1)
                * N_hat_multiplier,
            )
        ).reshape(-1, x.shape[1])
        # Use GD
        _x_hat = density_gradient_descent(
            p_x,
            x_start,
            {"N_steps": gd_n_steps, "lr": gd_lr, "threshold": gd_threshold},
        )
        # Ensure same device
        if x_hat.device != device:
            x_hat = x_hat.to(device)
        x_hat = torch.cat((x_hat, _x_hat.detach()))

    dl = DataLoader(TensorDataset(x_hat), batch_size=batch_size, shuffle=True)
    return dl
def mean_shift_pig_dl(
    dm: pl.LightningDataModule,
    batch_size: int,
    N_hat_multiplier: float = 1,
    max_iters: int = 20,
    h: float = None,
    h_factor: float = 1,
    sigma: float = None,
    sigma_factor: float = 1,
    kernel: str = "tophat",
    τ: float = 1e-5,
) -> DataLoader:
    x, _ = unpack_dataloader(dm.train_dataloader())

    if h is None:
        h = silverman_bandwidth(x) * h_factor
    if sigma is None:
        # Note multiplier is arbitrary
        sigma = torch.std(x) * sigma_factor

    N = x.shape[0]
    N_hat = int(N_hat_multiplier * N)
    if N_hat > N:
        # With replacement to allow for N_hat > N
        idx = torch.randint(x.shape[0], (N_hat,))
    else:
        # Without replacement
        idx = torch.randperm(x.shape[0])[:N_hat]
    # [n_pig, D]
    x_start = x[idx]

    # Sample locally around each point, [n_pig, D]
    x_hat = x_start + sigma * torch.randn_like(x_start)

    for _ in range(max_iters):
        # [N, n_pig]
        dst = sqr_dist(x, x_hat)
        if kernel == "tophat":
            kde = torch.where(dst <= h, torch.ones_like(dst), torch.zeros_like(dst))
        else:
            raise NotImplementedError(f"Kernel {kernel} invalid")
        # Replace nan by 0
        kde[kde != kde] = 0

        # Threshold for stopping updates
        out_kde = (torch.max(kde, dim=0)[0] < τ)[:, None]

        # [n_pig, 1]
        sum_kde = torch.sum(kde, dim=0).reshape((-1, 1))
        sum_kde = torch.where(out_kde, torch.ones_like(sum_kde), sum_kde)

        # Centroid for all estimates
        mu = (torch.transpose(kde, 0, 1) @ x) / sum_kde
        # Step size
        delta = mu - x_hat

        x_hat = torch.where(out_kde, x_hat, x_hat - delta)

    # import matplotlib.pyplot as plt
    # from sklearn.neighbors import KernelDensity

    # kde = KernelDensity(kernel="gaussian", bandwidth=0.2).fit(x_hat)
    # x_plot = np.linspace(-10, 20, 1000).reshape(-1, 1)
    # density = np.exp(kde.score_samples(x_plot))

    # fig, ax = plt.subplots()
    # ax.plot(x, torch.zeros_like(x), "o")
    # ax.plot(x_hat, torch.zeros_like(x_hat), "o")
    # ax.plot(x_plot, density, color="black")

    # plt.show()
    # exit()

    dl = DataLoader(TensorDataset(x_hat), batch_size=batch_size, shuffle=True)

    return dl
Example #20
0
def make_cli_parser(
        parser: argparse.ArgumentParser,
        datamodule_cls: pl.LightningDataModule) -> argparse.ArgumentParser:
    """make_cli_parser Augment an argument parser for slp with the default arguments

    Default arguments for training, logging, optimization etc. are added to the input {parser}.
    If you use make_cli_parser, the following command line arguments will be included

        usage: my_script.py [-h] [--hidden MODEL.INTERMEDIATE_HIDDEN]
                                        [--optimizer {Adam,AdamW,SGD,Adadelta,Adagrad,Adamax,ASGD,RMSprop}]
                                        [--lr OPTIM.LR] [--weight-decay OPTIM.WEIGHT_DECAY]
                                        [--lr-scheduler] [--lr-factor LR_SCHEDULE.FACTOR]
                                        [--lr-patience LR_SCHEDULE.PATIENCE]
                                        [--lr-cooldown LR_SCHEDULE.COOLDOWN]
                                        [--min-lr LR_SCHEDULE.MIN_LR] [--seed SEED] [--config CONFIG]
                                        [--experiment-name TRAINER.EXPERIMENT_NAME]
                                        [--run-id TRAINER.RUN_ID]
                                        [--experiment-group TRAINER.EXPERIMENT_GROUP]
                                        [--experiments-folder TRAINER.EXPERIMENTS_FOLDER]
                                        [--save-top-k TRAINER.SAVE_TOP_K]
                                        [--patience TRAINER.PATIENCE]
                                        [--wandb-project TRAINER.WANDB_PROJECT]
                                        [--tags [TRAINER.TAGS [TRAINER.TAGS ...]]]
                                        [--stochastic_weight_avg] [--gpus TRAINER.GPUS]
                                        [--val-interval TRAINER.CHECK_VAL_EVERY_N_EPOCH]
                                        [--clip-grad-norm TRAINER.GRADIENT_CLIP_VAL]
                                        [--epochs TRAINER.MAX_EPOCHS] [--steps TRAINER.MAX_STEPS]
                                        [--tbtt_steps TRAINER.TRUNCATED_BPTT_STEPS] [--debug]
                                        [--offline] [--early-stop-on TRAINER.EARLY_STOP_ON]
                                        [--early-stop-mode {min,max}] [--num-trials TUNE.NUM_TRIALS]
                                        [--gpus-per-trial TUNE.GPUS_PER_TRIAL]
                                        [--cpus-per-trial TUNE.CPUS_PER_TRIAL]
                                        [--tune-metric TUNE.METRIC] [--tune-mode {max,min}]
                                        [--val-percent DATA.VAL_PERCENT]
                                        [--test-percent DATA.TEST_PERCENT] [--bsz DATA.BATCH_SIZE]
                                        [--bsz-eval DATA.BATCH_SIZE_EVAL]
                                        [--num-workers DATA.NUM_WORKERS] [--no-pin-memory]
                                        [--drop-last] [--no-shuffle-eval]

        optional arguments:
          -h, --help            show this help message and exit
          --hidden MODEL.INTERMEDIATE_HIDDEN
                                                        Intermediate hidden layers for linear module
          --optimizer {Adam,AdamW,SGD,Adadelta,Adagrad,Adamax,ASGD,RMSprop}
                                                        Which optimizer to use
          --lr OPTIM.LR         Learning rate
          --weight-decay OPTIM.WEIGHT_DECAY
                                                        Learning rate
          --lr-scheduler        Use learning rate scheduling. Currently only
                                                        ReduceLROnPlateau is supported out of the box
          --lr-factor LR_SCHEDULE.FACTOR
                                                        Multiplicative factor by which LR is reduced. Used if
                                                        --lr-scheduler is provided.
          --lr-patience LR_SCHEDULE.PATIENCE
                                                        Number of epochs with no improvement after which
                                                        learning rate will be reduced. Used if --lr-scheduler
                                                        is provided.
          --lr-cooldown LR_SCHEDULE.COOLDOWN
                                                        Number of epochs to wait before resuming normal
                                                        operation after lr has been reduced. Used if --lr-
                                                        scheduler is provided.
          --min-lr LR_SCHEDULE.MIN_LR
                                                        Minimum lr for LR scheduling. Used if --lr-scheduler
                                                        is provided.
          --seed SEED           Seed for reproducibility
          --config CONFIG       Path to YAML configuration file
          --experiment-name TRAINER.EXPERIMENT_NAME
                                                        Name of the running experiment
          --run-id TRAINER.RUN_ID
                                                        Unique identifier for the current run. If not provided
                                                        it is inferred from datetime.now()
          --experiment-group TRAINER.EXPERIMENT_GROUP
                                                        Group of current experiment. Useful when evaluating
                                                        for different seeds / cross-validation etc.
          --experiments-folder TRAINER.EXPERIMENTS_FOLDER
                                                        Top-level folder where experiment results &
                                                        checkpoints are saved
          --save-top-k TRAINER.SAVE_TOP_K
                                                        Save checkpoints for top k models
          --patience TRAINER.PATIENCE
                                                        Number of epochs to wait before early stopping
          --wandb-project TRAINER.WANDB_PROJECT
                                                        Wandb project under which results are saved
          --tags [TRAINER.TAGS [TRAINER.TAGS ...]]
                                                        Tags for current run to make results searchable.
          --stochastic_weight_avg
                                                        Use Stochastic weight averaging.
          --gpus TRAINER.GPUS   Number of GPUs to use
          --val-interval TRAINER.CHECK_VAL_EVERY_N_EPOCH
                                                        Run validation every n epochs
          --clip-grad-norm TRAINER.GRADIENT_CLIP_VAL
                                                        Clip gradients with ||grad(w)|| >= args.clip_grad_norm
          --epochs TRAINER.MAX_EPOCHS
                                                        Maximum number of training epochs
          --steps TRAINER.MAX_STEPS
                                                        Maximum number of training steps
          --tbtt_steps TRAINER.TRUNCATED_BPTT_STEPS
                                                        Truncated Back-propagation-through-time steps.
          --debug               If true, we run a full run on a small subset of the
                                                        input data and overfit 10 training batches
          --offline             If true, forces offline execution of wandb logger
          --early-stop-on TRAINER.EARLY_STOP_ON
                                                        Metric for early stopping
          --early-stop-mode {min,max}
                                                        Minimize or maximize early stopping metric
          --num-trials TUNE.NUM_TRIALS
                                                        Number of trials to run for hyperparameter tuning
          --gpus-per-trial TUNE.GPUS_PER_TRIAL
                                                        How many gpus to use for each trial. If gpus_per_trial
                                                        < 1 multiple trials are packed in the same gpu
          --cpus-per-trial TUNE.CPUS_PER_TRIAL
                                                        How many cpus to use for each trial.
          --tune-metric TUNE.METRIC
                                                        Tune this metric. Need to be one of the keys of
                                                        metrics_map passed into make_trainer_for_ray_tune.
          --tune-mode {max,min}
                                                        Maximize or minimize metric
          --val-percent DATA.VAL_PERCENT
                                                        Percent of validation data to be randomly split from
                                                        the training set, if no validation set is provided
          --test-percent DATA.TEST_PERCENT
                                                        Percent of test data to be randomly split from the
                                                        training set, if no test set is provided
          --bsz DATA.BATCH_SIZE
                                                        Training batch size
          --bsz-eval DATA.BATCH_SIZE_EVAL
                                                        Evaluation batch size
          --num-workers DATA.NUM_WORKERS
                                                        Number of workers to be used in the DataLoader
          --no-pin-memory       Don't pin data to GPU memory when transferring
          --drop-last           Drop last incomplete batch
          --no-shuffle-eval     Don't shuffle val & test sets

    Args:
        parser (argparse.ArgumentParser): A parent argument to be augmented
        datamodule_cls (pytorch_lightning.LightningDataModule): A data module class that injects arguments through the add_argparse_args method

    Returns:
        argparse.ArgumentParser: The augmented command line parser

    Examples:
        >>> import argparse
        >>> from slp.plbind.dm import PLDataModuleFromDatasets
        >>> parser = argparse.ArgumentParser("My cool model")
        >>> parser.add_argument("--hidden", dest="model.hidden", type=int)  # Create parser with model arguments and anything else you need
        >>> parser = make_cli_parser(parser, PLDataModuleFromDatasets)
        >>> args = parser.parse_args(args=["--bsz", "64", "--lr", "0.01"])
        >>> args.data.batch_size
        64
        >>> args.optim.lr
        0.01
    """
    parser = add_optimizer_args(parser)
    parser = add_trainer_args(parser)
    parser = add_tune_args(parser)
    parser = datamodule_cls.add_argparse_args(parser)

    return parser
Example #21
0
def data_info(datamodule: pl.LightningDataModule, parameters: dict, save_path: Path = None) -> dict:
    """
    Summary of the data used for training/testing.
    Gives: class-balance in each split, shape of the data, range of the data, split sizes, plot of examples.
    Datasets in datamodule must have y() method to access targets (like BaseDataset).
    """

    info = {}

    # Prepare datasets
    datasets = {}
    try:
        datasets["train"] = datamodule.train_dataloader().dataset
    except Exception:
        pass
    try:
        datasets["val"] = datamodule.val_dataloader().dataset
    except Exception:
        pass
    try:
        datasets["test"] = datamodule.test_dataloader().dataset
    except Exception:
        pass

    # Analyze class-balance
    def class_balance(dataset: BaseDataset):
        if isinstance(dataset, BaseDataset):
            y = dataset.y()
        elif isinstance(dataset, Subset):
            ds = dataset.dataset
            if not isinstance(ds, BaseDataset):
                raise AttributeError("dataset must be a BaseDataset or a Subset of a BaseDataset.")
            y = ds.y(dataset.indices)
        else:
            raise AttributeError("dataset must be a BaseDataset or a Subset of a BaseDataset.")

        counter = Counter(y)
        class_names = [list(c.keys())[0] for c in parameters["data"]["targets"]["classes"]]
        return {class_names[i]: counter[i] for i in range(len(class_names))}

    if "data" in parameters and "targets" in parameters["data"] and "classes" in parameters["data"]["targets"]:
        info["class-balance"] = {}
        for ds_name, ds in datasets.items():
            info["class-balance"][ds_name] = class_balance(ds)

    # Shape of data
    info["shape"] = str(datamodule.size())

    # Range
    if "train" in datasets:
        if isinstance(datasets["train"][0][0], tuple):
            t = torch.cat([datasets["train"][0][0][0], datasets["train"][1][0][0], datasets["train"][2][0][0]])
        else:
            t = torch.cat([datasets["train"][0][0], datasets["train"][1][0], datasets["train"][2][0]])
        info["tensor-data"] = {
            "min": t.min().item(),
            "max": t.max().item(),
            "mean": t.mean().item(),
            "std": t.std().item(),
        }

    # Sizes of splits
    info["set-sizes"] = {}
    for ds_name, ds in datasets.items():
        info["set-sizes"][ds_name] = len(ds)

    # Examples of each batch
    def plot_batch(ds_name: str, ds: Dataset, n_images: int = 32):
        if isinstance(ds[0][0], tuple):
            images = [ds[i][0][0][0] for i in random.sample(range(len(ds)), n_images)]
        else:
            images = [ds[i][0][0] for i in random.sample(range(len(ds)), n_images)]

        plot_image_grid(images, max_images=n_images, columns=8, save_path=save_path / f"data-examples-{ds_name}.png")

    if save_path is not None:
        for ds_name, ds in datasets.items():
            plot_batch(ds_name, ds)

    return info
    def __init__(self,
                 base_encoder: Union[str, torch.nn.Module] = 'resnet18',
                 emb_dim: int = 128,
                 num_negatives: int = 65536,
                 encoder_momentum: float = 0.999,
                 softmax_temperature: float = 0.07,
                 learning_rate: float = 0.03,
                 momentum: float = 0.9,
                 weight_decay: float = 1e-4,
                 datamodule: pl.LightningDataModule = None,
                 data_dir: str = './',
                 batch_size: int = 256,
                 use_mlp: bool = False,
                 num_workers: int = 8,
                 *args,
                 **kwargs):
        """
        PyTorch Lightning implementation of `Moco <https://arxiv.org/abs/2003.04297>`_

        Paper authors: Xinlei Chen, Haoqi Fan, Ross Girshick, Kaiming He.

        Code adapted from `facebookresearch/moco <https://github.com/facebookresearch/moco>`_ to Lightning by:

            - `William Falcon <https://github.com/williamFalcon>`_

        Example:

            >>> from pl_bolts.models.self_supervised import MocoV2
            ...
            >>> model = MocoV2()

        Train::

            trainer = Trainer()
            trainer.fit(model)

        CLI command::

            # cifar10
            python moco2_module.py --gpus 1

            # imagenet
            python moco2_module.py
                --gpus 8
                --dataset imagenet2012
                --data_dir /path/to/imagenet/
                --meta_dir /path/to/folder/with/meta.bin/
                --batch_size 32

        Args:
            base_encoder: torchvision model name or torch.nn.Module
            emb_dim: feature dimension (default: 128)
            num_negatives: queue size; number of negative keys (default: 65536)
            encoder_momentum: moco momentum of updating key encoder (default: 0.999)
            softmax_temperature: softmax temperature (default: 0.07)
            learning_rate: the learning rate
            momentum: optimizer momentum
            weight_decay: optimizer weight decay
            datamodule: the DataModule (train, val, test dataloaders)
            data_dir: the directory to store data
            batch_size: batch size
            use_mlp: add an mlp to the encoders
            num_workers: workers for the loaders
        """

        super().__init__()
        self.save_hyperparameters()

        # use CIFAR-10 by default if no datamodule passed in
        if datamodule is None:
            datamodule = CIFAR10DataModule(data_dir)
            datamodule.train_transforms = Moco2TrainCIFAR10Transforms()
            datamodule.val_transforms = Moco2EvalCIFAR10Transforms()

        self.datamodule = datamodule

        # create the encoders
        # num_classes is the output fc dimension
        self.encoder_q, self.encoder_k = self.init_encoders(base_encoder)

        if use_mlp:  # hack: brute-force replacement
            dim_mlp = self.encoder_q.fc.weight.shape[1]
            self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp),
                                              nn.ReLU(), self.encoder_q.fc)
            self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp),
                                              nn.ReLU(), self.encoder_k.fc)

        for param_q, param_k in zip(self.encoder_q.parameters(),
                                    self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(emb_dim, num_negatives))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
Example #23
0
    def __post_init__(
        self,
        observation_space: gym.Space = None,
        action_space: gym.Space = None,
        reward_space: gym.Space = None,
    ):
        """ Initializes the fields of the setting that weren't set from the
        command-line.
        """
        logger.debug("__post_init__ of Setting")
        # BUG: simple-parsing sometimes parses a list with a single item, itself the
        # list of transforms. Not sure if this still happens.

        def is_list_of_list(v: Any) -> bool:
            return isinstance(v, list) and len(v) == 1 and isinstance(v[0], list)

        if is_list_of_list(self.train_transforms):
            self.train_transforms = self.train_transforms[0]
        if is_list_of_list(self.val_transforms):
            self.val_transforms = self.val_transforms[0]
        if is_list_of_list(self.test_transforms):
            self.test_transforms = self.test_transforms[0]

        if all(
            t is None
            for t in [
                self.transforms,
                self.train_transforms,
                self.val_transforms,
                self.test_transforms,
            ]
        ):
            # Use these two transforms by default if no transforms are passed at all.
            # TODO: Remove this after the competition perhaps.
            self.transforms = Compose([Transforms.to_tensor, Transforms.three_channels])

        # If the constructor is called with just the `transforms` argument, like this:
        # <SomeSetting>(dataset="bob", transforms=foo_transform)
        # Then we use this value as the default for the train, val and test transforms.
        if self.transforms and not any(
            [self.train_transforms, self.val_transforms, self.test_transforms]
        ):
            if not isinstance(self.transforms, list):
                self.transforms = Compose([self.transforms])
            self.train_transforms = self.transforms.copy()
            self.val_transforms = self.transforms.copy()
            self.test_transforms = self.transforms.copy()

        if self.train_transforms is not None and not isinstance(
            self.train_transforms, list
        ):
            self.train_transforms = [self.train_transforms]

        if self.val_transforms is not None and not isinstance(
            self.val_transforms, list
        ):
            self.val_transforms = [self.val_transforms]

        if self.test_transforms is not None and not isinstance(
            self.test_transforms, list
        ):
            self.test_transforms = [self.test_transforms]

        # Actually compose the list of Transforms or callables into a single transform.
        self.train_transforms: Compose = Compose(self.train_transforms or [])
        self.val_transforms: Compose = Compose(self.val_transforms or [])
        self.test_transforms: Compose = Compose(self.test_transforms or [])

        LightningDataModule.__init__(
            self,
            train_transforms=self.train_transforms,
            val_transforms=self.val_transforms,
            test_transforms=self.test_transforms,
        )

        self._observation_space = observation_space
        self._action_space = action_space
        self._reward_space = reward_space

        # TODO: It's a bit confusing to also have a `config` attribute on the
        # Setting. Might want to change this a bit.
        self.config: Config = None

        self.train_env: Environment = None  # type: ignore
        self.val_env: Environment = None  # type: ignore
        self.test_env: Environment = None  # type: ignore
Example #24
0
def train(
    parameters: dict,
    datamodule: LightningDataModule,
    model: LightningModule,
    callbacks: Optional[Union[List[Callback], Callback]] = None,
):
    validate_train_parameters(parameters)

    model_path = Path(parameters["path"])
    plot_path = Path(parameters["path"]) / "train_plots"
    plot_path.mkdir(parents=True, exist_ok=True)
    regression = parameters["data"]["targets"] == "regression"
    n_class = 1 if regression else len(parameters["data"]["targets"]["classes"])

    # Callbacks
    if callbacks is None:
        callbacks = []
    elif not isinstance(callbacks, list):
        callbacks = [callbacks]
    timer_callback = TimerCallback()

    early_stop_callback = (
        EarlyStopping(
            monitor="val_loss", min_delta=0.00, patience=parameters["trainer"]["patience"], verbose=True, mode="min"
        )
        if parameters["trainer"]["patience"] > 0
        else None
    )
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        mode="min",
        dirpath=str(model_path),
        filename="model",
    )
    callbacks += [
        early_stop_callback,
        checkpoint_callback,
        InfoLogCallback(),
        timer_callback,
        PlotTrainValCurveCallback(plot_path, "loss"),
        PlotTrainValCurveCallback(plot_path, "accuracy"),
    ]
    callbacks = [c for c in callbacks if c is not None]

    # Logging
    im_logger = InMemoryLogger()
    pl_logger = [im_logger]

    # Tracking
    tracking: Tracking = NeptuneNewTracking(
        parameters=parameters, tags=parameters.get("tags", []), disabled=not parameters["tracking"] or parameters["tune_lr"]
    )
    tracking_logger = tracking.get_callback("pytorch-lightning")
    if tracking_logger is not None:
        pl_logger.append(tracking_logger)
        callbacks.append(LearningRateMonitor(logging_interval="step"))

    trainer = pl.Trainer(
        gpus=parameters["system"]["gpus"],
        logger=pl_logger,
        callbacks=callbacks,
        max_epochs=parameters["trainer"]["epochs"],
        num_sanity_val_steps=0,
        fast_dev_run=False,
        accelerator=None if parameters["system"]["gpus"] in [None, 0, 1] else "dp",
        sync_batchnorm=True,
        deterministic=True,
    )

    if parameters["tune_lr"]:
        lr_find(
            trainer,
            model,
            datamodule,
            model_path,
            min_lr=parameters.get("min_lr", 1e-6),
            max_lr=parameters.get("max_lr", 1e-2),
            num_training=parameters.get("num_training", 100),
        )
        return

    trainer.fit(model, datamodule=datamodule)

    # Save pytorch model / nn.Module / weights
    torch.save(model.state_dict(), Path(model_path / "model.pt"))

    # Save meaningful parameters for loading
    model_config = {
        "model": type(model).__name__,
        "backbone": model.backbone_name,
        "n_channel": datamodule.size(0),
        "output_size": model.output_size,
        "hparams": dict(model.hparams),
    }
    write_yaml(model_path / "model_config.yaml", model_config)

    # Log actual config used for training
    write_yaml(model_path / "config.yaml", parameters)

    # Log model summary
    pytorch_model_summary(model, model_path)

    # Output tracking run id to continue logging in test step
    run_id = tracking.get_id()

    # Metadata
    steps_per_epoch = len(datamodule.train_dataloader())
    metadata = get_training_summary(
        model_path / "model.ckpt",
        parameters,
        run_id,
        timer_callback,
        datamodule,
        early_stop_callback,
        checkpoint_callback,
        steps_per_epoch,
        save_path=plot_path,
    )
    write_yaml(model_path / "metadata.yaml", metadata)
    tracking.log_property("metadata", metadata)
    tracking.log_artifact(plot_path / "data-examples-train.png")
    tracking.log_artifact(plot_path / "data-examples-val.png")
    if (plot_path / "data-examples-test.png").exists():
        tracking.log_artifact(plot_path / "data-examples-test.png")

    tracking.end()
    def __init__(self,
                 datamodule: pl.LightningDataModule = None,
                 data_dir: str = './',
                 learning_rate: float = 0.2,
                 weight_decay: float = 15e-6,
                 input_height: int = 32,
                 batch_size: int = 32,
                 num_workers: int = 4,
                 warmup_epochs: int = 10,
                 max_epochs: int = 1000,
                 **kwargs):
        """
        PyTorch Lightning implementation of `Bring Your Own Latent (BYOL)
        <https://arxiv.org/pdf/2006.07733.pdf.>`_

        Paper authors: Jean-Bastien Grill ,Florian Strub, Florent Altché, Corentin Tallec, Pierre H. Richemond, \
        Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Daniel Guo, Mohammad Gheshlaghi Azar, \
        Bilal Piot, Koray Kavukcuoglu, Rémi Munos, Michal Valko.

        Model implemented by:
            - `Annika Brundyn <https://github.com/annikabrundyn>`_

        .. warning:: Work in progress. This implementation is still being verified.

        TODOs:
            - verify on CIFAR-10
            - verify on STL-10
            - pre-train on imagenet

        Example:

            >>> from pl_bolts.models.self_supervised import BYOL
            ...
            >>> model = BYOL()

        Train::

            trainer = Trainer()
            trainer.fit(model)

        CLI command::

            # cifar10
            python byol_module.py --gpus 1

            # imagenet
            python byol_module.py
                --gpus 8
                --dataset imagenet2012
                --data_dir /path/to/imagenet/
                --meta_dir /path/to/folder/with/meta.bin/
                --batch_size 32

        Args:
            datamodule: The datamodule
            data_dir: directory to store data
            learning_rate: the learning rate
            weight_decay: optimizer weight decay
            input_height: image input height
            batch_size: the batch size
            num_workers: number of workers
            warmup_epochs: num of epochs for scheduler warm up
            max_epochs: max epochs for scheduler
        """
        super().__init__()
        self.save_hyperparameters()

        # init default datamodule
        if datamodule is None:
            datamodule = CIFAR10DataModule(data_dir,
                                           num_workers=num_workers,
                                           batch_size=batch_size)
            datamodule.train_transforms = SimCLRTrainDataTransform(
                input_height)
            datamodule.val_transforms = SimCLREvalDataTransform(input_height)

        self.datamodule = datamodule

        self.online_network = SiameseArm()
        self.target_network = deepcopy(self.online_network)

        self.weight_callback = BYOLMAWeightUpdate()

        # for finetuning callback
        self.z_dim = 2048
        self.num_classes = self.datamodule.num_classes
    def __init__(
        self,
        datamodule: pl.LightningDataModule = None,
        encoder: Union[str, torch.nn.Module,
                       pl.LightningModule] = 'cpc_encoder',
        patch_size: int = 8,
        patch_overlap: int = 4,
        online_ft: int = True,
        task: str = 'cpc',
        num_workers: int = 4,
        learning_rate: int = 1e-4,
        data_dir: str = '',
        batch_size: int = 32,
        pretrained: str = None,
        **kwargs,
    ):
        """
        Args:
            datamodule: A Datamodule (optional). Otherwise set the dataloaders directly
            encoder: A string for any of the resnets in torchvision, or the original CPC encoder,
                or a custon nn.Module encoder
            patch_size: How big to make the image patches
            patch_overlap: How much overlap should each patch have.
            online_ft: Enable a 1024-unit MLP to fine-tune online
            task: Which self-supervised task to use ('cpc', 'amdim', etc...)
            num_workers: num dataloader worksers
            learning_rate: what learning rate to use
            data_dir: where to store data
            batch_size: batch size
            pretrained: If true, will use the weights pretrained (using CPC) on Imagenet
        """

        super().__init__()
        self.save_hyperparameters()

        self.online_evaluator = self.hparams.online_ft

        if pretrained:
            self.hparams.dataset = pretrained
            self.online_evaluator = True

        # link data
        if datamodule is None:
            datamodule = CIFAR10DataModule(
                self.hparams.data_dir,
                num_workers=self.hparams.num_workers,
                batch_size=batch_size)
            datamodule.train_transforms = CPCTrainTransformsCIFAR10()
            datamodule.val_transforms = CPCEvalTransformsCIFAR10()
        self.datamodule = datamodule

        # init encoder
        self.encoder = encoder
        if isinstance(encoder, str):
            self.encoder = self.init_encoder()

        # info nce loss
        c, h = self.__compute_final_nb_c(self.hparams.patch_size)
        self.contrastive_task = CPCTask(num_input_channels=c,
                                        target_dim=64,
                                        embed_scale=0.1)

        self.z_dim = c * h * h
        self.num_classes = self.datamodule.num_classes

        if pretrained:
            self.load_pretrained(encoder)
    def __init__(self,
                 datamodule: pl.LightningDataModule = None,
                 data_dir: str = './',
                 learning_rate: float = 0.00006,
                 weight_decay: float = 0.0005,
                 input_height: int = 32,
                 batch_size: int = 128,
                 online_ft: bool = False,
                 num_workers: int = 4,
                 optimizer: str = 'lars',
                 lr_sched_step: float = 30.0,
                 lr_sched_gamma: float = 0.5,
                 lars_momentum: float = 0.9,
                 lars_eta: float = 0.001,
                 loss_temperature: float = 0.5,
                 **kwargs):
        """
        PyTorch Lightning implementation of `SIMCLR <https://arxiv.org/abs/2002.05709.>`_

        Paper authors: Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton.

        Model implemented by:

            - `William Falcon <https://github.com/williamFalcon>`_
            - `Tullie Murrell <https://github.com/tullie>`_

        Example:

            >>> from pl_bolts.models.self_supervised import SimCLR
            ...
            >>> model = SimCLR()

        Train::

            trainer = Trainer()
            trainer.fit(model)

        CLI command::

            # cifar10
            python simclr_module.py --gpus 1

            # imagenet
            python simclr_module.py
                --gpus 8
                --dataset imagenet2012
                --data_dir /path/to/imagenet/
                --meta_dir /path/to/folder/with/meta.bin/
                --batch_size 32

        Args:
            datamodule: The datamodule
            data_dir: directory to store data
            learning_rate: the learning rate
            weight_decay: optimizer weight decay
            input_height: image input height
            batch_size: the batch size
            online_ft: whether to tune online or not
            num_workers: number of workers
            optimizer: optimizer name
            lr_sched_step: step for learning rate scheduler
            lr_sched_gamma: gamma for learning rate scheduler
            lars_momentum: the mom param for lars optimizer
            lars_eta: for lars optimizer
            loss_temperature: float = 0.
        """
        super().__init__()
        self.save_hyperparameters()
        self.online_evaluator = online_ft

        # init default datamodule
        if datamodule is None:
            datamodule = CIFAR10DataModule(data_dir,
                                           num_workers=num_workers,
                                           batch_size=batch_size)
            datamodule.train_transforms = SimCLRTrainDataTransform(
                input_height)
            datamodule.val_transforms = SimCLREvalDataTransform(input_height)

        self.datamodule = datamodule

        self.loss_func = self.init_loss()
        self.encoder = self.init_encoder()
        self.projection = self.init_projection()

        if self.online_evaluator:
            z_dim = self.projection.output_dim
            num_classes = self.datamodule.num_classes
            self.non_linear_evaluator = SSLEvaluator(n_input=z_dim,
                                                     n_classes=num_classes,
                                                     p=0.2,
                                                     n_hidden=1024)
 def manual_test_step(self, dm: LightningDataModule) -> None:
     """Manually perform the test step. Used for debugging."""
     dm.setup()
     batch = next(iter(dm.test_dataloader()))
     self.test_step(batch, 0)