import os from typing import Any, Callable, Optional import torch from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader from torch.utils.data.dataset import random_split from ConSSL.datasets import KittiDataset from ConSSL.utils import _TORCHVISION_AVAILABLE from ConSSL.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: from torchvision import transforms as transforms else: # pragma: no cover warn_missing_pkg('torchvision') class KittiDataModule(LightningDataModule): name = 'kitti' def __init__( self, data_dir: Optional[str] = None, val_split: float = 0.2, test_split: float = 0.1, num_workers: int = 16, batch_size: int = 32, seed: int = 42, shuffle: bool = False,
from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.loggers import LightningLoggerBase, TensorBoardLogger, WandbLogger from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection from torch import nn as nn from torch import Tensor from torch.nn import Module from torch.utils.hooks import RemovableHandle from ConSSL.utils import _WANDB_AVAILABLE from ConSSL.utils.warnings import warn_missing_pkg if _WANDB_AVAILABLE: import wandb else: # pragma: no cover warn_missing_pkg("wandb") wandb = None # type: ignore class DataMonitorBase(Callback): supported_loggers = ( TensorBoardLogger, WandbLogger, ) def __init__(self, log_every_n_steps: int = None): """ Base class for monitoring data histograms in a LightningModule. This requires a logger configured in the Trainer, otherwise no data is logged. The specific class that inherits from this base defines what data gets collected.
from typing import Optional, Tuple, Union import numpy as np import torch from pytorch_lightning import Callback, LightningModule, Trainer from torch.utils.data import DataLoader from ConSSL.utils import _SKLEARN_AVAILABLE from ConSSL.utils.warnings import warn_missing_pkg if _SKLEARN_AVAILABLE: from sklearn.neighbors import KNeighborsClassifier else: # pragma: no cover warn_missing_pkg("sklearn", pypi_name="scikit-learn") class KNNOnlineEvaluator(Callback): # pragma: no cover """ Evaluates self-supervised K nearest neighbors. Example:: # your model must have 1 attribute model = Model() model.num_classes = ... # the num of classes in the model online_eval = KNNOnlineEvaluator( num_classes=model.num_classes, dataset='imagenet' )
import os import pickle import tarfile from typing import Callable, Optional, Sequence, Tuple import torch from torch import Tensor from ConSSL.datasets import LightDataset from ConSSL.utils import _PIL_AVAILABLE from ConSSL.utils.warnings import warn_missing_pkg if _PIL_AVAILABLE: from PIL import Image else: # pragma: no cover warn_missing_pkg('PIL', pypi_name='Pillow') class CIFAR10(LightDataset): """ Customized `CIFAR10 <http://www.cs.toronto.edu/~kriz/cifar.html>`_ dataset for testing Pytorch Lightning without the torchvision dependency. Part of the code was copied from https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/ Args: data_dir: Root directory of dataset where ``CIFAR10/processed/training.pt`` and ``CIFAR10/processed/test.pt`` exist. train: If ``True``, creates dataset from ``training.pt``, otherwise from ``test.pt``.
import math from typing import List import numpy as np import torch from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback from torch import Tensor from ConSSL.utils import _TORCHVISION_AVAILABLE from ConSSL.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: import torchvision else: # pragma: no cover warn_missing_pkg("torchvision") class LatentDimInterpolator(Callback): """ Interpolates the latent space for a model by setting all dims to zero and stepping through the first two dims increasing one unit at a time. Default interpolates between [-5, 5] (-5, -4, -3, ..., 3, 4, 5) Example:: from ConSSL.callbacks import LatentDimInterpolator Trainer(callbacks=[LatentDimInterpolator()]) """
from ConSSL.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE from ConSSL.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: from torchvision.datasets import MNIST else: # pragma: no cover warn_missing_pkg('torchvision') MNIST = object if _PIL_AVAILABLE: from PIL import Image else: # pragma: no cover warn_missing_pkg('PIL', pypi_name='Pillow') class BinaryMNIST(MNIST): def __getitem__(self, idx): """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ if not _TORCHVISION_AVAILABLE: # pragma: no cover raise ModuleNotFoundError( 'You want to use `torchvision` which is not installed yet.') img, target = self.data[idx], int(self.targets[idx]) # doing this so that it is consistent with all other datasets # to return a PIL Image
Based on implementations found here: https://github.com/Shmuma/ptan/blob/master/ptan/experience.py """ from abc import ABC from collections import deque, namedtuple from typing import Callable, Iterable, List, Tuple import torch from torch.utils.data import IterableDataset from ConSSL.utils import _GYM_AVAILABLE from ConSSL.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: from gym import Env else: # pragma: no cover warn_missing_pkg("gym") Env = object Experience = namedtuple( "Experience", field_names=["state", "action", "reward", "done", "new_state"]) class ExperienceSourceDataset(IterableDataset): """ Basic experience source dataset. Takes a generate_batch function that returns an iterator. The logic for the experience source and how the batch is generated is defined the Lightning model itself """ def __init__(self, generate_batch: Callable) -> None: self.generate_batch = generate_batch
import numpy as np from ConSSL.utils import _OPENCV_AVAILABLE, _TORCHVISION_AVAILABLE from ConSSL.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: from torchvision import transforms as transforms else: # pragma: no cover warn_missing_pkg('torchvision') if _OPENCV_AVAILABLE: import cv2 else: # pragma: no cover warn_missing_pkg('cv2', pypi_name='opencv-python') class SimCLRTrainDataTransform(object): """ Transforms for SimCLR Transform:: RandomResizedCrop(size=self.input_height) RandomHorizontalFlip() RandomApply([color_jitter], p=0.8) RandomGrayscale(p=0.2) GaussianBlur(kernel_size=int(0.1 * self.input_height)) transforms.ToTensor() Example::
from typing import Sequence import torch from pytorch_lightning import Callback, LightningModule, Trainer from torch import nn, Tensor from ConSSL.utils import _MATPLOTLIB_AVAILABLE from ConSSL.utils.warnings import warn_missing_pkg if _MATPLOTLIB_AVAILABLE: from matplotlib import pyplot as plt from matplotlib.axes import Axes from matplotlib.figure import Figure else: # pragma: no cover warn_missing_pkg("matplotlib") Axes = object Figure = object class ConfusedLogitCallback(Callback): # pragma: no cover """ Takes the logit predictions of a model and when the probabilities of two classes are very close, the model doesn't have high certainty that it should pick one vs the other class. This callback shows how the input would have to change to swing the model from one label prediction to the other. In this case, the network predicts a 5... but gives almost equal probability to an 8. The images show what about the original 5 would have to change to make it more like a 5 or more like an 8. For each confused logit the confused images are generated by taking the gradient from a logit wrt an input
import math from typing import List, Sequence, Tuple, Union import numpy as np import torch from torch import Tensor from ConSSL.utils import _SKLEARN_AVAILABLE from ConSSL.utils.warnings import warn_missing_pkg if _SKLEARN_AVAILABLE: from sklearn.utils import shuffle as sk_shuffle else: # pragma: no cover warn_missing_pkg('sklearn', pypi_name='scikit-learn') class Identity(torch.nn.Module): """ An identity class to replace arbitrary layers in pretrained models Example:: from ConSSL.utils import Identity model = resnet18() model.fc = Identity() """ def __init__(self) -> None: super(Identity, self).__init__()
import math from typing import Any, Tuple import numpy as np import torch from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset from ConSSL.utils import _SKLEARN_AVAILABLE from ConSSL.utils.warnings import warn_missing_pkg if _SKLEARN_AVAILABLE: from sklearn.utils import shuffle as sk_shuffle else: # pragma: no cover warn_missing_pkg("sklearn") class SklearnDataset(Dataset): """ Mapping between numpy (or sklearn) datasets to PyTorch datasets. Example: >>> from sklearn.datasets import load_boston >>> from ConSSL.datamodules import SklearnDataset ... >>> X, y = load_boston(return_X_y=True) >>> dataset = SklearnDataset(X, y) >>> len(dataset) 506 """ def __init__(self,