import importlib
import os

import numpy as np
from torch.utils.data import Dataset

from pl_bolts.utils.warnings import warn_missing_pkg

_PIL_AVAILABLE = importlib.util.find_spec("PIL") is not None
if _PIL_AVAILABLE:
    from PIL import Image
else:
    warn_missing_pkg('PIL')  # pragma: no-cover

DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1)
DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26,
                        27, 28, 31, 32, 33)


class KittiDataset(Dataset):
    """
    Note:
        You need to have downloaded the Kitti dataset first and provide the path to where it is saved.
        You can download the dataset here: http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015

    There are 34 classes, however not all of them are useful for training (e.g. railings on highways). These
    useless classes (the pixel values of these classes) are stored in `void_labels`. Useful classes are stored
    in `valid_labels`.

    The `encode_segmap` function sets all pixels with any of the `void_labels` to `ignore_index`
    (250 by default). It also sets all of the valid pixels to the appropriate value between 0 and
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from pl_bolts.datamodules.experience_source import Experience, ExperienceSourceDataset
from pl_bolts.losses.rl import dqn_loss
from pl_bolts.models.rl.common.agents import ValueAgent
from pl_bolts.models.rl.common.gym_wrappers import make_environment
from pl_bolts.models.rl.common.memory import MultiStepBuffer
from pl_bolts.models.rl.common.networks import CNN
from pl_bolts.utils import _GYM_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _GYM_AVAILABLE:
    from gym import Env
else:
    warn_missing_pkg('gym')  # pragma: no-cover
    Env = object


class DQN(pl.LightningModule):
    """
    Basic DQN Model

    PyTorch Lightning implementation of `DQN <https://arxiv.org/abs/1312.5602>`_
    Paper authors: Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Alex Graves,
    Ioannis Antonoglou, Daan Wierstra, Martin Riedmiller.
    Model implemented by:

        - `Donal Byrne <https://github.com/djbyrne>`

    Example:
import math
from typing import Any

import numpy as np
import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset

from pl_bolts.utils.warnings import warn_missing_pkg

try:
    from sklearn.utils import shuffle as sk_shuffle
except ModuleNotFoundError:
    warn_missing_pkg("sklearn")  # pragma: no-cover
    _SKLEARN_AVAILABLE = False
else:
    _SKLEARN_AVAILABLE = True


class SklearnDataset(Dataset):
    """
    Mapping between numpy (or sklearn) datasets to PyTorch datasets.

    Example:
        >>> from sklearn.datasets import load_boston
        >>> from pl_bolts.datamodules import SklearnDataset
        ...
        >>> X, y = load_boston(return_X_y=True)
        >>> dataset = SklearnDataset(X, y)
        >>> len(dataset)
        506
Ejemplo n.º 4
0
import random

from pl_bolts.transforms.dataset_normalizations import (imagenet_normalization,
                                                        cifar10_normalization,
                                                        stl10_normalization)
from pl_bolts.utils.warnings import warn_missing_pkg

try:
    from torchvision import transforms
except ModuleNotFoundError:
    warn_missing_pkg('torchvision')  # pragma: no-cover
    _TORCHVISION_AVAILABLE = False
else:
    _TORCHVISION_AVAILABLE = True

try:
    from PIL import ImageFilter
except ModuleNotFoundError:
    warn_missing_pkg('PIL', pypi_name='Pillow')  # pragma: no-cover
    _PIL_AVAILABLE = False
else:
    _PIL_AVAILABLE = True


class Moco2TrainCIFAR10Transforms:
    """
    Moco 2 augmentation:
    https://arxiv.org/pdf/2003.04297.pdf
    """
    def __init__(self, height: int = 32):
        if not _TORCHVISION_AVAILABLE:
Ejemplo n.º 5
0
from typing import Sequence

import torch
from pytorch_lightning import Callback, LightningModule, Trainer
from torch import Tensor, nn

from pl_bolts.utils import _MATPLOTLIB_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _MATPLOTLIB_AVAILABLE:
    from matplotlib import pyplot as plt
    from matplotlib.axes import Axes
    from matplotlib.figure import Figure
else:  # pragma: no cover
    warn_missing_pkg("matplotlib")
    Axes = object
    Figure = object


class ConfusedLogitCallback(Callback):  # pragma: no cover
    """Takes the logit predictions of a model and when the probabilities of two classes are very close, the model
    doesn't have high certainty that it should pick one vs the other class.

    This callback shows how the input would have to change to swing the model from one label prediction
    to the other.

    In this case, the network predicts a 5... but gives almost equal probability to an 8.
    The images show what about the original 5 would have to change to make it more like a 5 or more like an 8.

    For each confused logit the confused images are generated by taking the gradient from a logit wrt an input
    for the top two closest logits.
