def __init__(
        self,
        datamodule: Union[str, pl.LightningDataModule] = 'cifar10',
        encoder: Union[str, torch.nn.Module,
                       pl.LightningModule] = 'amdim_encoder',
        contrastive_task: Union[
            FeatureMapContrastiveTask] = FeatureMapContrastiveTask(
                '01, 02, 11'),
        image_channels: int = 3,
        image_height: int = 32,
        encoder_feature_dim: int = 320,
        embedding_fx_dim: int = 1280,
        conv_block_depth: int = 10,
        use_bn: bool = False,
        tclip: int = 20.0,
        learning_rate: int = 2e-4,
        data_dir: str = '',
        num_classes: int = 10,
        batch_size: int = 200,
        num_workers: int = 16,
        **kwargs,
    ):
        """
        Args:
            datamodule: A LightningDatamodule
            encoder: an encoder string or model
            image_channels: 3
            image_height: pixels
            encoder_feature_dim: Called `ndf` in the paper, this is the representation size for the encoder.
            embedding_fx_dim: Output dim of the embedding function (`nrkhs` in the paper)
                (Reproducing Kernel Hilbert Spaces).
            conv_block_depth: Depth of each encoder block,
            use_bn: If true will use batchnorm.
            tclip: soft clipping non-linearity to the scores after computing the regularization term
                and before computing the log-softmax. This is the 'second trick' used in the paper
            learning_rate: The learning rate
            data_dir: Where to store data
            num_classes: How many classes in the dataset
            batch_size: The batch size
        """
        super().__init__()
        self.save_hyperparameters()

        # init encoder
        self.encoder = encoder
        if isinstance(encoder, str):
            self.encoder = self.init_encoder()

        # the task
        self.contrastive_task = contrastive_task

        self.tng_split = None
        self.val_split = None
Ejemplo n.º 2
0
def mix_and_match_any_part_or_subclass_example():
	from pl_bolts.models.self_supervised import CPC_v2
	from pl_bolts.losses.self_supervised_learning import FeatureMapContrastiveTask
	from pl_bolts.datamodules import CIFAR10DataModule
	from pl_bolts.models.self_supervised.cpc import CPCTrainTransformsCIFAR10, CPCEvalTransformsCIFAR10
	from pytorch_lightning.plugins import DDPPlugin

	# Data module.
	dm = CIFAR10DataModule(num_workers=12, batch_size=32)
	dm.train_transforms = CPCTrainTransformsCIFAR10()
	dm.val_transforms = CPCEvalTransformsCIFAR10()

	# Model.
	amdim_task = FeatureMapContrastiveTask(comparisons="01, 11, 02", bidirectional=True)
	model = CPC_v2(encoder="cpc_encoder", contrastive_task=amdim_task)

	# Fit.
	trainer = pl.Trainer(gpus=2, accelerator="ddp", plugins=DDPPlugin(find_unused_parameters=False))
	trainer.fit(model, datamodule=dm)
    def __init__(
            self,
            datamodule: Union[str, pl_bolts.datamodules.LightningDataModule] = 'cifar10',
            encoder: Union[str, torch.nn.Module, pl.LightningModule] = 'amdim_encoder',
            contrastive_task: Union[FeatureMapContrastiveTask] = FeatureMapContrastiveTask('01, 02, 11'),
            image_channels: int = 3,
            image_height: int = 32,
            encoder_feature_dim: int = 320,
            embedding_fx_dim: int = 1280,
            conv_block_depth: int = 10,
            use_bn: bool = False,
            tclip: int = 20.0,
            learning_rate: int = 2e-4,
            data_dir: str = '',
            num_classes: int = 10,
            batch_size: int = 200,
            **kwargs,
    ):
        """
        PyTorch Lightning implementation of
        `Augmented Multiscale Deep InfoMax (AMDIM) <https://arxiv.org/abs/1906.00910.>`_

        Paper authors: Philip Bachman, R Devon Hjelm, William Buchwalter.

        Model implemented by: `William Falcon <https://github.com/williamFalcon>`_

        This code is adapted to Lightning using the original author repo
        (`the original repo <https://github.com/Philip-Bachman/amdim-public>`_).

        Example:

            >>> from pl_bolts.models.self_supervised import AMDIM
            ...
            >>> model = AMDIM(encoder='resnet18')

        Train::

            trainer = Trainer()
            trainer.fit(model)

        Args:
            datamodule: A LightningDatamodule
            encoder: an encoder string or model
            image_channels: 3
            image_height: pixels
            encoder_feature_dim: Called `ndf` in the paper, this is the representation size for the encoder.
            embedding_fx_dim: Output dim of the embedding function (`nrkhs` in the paper)
                (Reproducing Kernel Hilbert Spaces).
            conv_block_depth: Depth of each encoder block,
            use_bn: If true will use batchnorm.
            tclip: soft clipping non-linearity to the scores after computing the regularization term
                and before computing the log-softmax. This is the 'second trick' used in the paper
            learning_rate: The learning rate
            data_dir: Where to store data
            num_classes: How many classes in the dataset
            batch_size: The batch size
        """
        super().__init__()
        self.save_hyperparameters()

        # init encoder
        self.encoder = encoder
        if isinstance(encoder, str):
            self.encoder = self.init_encoder()

        # the task
        self.contrastive_task = contrastive_task

        self.tng_split = None
        self.val_split = None