import importlib import os import numpy as np from torch.utils.data import Dataset from pl_bolts.utils.warnings import warn_missing_pkg _PIL_AVAILABLE = importlib.util.find_spec("PIL") is not None if _PIL_AVAILABLE: from PIL import Image else: warn_missing_pkg('PIL') # pragma: no-cover DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1) DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33) class KittiDataset(Dataset): """ Note: You need to have downloaded the Kitti dataset first and provide the path to where it is saved. You can download the dataset here: http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015 There are 34 classes, however not all of them are useful for training (e.g. railings on highways). These useless classes (the pixel values of these classes) are stored in `void_labels`. Useful classes are stored in `valid_labels`. The `encode_segmap` function sets all pixels with any of the `void_labels` to `ignore_index` (250 by default). It also sets all of the valid pixels to the appropriate value between 0 and
from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader from pl_bolts.datamodules.experience_source import Experience, ExperienceSourceDataset from pl_bolts.losses.rl import dqn_loss from pl_bolts.models.rl.common.agents import ValueAgent from pl_bolts.models.rl.common.gym_wrappers import make_environment from pl_bolts.models.rl.common.memory import MultiStepBuffer from pl_bolts.models.rl.common.networks import CNN from pl_bolts.utils import _GYM_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: from gym import Env else: warn_missing_pkg('gym') # pragma: no-cover Env = object class DQN(pl.LightningModule): """ Basic DQN Model PyTorch Lightning implementation of `DQN <https://arxiv.org/abs/1312.5602>`_ Paper authors: Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Alex Graves, Ioannis Antonoglou, Daan Wierstra, Martin Riedmiller. Model implemented by: - `Donal Byrne <https://github.com/djbyrne>` Example:
import math from typing import Any import numpy as np import torch from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset from pl_bolts.utils.warnings import warn_missing_pkg try: from sklearn.utils import shuffle as sk_shuffle except ModuleNotFoundError: warn_missing_pkg("sklearn") # pragma: no-cover _SKLEARN_AVAILABLE = False else: _SKLEARN_AVAILABLE = True class SklearnDataset(Dataset): """ Mapping between numpy (or sklearn) datasets to PyTorch datasets. Example: >>> from sklearn.datasets import load_boston >>> from pl_bolts.datamodules import SklearnDataset ... >>> X, y = load_boston(return_X_y=True) >>> dataset = SklearnDataset(X, y) >>> len(dataset) 506
import random from pl_bolts.transforms.dataset_normalizations import (imagenet_normalization, cifar10_normalization, stl10_normalization) from pl_bolts.utils.warnings import warn_missing_pkg try: from torchvision import transforms except ModuleNotFoundError: warn_missing_pkg('torchvision') # pragma: no-cover _TORCHVISION_AVAILABLE = False else: _TORCHVISION_AVAILABLE = True try: from PIL import ImageFilter except ModuleNotFoundError: warn_missing_pkg('PIL', pypi_name='Pillow') # pragma: no-cover _PIL_AVAILABLE = False else: _PIL_AVAILABLE = True class Moco2TrainCIFAR10Transforms: """ Moco 2 augmentation: https://arxiv.org/pdf/2003.04297.pdf """ def __init__(self, height: int = 32): if not _TORCHVISION_AVAILABLE:
from typing import Sequence import torch from pytorch_lightning import Callback, LightningModule, Trainer from torch import Tensor, nn from pl_bolts.utils import _MATPLOTLIB_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _MATPLOTLIB_AVAILABLE: from matplotlib import pyplot as plt from matplotlib.axes import Axes from matplotlib.figure import Figure else: # pragma: no cover warn_missing_pkg("matplotlib") Axes = object Figure = object class ConfusedLogitCallback(Callback): # pragma: no cover """Takes the logit predictions of a model and when the probabilities of two classes are very close, the model doesn't have high certainty that it should pick one vs the other class. This callback shows how the input would have to change to swing the model from one label prediction to the other. In this case, the network predicts a 5... but gives almost equal probability to an 8. The images show what about the original 5 would have to change to make it more like a 5 or more like an 8. For each confused logit the confused images are generated by taking the gradient from a logit wrt an input for the top two closest logits.
https://github.com/Shmuma/ptan/blob/master/ptan/common/wrappers.py """ import collections import numpy as np import torch from pl_bolts.utils import _GYM_AVAILABLE, _OPENCV_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: import gym.spaces from gym import make as gym_make from gym import ObservationWrapper, Wrapper else: # pragma: no-cover warn_missing_pkg('gym') Wrapper = object ObservationWrapper = object if _OPENCV_AVAILABLE: import cv2 else: warn_missing_pkg('cv2', pypi_name='opencv-python') # pragma: no-cover class ToTensor(Wrapper): """For environments where the user need to press FIRE for the game to start.""" def __init__(self, env=None): super(ToTensor, self).__init__(env)
from typing import Any from pl_bolts.datasets.mnist_dataset import MNIST from pl_bolts.datasets.sr_dataset_mixin import SRDatasetMixin from pl_bolts.utils import _PIL_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _PIL_AVAILABLE: from PIL import Image else: # pragma: no cover warn_missing_pkg("PIL", pypi_name="Pillow") class SRMNIST(SRDatasetMixin, MNIST): """MNIST dataset that can be used to train Super Resolution models. Function __getitem__ (implemented in SRDatasetMixin) returns tuple of high and low resolution image. """ def __init__(self, scale_factor: int, *args: Any, **kwargs: Any) -> None: hr_image_size = 28 lr_image_size = hr_image_size // scale_factor self.image_channels = 1 super().__init__(hr_image_size, lr_image_size, self.image_channels, *args, **kwargs) def _get_image(self, index: int): return Image.fromarray(self.data[index].numpy(), mode="L")
from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader from pl_bolts.datamodules.experience_source import Experience, ExperienceSourceDataset from pl_bolts.losses.rl import dqn_loss from pl_bolts.models.rl.common.agents import ValueAgent from pl_bolts.models.rl.common.gym_wrappers import make_environment from pl_bolts.models.rl.common.memory import MultiStepBuffer from pl_bolts.models.rl.common.networks import CNN from pl_bolts.utils import _GYM_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: from gym import Env else: # pragma: no cover warn_missing_pkg('gym') Env = object class DQN(LightningModule): """ Basic DQN Model PyTorch Lightning implementation of `DQN <https://arxiv.org/abs/1312.5602>`_ Paper authors: Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Alex Graves, Ioannis Antonoglou, Daan Wierstra, Martin Riedmiller. Model implemented by: - `Donal Byrne <https://github.com/djbyrne>` Example:
import numpy as np from pl_bolts.utils import _OPENCV_AVAILABLE, _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: from torchvision import transforms as transforms else: # pragma: no cover warn_missing_pkg('torchvision') if _OPENCV_AVAILABLE: import cv2 else: # pragma: no cover warn_missing_pkg('cv2', pypi_name='opencv-python') class SimCLRTrainDataTransform(object): """ Transforms for SimCLR Transform:: RandomResizedCrop(size=self.input_height) RandomHorizontalFlip() RandomApply([color_jitter], p=0.8) RandomGrayscale(p=0.2) GaussianBlur(kernel_size=int(0.1 * self.input_height)) transforms.ToTensor() Example::
import math from typing import List, Sequence, Tuple, Union import numpy as np import torch from torch import Tensor from pl_bolts.utils import _SKLEARN_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _SKLEARN_AVAILABLE: from sklearn.utils import shuffle as sk_shuffle else: # pragma: no cover warn_missing_pkg("sklearn", pypi_name="scikit-learn") class Identity(torch.nn.Module): """An identity class to replace arbitrary layers in pretrained models. Example:: from pl_bolts.utils import Identity model = resnet18() model.fc = Identity() """ def __init__(self) -> None: super().__init__() def forward(self, x: Tensor) -> Tensor:
from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_9_1 from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: from torchvision.datasets import MNIST else: # pragma: no cover warn_missing_pkg("torchvision") MNIST = object if _PIL_AVAILABLE: from PIL import Image else: # pragma: no cover warn_missing_pkg("PIL", pypi_name="Pillow") # TODO(akihironitta): This is needed to avoid 503 error when downloading MNIST dataset # from http://yann.lecun.com/exdb/mnist/ and can be removed after `torchvision==0.9.1`. # See https://github.com/pytorch/vision/issues/3549 for details. if _TORCHVISION_AVAILABLE and _TORCHVISION_LESS_THAN_0_9_1: MNIST.resources = [ ( "https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873", ), ( "https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432", ), ( "https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3", ),
import os import pickle import tarfile from typing import Callable, Optional, Sequence, Tuple import torch from torch import Tensor from pl_bolts.datasets.base_dataset import LightDataset from pl_bolts.utils import _PIL_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _PIL_AVAILABLE: from PIL import Image else: # pragma: no cover warn_missing_pkg('PIL', pypi_name='Pillow') class CIFAR10(LightDataset): """ Customized `CIFAR10 <http://www.cs.toronto.edu/~kriz/cifar.html>`_ dataset for testing Pytorch Lightning without the torchvision dependency. Part of the code was copied from https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/ Args: data_dir: Root directory of dataset where ``CIFAR10/processed/training.pt`` and ``CIFAR10/processed/test.pt`` exist. train: If ``True``, creates dataset from ``training.pt``, otherwise from ``test.pt``.
import importlib import math import numpy as np import torch from pl_bolts.utils.warnings import warn_missing_pkg _SKLEARN_AVAILABLE = importlib.util.find_spec("sklearn") is not None if _SKLEARN_AVAILABLE: from sklearn.utils import shuffle as sk_shuffle else: warn_missing_pkg('sklearn', pypi_name='scikit-learn') # pragma: no-cover class Identity(torch.nn.Module): """ An identity class to replace arbitrary layers in pretrained models Example:: from pl_bolts.utils import Identity model = resnet18() model.fc = Identity() """ def __init__(self): super(Identity, self).__init__() def forward(self, x):
""" import importlib from abc import ABC from collections import deque, namedtuple from typing import Callable, Iterable, List, Tuple import torch from torch.utils.data import IterableDataset from pl_bolts.utils.warnings import warn_missing_pkg _GYM_AVAILABLE = importlib.util.find_spec("gym") is not None if _GYM_AVAILABLE: from gym import Env else: warn_missing_pkg("gym") # pragma: no-cover # Datasets Experience = namedtuple( "Experience", field_names=["state", "action", "reward", "done", "new_state"] ) class ExperienceSourceDataset(IterableDataset): """ Basic experience source dataset. Takes a generate_batch function that returns an iterator. The logic for the experience source and how the batch is generated is defined the Lightning model itself """
https://github.com/Shmuma/ptan/blob/master/ptan/common/wrappers.py """ import collections import numpy as np import torch from pl_bolts.utils import _GYM_AVAILABLE, _OPENCV_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: import gym.spaces from gym import make as gym_make from gym import ObservationWrapper, Wrapper else: # pragma: no cover warn_missing_pkg('gym') Wrapper = object ObservationWrapper = object if _OPENCV_AVAILABLE: import cv2 else: # pragma: no cover warn_missing_pkg('cv2', pypi_name='opencv-python') class ToTensor(Wrapper): """For environments where the user need to press FIRE for the game to start.""" def __init__(self, env=None): if not _GYM_AVAILABLE: # pragma: no cover raise ModuleNotFoundError( 'You want to use `gym` which is not installed yet.')
import torch.nn as nn from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.loggers import LightningLoggerBase, TensorBoardLogger, WandbLogger from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection from torch import Tensor from torch.nn import Module from torch.utils.hooks import RemovableHandle from pl_bolts.utils import _WANDB_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _WANDB_AVAILABLE: import wandb else: # pragma: no cover warn_missing_pkg("wandb") wandb = None class DataMonitorBase(Callback): supported_loggers = ( TensorBoardLogger, WandbLogger, ) def __init__(self, log_every_n_steps: int = None): """ Base class for monitoring data histograms in a LightningModule. This requires a logger configured in the Trainer, otherwise no data is logged. The specific class that inherits from this base defines what data gets collected.
from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: from torchvision.datasets import MNIST else: # pragma: no cover warn_missing_pkg('torchvision') MNIST = object if _PIL_AVAILABLE: from PIL import Image else: # pragma: no cover warn_missing_pkg('PIL', pypi_name='Pillow') class BinaryMNIST(MNIST): def __getitem__(self, idx): """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ if not _TORCHVISION_AVAILABLE: # pragma: no cover raise ModuleNotFoundError( 'You want to use `torchvision` which is not installed yet.') img, target = self.data[idx], int(self.targets[idx]) # doing this so that it is consistent with all other datasets # to return a PIL Image
"""Set of wrapper functions for gym environments taken from https://github.com/Shmuma/ptan/blob/master/ptan/common/wrappers.py.""" import collections import numpy as np import torch from pl_bolts.utils import _GYM_AVAILABLE, _OPENCV_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: import gym.spaces from gym import ObservationWrapper, Wrapper from gym import make as gym_make else: # pragma: no cover warn_missing_pkg("gym") Wrapper = object ObservationWrapper = object if _OPENCV_AVAILABLE: import cv2 else: # pragma: no cover warn_missing_pkg("cv2", pypi_name="opencv-python") class ToTensor(Wrapper): """For environments where the user need to press FIRE for the game to start.""" def __init__(self, env=None): if not _GYM_AVAILABLE: # pragma: no cover raise ModuleNotFoundError( "You want to use `gym` which is not installed yet.")
from typing import Optional from torch.utils.data import random_split from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet from pl_bolts.datasets.ssl_amdim_datasets import CIFAR10Mixed from pl_bolts.models.self_supervised.amdim import transforms as amdim_transforms from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: from torchvision.datasets import STL10 else: # pragma: no cover warn_missing_pkg('torchvision') class AMDIMPretraining(): """" For pretraining we use the train transform for both train and val. """ @staticmethod def cifar10(dataset_root, split: str = 'train'): assert split in ('train', 'val') dataset = CIFAR10Mixed( root=dataset_root, split=split, transform=amdim_transforms.AMDIMTrainTransformsCIFAR10(), download=True, ) return dataset
import math import numpy as np import torch from pl_bolts.utils import _SKLEARN_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _SKLEARN_AVAILABLE: from sklearn.utils import shuffle as sk_shuffle else: # pragma: no-cover warn_missing_pkg('sklearn', pypi_name='scikit-learn') class Identity(torch.nn.Module): """ An identity class to replace arbitrary layers in pretrained models Example:: from pl_bolts.utils import Identity model = resnet18() model.fc = Identity() """ def __init__(self): super(Identity, self).__init__() def forward(self, x): return x
"""Datamodules for RL models that rely on experiences generated during training Based on implementations found here: https://github.com/Shmuma/ptan/blob/master/ptan/experience.py.""" from abc import ABC from collections import deque, namedtuple from typing import Callable, Iterator, List, Tuple import torch from torch.utils.data import IterableDataset from pl_bolts.utils import _GYM_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: from gym import Env else: # pragma: no cover warn_missing_pkg("gym") Env = object Experience = namedtuple( "Experience", field_names=["state", "action", "reward", "done", "new_state"]) class ExperienceSourceDataset(IterableDataset): """Basic experience source dataset. Takes a generate_batch function that returns an iterator. The logic for the experience source and how the batch is generated is defined the Lightning model itself """ def __init__(self, generate_batch: Callable) -> None: self.generate_batch = generate_batch
import os import numpy as np from torch.utils.data import Dataset from pl_bolts.utils import _PIL_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _PIL_AVAILABLE: from PIL import Image else: # pragma: no cover warn_missing_pkg('PIL') DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1) DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33) class KittiDataset(Dataset): """ Note: You need to have downloaded the Kitti dataset first and provide the path to where it is saved. You can download the dataset here: http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015 There are 34 classes, however not all of them are useful for training (e.g. railings on highways). These useless classes (the pixel values of these classes) are stored in `void_labels`. Useful classes are stored in `valid_labels`. The `encode_segmap` function sets all pixels with any of the `void_labels` to `ignore_index` (250 by default). It also sets all of the valid pixels to the appropriate value between 0 and `len(valid_labels)` (since that is the number of valid classes), so it can be used properly by the loss function when comparing with the output.
import torch from pytorch_lightning import LightningModule, Trainer, seed_everything from torch import Tensor, nn from torch.utils.data import DataLoader from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler from pl_bolts.models.gans.dcgan.components import DCGANDiscriminator, DCGANGenerator from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import LSUN, MNIST else: # pragma: no cover warn_missing_pkg("torchvision") class DCGAN(LightningModule): """DCGAN implementation. Example:: from pl_bolts.models.gans import DCGAN m = DCGAN() Trainer(gpus=2).fit(m) Example CLI:: # mnist
from typing import Tuple import numpy as np from pl_bolts.utils import _OPENCV_AVAILABLE, _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: from torchvision import transforms else: # pragma: no cover warn_missing_pkg("torchvision") if _OPENCV_AVAILABLE: import cv2 else: # pragma: no cover warn_missing_pkg("cv2", pypi_name="opencv-python") class SwAVTrainDataTransform: def __init__( self, normalize=None, size_crops: Tuple[int] = (96, 36), nmb_crops: Tuple[int] = (2, 4), min_scale_crops: Tuple[float] = (0.33, 0.10), max_scale_crops: Tuple[float] = (1, 0.33), gaussian_blur: bool = True, jitter_strength: float = 1.0, ): self.jitter_strength = jitter_strength self.gaussian_blur = gaussian_blur
from typing import List import numpy as np from pl_bolts.utils import _OPENCV_AVAILABLE, _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: import torchvision.transforms as transforms else: warn_missing_pkg('torchvision') # pragma: no-cover if _OPENCV_AVAILABLE: import cv2 else: warn_missing_pkg('cv2', pypi_name='opencv-python') # pragma: no-cover class SwAVTrainDataTransform(object): def __init__(self, normalize=None, size_crops: List[int] = [96, 36], nmb_crops: List[int] = [2, 4], min_scale_crops: List[float] = [0.33, 0.10], max_scale_crops: List[float] = [1, 0.33], gaussian_blur: bool = True, jitter_strength: float = 1.): self.jitter_strength = jitter_strength self.gaussian_blur = gaussian_blur assert len(size_crops) == len(nmb_crops)
import math from typing import Any, Tuple import numpy as np import torch from pytorch_lightning import LightningDataModule from torch import Tensor from torch.utils.data import DataLoader, Dataset from pl_bolts.utils import _SKLEARN_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _SKLEARN_AVAILABLE: from sklearn.utils import shuffle as sk_shuffle else: # pragma: no cover warn_missing_pkg("sklearn") class SklearnDataset(Dataset): """Mapping between numpy (or sklearn) datasets to PyTorch datasets. Example: >>> from sklearn.datasets import load_diabetes >>> from pl_bolts.datamodules import SklearnDataset ... >>> X, y = load_diabetes(return_X_y=True) >>> dataset = SklearnDataset(X, y) >>> len(dataset) 442 """ def __init__(self,
from pl_bolts.transforms.self_supervised import RandomTranslateWithReflect from pl_bolts.utils.warnings import warn_missing_pkg try: from torchvision import transforms except ModuleNotFoundError: warn_missing_pkg('torchvision') # pragma: no-cover _TORCHVISION_AVAILABLE = False else: _TORCHVISION_AVAILABLE = True class AMDIMTrainTransformsCIFAR10: """ Transforms applied to AMDIM Transforms:: img_jitter, col_jitter, rnd_gray, transforms.ToTensor(), normalize Example:: x = torch.rand(5, 3, 32, 32) transform = AMDIMTrainTransformsCIFAR10() (view1, view2) = transform(x)
import importlib import torch from pytorch_lightning import Callback from torch import nn from pl_bolts.utils.warnings import warn_missing_pkg _MATPLOTLIB_AVAILABLE = importlib.util.find_spec("matplotlib") is not None if _MATPLOTLIB_AVAILABLE: from matplotlib import pyplot as plt else: warn_missing_pkg("matplotlib") # pragma: no-cover class ConfusedLogitCallback(Callback): # pragma: no-cover """ Takes the logit predictions of a model and when the probabilities of two classes are very close, the model doesn't have high certainty that it should pick one vs the other class. This callback shows how the input would have to change to swing the model from one label prediction to the other. In this case, the network predicts a 5... but gives almost equal probability to an 8. The images show what about the original 5 would have to change to make it more like a 5 or more like an 8. For each confused logit the confused images are generated by taking the gradient from a logit wrt an input for the top two closest logits. Example::