https://github.com/Shmuma/ptan/blob/master/ptan/common/wrappers.py
"""
import collections

import numpy as np
import torch

from pl_bolts.utils import _GYM_AVAILABLE, _OPENCV_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _GYM_AVAILABLE:
    import gym.spaces
    from gym import make as gym_make
    from gym import ObservationWrapper, Wrapper
else:  # pragma: no-cover
    warn_missing_pkg('gym')
    Wrapper = object
    ObservationWrapper = object

if _OPENCV_AVAILABLE:
    import cv2
else:
    warn_missing_pkg('cv2', pypi_name='opencv-python')  # pragma: no-cover


class ToTensor(Wrapper):
    """For environments where the user need to press FIRE for the game to start."""

    def __init__(self, env=None):
        super(ToTensor, self).__init__(env)
Ejemplo n.º 7
0
from typing import Any

from pl_bolts.datasets.mnist_dataset import MNIST
from pl_bolts.datasets.sr_dataset_mixin import SRDatasetMixin
from pl_bolts.utils import _PIL_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _PIL_AVAILABLE:
    from PIL import Image
else:  # pragma: no cover
    warn_missing_pkg("PIL", pypi_name="Pillow")


class SRMNIST(SRDatasetMixin, MNIST):
    """MNIST dataset that can be used to train Super Resolution models.

    Function __getitem__ (implemented in SRDatasetMixin) returns tuple of high and low resolution image.
    """
    def __init__(self, scale_factor: int, *args: Any, **kwargs: Any) -> None:
        hr_image_size = 28
        lr_image_size = hr_image_size // scale_factor
        self.image_channels = 1
        super().__init__(hr_image_size, lr_image_size, self.image_channels,
                         *args, **kwargs)

    def _get_image(self, index: int):
        return Image.fromarray(self.data[index].numpy(), mode="L")
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from pl_bolts.datamodules.experience_source import Experience, ExperienceSourceDataset
from pl_bolts.losses.rl import dqn_loss
from pl_bolts.models.rl.common.agents import ValueAgent
from pl_bolts.models.rl.common.gym_wrappers import make_environment
from pl_bolts.models.rl.common.memory import MultiStepBuffer
from pl_bolts.models.rl.common.networks import CNN
from pl_bolts.utils import _GYM_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _GYM_AVAILABLE:
    from gym import Env
else:  # pragma: no cover
    warn_missing_pkg('gym')
    Env = object


class DQN(LightningModule):
    """
    Basic DQN Model

    PyTorch Lightning implementation of `DQN <https://arxiv.org/abs/1312.5602>`_
    Paper authors: Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Alex Graves,
    Ioannis Antonoglou, Daan Wierstra, Martin Riedmiller.
    Model implemented by:

        - `Donal Byrne <https://github.com/djbyrne>`

    Example:
Ejemplo n.º 9
0
import numpy as np

from pl_bolts.utils import _OPENCV_AVAILABLE, _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
    from torchvision import transforms as transforms
else:  # pragma: no cover
    warn_missing_pkg('torchvision')

if _OPENCV_AVAILABLE:
    import cv2
else:  # pragma: no cover
    warn_missing_pkg('cv2', pypi_name='opencv-python')


class SimCLRTrainDataTransform(object):
    """
    Transforms for SimCLR

    Transform::

        RandomResizedCrop(size=self.input_height)
        RandomHorizontalFlip()
        RandomApply([color_jitter], p=0.8)
        RandomGrayscale(p=0.2)
        GaussianBlur(kernel_size=int(0.1 * self.input_height))
        transforms.ToTensor()

    Example::
Ejemplo n.º 10
0
import math
from typing import List, Sequence, Tuple, Union

import numpy as np
import torch
from torch import Tensor

from pl_bolts.utils import _SKLEARN_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _SKLEARN_AVAILABLE:
    from sklearn.utils import shuffle as sk_shuffle
else:  # pragma: no cover
    warn_missing_pkg("sklearn", pypi_name="scikit-learn")


class Identity(torch.nn.Module):
    """An identity class to replace arbitrary layers in pretrained models.

    Example::

        from pl_bolts.utils import Identity

        model = resnet18()
        model.fc = Identity()
    """

    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: Tensor) -> Tensor:
Ejemplo n.º 11
0
from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_9_1
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
    from torchvision.datasets import MNIST
else:  # pragma: no cover
    warn_missing_pkg("torchvision")
    MNIST = object

if _PIL_AVAILABLE:
    from PIL import Image
else:  # pragma: no cover
    warn_missing_pkg("PIL", pypi_name="Pillow")

# TODO(akihironitta): This is needed to avoid 503 error when downloading MNIST dataset
# from http://yann.lecun.com/exdb/mnist/ and can be removed after `torchvision==0.9.1`.
# See https://github.com/pytorch/vision/issues/3549 for details.
if _TORCHVISION_AVAILABLE and _TORCHVISION_LESS_THAN_0_9_1:
    MNIST.resources = [
        (
            "https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz",
            "f68b3c2dcbeaaa9fbdd348bbdeb94873",
        ),
        (
            "https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz",
            "d53e105ee54ea40749a09fcbcd1e9432",
        ),
        (
            "https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz",
            "9fb629c4189551a2d022fa330f9573f3",
        ),
import os
import pickle
import tarfile
from typing import Callable, Optional, Sequence, Tuple

import torch
from torch import Tensor

from pl_bolts.datasets.base_dataset import LightDataset
from pl_bolts.utils import _PIL_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _PIL_AVAILABLE:
    from PIL import Image
else:  # pragma: no cover
    warn_missing_pkg('PIL', pypi_name='Pillow')


class CIFAR10(LightDataset):
    """
    Customized `CIFAR10 <http://www.cs.toronto.edu/~kriz/cifar.html>`_ dataset for testing Pytorch Lightning
    without the torchvision dependency.

    Part of the code was copied from
    https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/

    Args:
        data_dir: Root directory of dataset where ``CIFAR10/processed/training.pt``
            and  ``CIFAR10/processed/test.pt`` exist.
        train: If ``True``, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
Ejemplo n.º 13
0
import importlib
import math

import numpy as np
import torch

from pl_bolts.utils.warnings import warn_missing_pkg

_SKLEARN_AVAILABLE = importlib.util.find_spec("sklearn") is not None
if _SKLEARN_AVAILABLE:
    from sklearn.utils import shuffle as sk_shuffle
else:
    warn_missing_pkg('sklearn', pypi_name='scikit-learn')  # pragma: no-cover


class Identity(torch.nn.Module):
    """
    An identity class to replace arbitrary layers in pretrained models

    Example::

        from pl_bolts.utils import Identity

        model = resnet18()
        model.fc = Identity()

    """
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
"""
import importlib
from abc import ABC
from collections import deque, namedtuple
from typing import Callable, Iterable, List, Tuple

import torch
from torch.utils.data import IterableDataset

from pl_bolts.utils.warnings import warn_missing_pkg

_GYM_AVAILABLE = importlib.util.find_spec("gym") is not None
if _GYM_AVAILABLE:
    from gym import Env
else:
    warn_missing_pkg("gym")  # pragma: no-cover


# Datasets

Experience = namedtuple(
    "Experience", field_names=["state", "action", "reward", "done", "new_state"]
)


class ExperienceSourceDataset(IterableDataset):
    """
    Basic experience source dataset. Takes a generate_batch function that returns an iterator.
    The logic for the experience source and how the batch is generated is defined the Lightning model itself
    """
Ejemplo n.º 15
0
https://github.com/Shmuma/ptan/blob/master/ptan/common/wrappers.py
"""
import collections

import numpy as np
import torch

from pl_bolts.utils import _GYM_AVAILABLE, _OPENCV_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _GYM_AVAILABLE:
    import gym.spaces
    from gym import make as gym_make
    from gym import ObservationWrapper, Wrapper
else:  # pragma: no cover
    warn_missing_pkg('gym')
    Wrapper = object
    ObservationWrapper = object

if _OPENCV_AVAILABLE:
    import cv2
else:  # pragma: no cover
    warn_missing_pkg('cv2', pypi_name='opencv-python')


class ToTensor(Wrapper):
    """For environments where the user need to press FIRE for the game to start."""
    def __init__(self, env=None):
        if not _GYM_AVAILABLE:  # pragma: no cover
            raise ModuleNotFoundError(
                'You want to use `gym` which is not installed yet.')
Ejemplo n.º 16
0
import torch.nn as nn
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.loggers import LightningLoggerBase, TensorBoardLogger, WandbLogger
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from torch import Tensor
from torch.nn import Module
from torch.utils.hooks import RemovableHandle

from pl_bolts.utils import _WANDB_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _WANDB_AVAILABLE:
    import wandb
else:  # pragma: no cover
    warn_missing_pkg("wandb")
    wandb = None


class DataMonitorBase(Callback):

    supported_loggers = (
        TensorBoardLogger,
        WandbLogger,
    )

    def __init__(self, log_every_n_steps: int = None):
        """
        Base class for monitoring data histograms in a LightningModule.
        This requires a logger configured in the Trainer, otherwise no data is logged.
        The specific class that inherits from this base defines what data gets collected.
Ejemplo n.º 17
0
from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
    from torchvision.datasets import MNIST
else:  # pragma: no cover
    warn_missing_pkg('torchvision')
    MNIST = object

if _PIL_AVAILABLE:
    from PIL import Image
else:  # pragma: no cover
    warn_missing_pkg('PIL', pypi_name='Pillow')


class BinaryMNIST(MNIST):
    def __getitem__(self, idx):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        if not _TORCHVISION_AVAILABLE:  # pragma: no cover
            raise ModuleNotFoundError(
                'You want to use `torchvision` which is not installed yet.')

        img, target = self.data[idx], int(self.targets[idx])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
Ejemplo n.º 18
0
"""Set of wrapper functions for gym environments taken from
https://github.com/Shmuma/ptan/blob/master/ptan/common/wrappers.py."""
import collections

import numpy as np
import torch

from pl_bolts.utils import _GYM_AVAILABLE, _OPENCV_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _GYM_AVAILABLE:
    import gym.spaces
    from gym import ObservationWrapper, Wrapper
    from gym import make as gym_make
else:  # pragma: no cover
    warn_missing_pkg("gym")
    Wrapper = object
    ObservationWrapper = object

if _OPENCV_AVAILABLE:
    import cv2
else:  # pragma: no cover
    warn_missing_pkg("cv2", pypi_name="opencv-python")


class ToTensor(Wrapper):
    """For environments where the user need to press FIRE for the game to start."""
    def __init__(self, env=None):
        if not _GYM_AVAILABLE:  # pragma: no cover
            raise ModuleNotFoundError(
                "You want to use `gym` which is not installed yet.")
from typing import Optional

from torch.utils.data import random_split

from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet
from pl_bolts.datasets.ssl_amdim_datasets import CIFAR10Mixed
from pl_bolts.models.self_supervised.amdim import transforms as amdim_transforms
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
    from torchvision.datasets import STL10
else:  # pragma: no cover
    warn_missing_pkg('torchvision')


class AMDIMPretraining():
    """"
    For pretraining we use the train transform for both train and val.
    """
    @staticmethod
    def cifar10(dataset_root, split: str = 'train'):
        assert split in ('train', 'val')
        dataset = CIFAR10Mixed(
            root=dataset_root,
            split=split,
            transform=amdim_transforms.AMDIMTrainTransformsCIFAR10(),
            download=True,
        )
        return dataset
Ejemplo n.º 20
0
import math

import numpy as np
import torch

from pl_bolts.utils import _SKLEARN_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _SKLEARN_AVAILABLE:
    from sklearn.utils import shuffle as sk_shuffle
else:  # pragma: no-cover
    warn_missing_pkg('sklearn', pypi_name='scikit-learn')


class Identity(torch.nn.Module):
    """
    An identity class to replace arbitrary layers in pretrained models

    Example::

        from pl_bolts.utils import Identity

        model = resnet18()
        model.fc = Identity()

    """
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x
Ejemplo n.º 21
0
"""Datamodules for RL models that rely on experiences generated during training Based on implementations found
here: https://github.com/Shmuma/ptan/blob/master/ptan/experience.py."""
from abc import ABC
from collections import deque, namedtuple
from typing import Callable, Iterator, List, Tuple

import torch
from torch.utils.data import IterableDataset

from pl_bolts.utils import _GYM_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _GYM_AVAILABLE:
    from gym import Env
else:  # pragma: no cover
    warn_missing_pkg("gym")
    Env = object

Experience = namedtuple(
    "Experience",
    field_names=["state", "action", "reward", "done", "new_state"])


class ExperienceSourceDataset(IterableDataset):
    """Basic experience source dataset.

    Takes a generate_batch function that returns an iterator. The logic for the experience source and how the batch is
    generated is defined the Lightning model itself
    """
    def __init__(self, generate_batch: Callable) -> None:
        self.generate_batch = generate_batch
Ejemplo n.º 22
0
import os

import numpy as np
from torch.utils.data import Dataset

from pl_bolts.utils import _PIL_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _PIL_AVAILABLE:
    from PIL import Image
else:  # pragma: no cover
    warn_missing_pkg('PIL')

DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1)
DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33)


class KittiDataset(Dataset):
    """
    Note:
        You need to have downloaded the Kitti dataset first and provide the path to where it is saved.
        You can download the dataset here: http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015

    There are 34 classes, however not all of them are useful for training (e.g. railings on highways). These
    useless classes (the pixel values of these classes) are stored in `void_labels`. Useful classes are stored
    in `valid_labels`.

    The `encode_segmap` function sets all pixels with any of the `void_labels` to `ignore_index`
    (250 by default). It also sets all of the valid pixels to the appropriate value between 0 and
    `len(valid_labels)` (since that is the number of valid classes), so it can be used properly by
    the loss function when comparing with the output.
Ejemplo n.º 23
0
import torch
from pytorch_lightning import LightningModule, Trainer, seed_everything
from torch import Tensor, nn
from torch.utils.data import DataLoader

from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler
from pl_bolts.models.gans.dcgan.components import DCGANDiscriminator, DCGANGenerator
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
    from torchvision import transforms as transform_lib
    from torchvision.datasets import LSUN, MNIST
else:  # pragma: no cover
    warn_missing_pkg("torchvision")


class DCGAN(LightningModule):
    """DCGAN implementation.

    Example::

        from pl_bolts.models.gans import DCGAN

        m = DCGAN()
        Trainer(gpus=2).fit(m)

    Example CLI::

        # mnist
Ejemplo n.º 24
0
from typing import Tuple

import numpy as np

from pl_bolts.utils import _OPENCV_AVAILABLE, _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
    from torchvision import transforms
else:  # pragma: no cover
    warn_missing_pkg("torchvision")

if _OPENCV_AVAILABLE:
    import cv2
else:  # pragma: no cover
    warn_missing_pkg("cv2", pypi_name="opencv-python")


class SwAVTrainDataTransform:
    def __init__(
        self,
        normalize=None,
        size_crops: Tuple[int] = (96, 36),
        nmb_crops: Tuple[int] = (2, 4),
        min_scale_crops: Tuple[float] = (0.33, 0.10),
        max_scale_crops: Tuple[float] = (1, 0.33),
        gaussian_blur: bool = True,
        jitter_strength: float = 1.0,
    ):
        self.jitter_strength = jitter_strength
        self.gaussian_blur = gaussian_blur
from typing import List

import numpy as np

from pl_bolts.utils import _OPENCV_AVAILABLE, _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
    import torchvision.transforms as transforms
else:
    warn_missing_pkg('torchvision')  # pragma: no-cover

if _OPENCV_AVAILABLE:
    import cv2
else:
    warn_missing_pkg('cv2', pypi_name='opencv-python')  # pragma: no-cover


class SwAVTrainDataTransform(object):
    def __init__(self,
                 normalize=None,
                 size_crops: List[int] = [96, 36],
                 nmb_crops: List[int] = [2, 4],
                 min_scale_crops: List[float] = [0.33, 0.10],
                 max_scale_crops: List[float] = [1, 0.33],
                 gaussian_blur: bool = True,
                 jitter_strength: float = 1.):
        self.jitter_strength = jitter_strength
        self.gaussian_blur = gaussian_blur

        assert len(size_crops) == len(nmb_crops)
Ejemplo n.º 26
0
import math
from typing import Any, Tuple

import numpy as np
import torch
from pytorch_lightning import LightningDataModule
from torch import Tensor
from torch.utils.data import DataLoader, Dataset

from pl_bolts.utils import _SKLEARN_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _SKLEARN_AVAILABLE:
    from sklearn.utils import shuffle as sk_shuffle
else:  # pragma: no cover
    warn_missing_pkg("sklearn")


class SklearnDataset(Dataset):
    """Mapping between numpy (or sklearn) datasets to PyTorch datasets.

    Example:
        >>> from sklearn.datasets import load_diabetes
        >>> from pl_bolts.datamodules import SklearnDataset
        ...
        >>> X, y = load_diabetes(return_X_y=True)
        >>> dataset = SklearnDataset(X, y)
        >>> len(dataset)
        442
    """
    def __init__(self,
from pl_bolts.transforms.self_supervised import RandomTranslateWithReflect
from pl_bolts.utils.warnings import warn_missing_pkg

try:
    from torchvision import transforms
except ModuleNotFoundError:
    warn_missing_pkg('torchvision')  # pragma: no-cover
    _TORCHVISION_AVAILABLE = False
else:
    _TORCHVISION_AVAILABLE = True


class AMDIMTrainTransformsCIFAR10:
    """
    Transforms applied to AMDIM

    Transforms::

        img_jitter,
        col_jitter,
        rnd_gray,
        transforms.ToTensor(),
        normalize

    Example::

        x = torch.rand(5, 3, 32, 32)

        transform = AMDIMTrainTransformsCIFAR10()
        (view1, view2) = transform(x)
Ejemplo n.º 28
0
import importlib

import torch
from pytorch_lightning import Callback
from torch import nn

from pl_bolts.utils.warnings import warn_missing_pkg

_MATPLOTLIB_AVAILABLE = importlib.util.find_spec("matplotlib") is not None
if _MATPLOTLIB_AVAILABLE:
    from matplotlib import pyplot as plt
else:
    warn_missing_pkg("matplotlib")  # pragma: no-cover


class ConfusedLogitCallback(Callback):  # pragma: no-cover
    """
    Takes the logit predictions of a model and when the probabilities of two classes are very close, the model
    doesn't have high certainty that it should pick one vs the other class.

    This callback shows how the input would have to change to swing the model from one label prediction
    to the other.

    In this case, the network predicts a 5... but gives almost equal probability to an 8.
    The images show what about the original 5 would have to change to make it more like a 5 or more like an 8.

    For each confused logit the confused images are generated by taking the gradient from a logit wrt an input
    for the top two closest logits.

    Example::