Esempio n. 1
0
    def __init__(self,
                 x_ref: Union[np.ndarray, list],
                 p_val: float = .05,
                 preprocess_x_ref: bool = True,
                 update_x_ref: Optional[Dict[str, int]] = None,
                 preprocess_fn: Optional[Callable] = None,
                 sigma: Optional[np.ndarray] = None,
                 n_permutations: int = 100,
                 n_kernel_centers: Optional[int] = None,
                 lambda_rd_max: float = 0.2,
                 device: Optional[str] = None,
                 input_shape: Optional[tuple] = None,
                 data_type: Optional[str] = None) -> None:
        """
        Least-squares density difference (LSDD) data drift detector using a permutation test.

        Parameters
        ----------
        x_ref
            Data used as reference distribution.
        p_val
            p-value used for the significance of the permutation test.
        preprocess_x_ref
            Whether to already preprocess and store the reference data.
        update_x_ref
            Reference data can optionally be updated to the last n instances seen by the detector
            or via reservoir sampling with size n. For the former, the parameter equals {'last': n} while
            for reservoir sampling {'reservoir_sampling': n} is passed.
        preprocess_fn
            Function to preprocess the data before computing the data drift metrics.
        sigma
            Optionally set the bandwidth of the Gaussian kernel used in estimating the LSDD. Can also pass multiple
            bandwidth values as an array. The kernel evaluation is then averaged over those bandwidths. If `sigma`
            is not specified, the 'median heuristic' is adopted whereby `sigma` is set as the median pairwise distance
            between reference samples.
        n_permutations
            Number of permutations used in the permutation test.
        n_kernel_centers
            The number of reference samples to use as centers in the Gaussian kernel model used to estimate LSDD.
            Defaults to 1/20th of the reference data.
        lambda_rd_max
            The maximum relative difference between two estimates of LSDD that the regularization parameter
            lambda is allowed to cause. Defaults to 0.2 as in the paper.
        device
            Device type used. The default None tries to use the GPU and falls back on CPU if needed.
            Can be specified by passing either 'cuda', 'gpu' or 'cpu'.
        input_shape
            Shape of input data.
        data_type
            Optionally specify the data type (tabular, image or time-series). Added to metadata.
        """
        super().__init__(x_ref=x_ref,
                         p_val=p_val,
                         preprocess_x_ref=preprocess_x_ref,
                         update_x_ref=update_x_ref,
                         preprocess_fn=preprocess_fn,
                         sigma=sigma,
                         n_permutations=n_permutations,
                         n_kernel_centers=n_kernel_centers,
                         lambda_rd_max=lambda_rd_max,
                         input_shape=input_shape,
                         data_type=data_type)
        self.meta.update({'backend': 'pytorch'})

        # set device
        self.device = get_device(device)

        # TODO: TBD: the several type:ignore's below are because x_ref is typed as an np.ndarray
        #  in the method signature, so we can't cast it to torch.Tensor unless we change the signature
        #  to also accept torch.Tensor. We also can't redefine it's type as that would involve enabling
        #  --allow-redefinitions in mypy settings (which we might do eventually).
        if self.preprocess_x_ref or self.preprocess_fn is None:
            x_ref = torch.as_tensor(self.x_ref).to(
                self.device)  # type: ignore[assignment]
            self._configure_normalization(x_ref)  # type: ignore[arg-type]
            x_ref = self._normalize(x_ref)
            self._initialize_kernel(x_ref)  # type: ignore[arg-type]
            self._configure_kernel_centers(x_ref)  # type: ignore[arg-type]
            self.x_ref = x_ref.cpu().numpy()  # type: ignore[union-attr]
            # For stability in high dimensions we don't divide H by (pi*sigma^2)^(d/2)
            # Results in an alternative test-stat of LSDD*(pi*sigma^2)^(d/2). Same p-vals etc.
            self.H = GaussianRBF(np.sqrt(2.) * self.kernel.sigma)(
                self.kernel_centers, self.kernel_centers)
Esempio n. 2
0
    def __init__(self,
                 x_ref: Union[np.ndarray, list],
                 ert: float,
                 window_size: int,
                 preprocess_fn: Optional[Callable] = None,
                 sigma: Optional[np.ndarray] = None,
                 n_bootstraps: int = 1000,
                 n_kernel_centers: Optional[int] = None,
                 lambda_rd_max: float = 0.2,
                 device: Optional[str] = None,
                 verbose: bool = True,
                 input_shape: Optional[tuple] = None,
                 data_type: Optional[str] = None) -> None:
        """
        Online least squares density difference (LSDD) data drift detector using preconfigured thresholds.
        Motivated by Bu et al. (2017): https://ieeexplore.ieee.org/abstract/document/7890493
        We have made modifications such that a desired ERT can be accurately targeted however.

        Parameters
        ----------
        x_ref
            Data used as reference distribution.
        ert
            The expected run-time (ERT) in the absence of drift. For the multivariate detectors, the ERT is defined
            as the expected run-time from t=0.
        window_size
            The size of the sliding test-window used to compute the test-statistic.
            Smaller windows focus on responding quickly to severe drift, larger windows focus on
            ability to detect slight drift.
        preprocess_fn
            Function to preprocess the data before computing the data drift metrics.s
        sigma
            Optionally set the bandwidth of the Gaussian kernel used in estimating the LSDD. Can also pass multiple
            bandwidth values as an array. The kernel evaluation is then averaged over those bandwidths. If `sigma`
            is not specified, the 'median heuristic' is adopted whereby `sigma` is set as the median pairwise distance
            between reference samples.
        n_bootstraps
            The number of bootstrap simulations used to configure the thresholds. The larger this is the
            more accurately the desired ERT will be targeted. Should ideally be at least an order of magnitude
            larger than the ert.
        n_kernel_centers
            The number of reference samples to use as centers in the Gaussian kernel model used to estimate LSDD.
            Defaults to 2*window_size.
        lambda_rd_max
            The maximum relative difference between two estimates of LSDD that the regularization parameter
            lambda is allowed to cause. Defaults to 0.2 as in the paper.
        device
            Device type used. The default None tries to use the GPU and falls back on CPU if needed.
            Can be specified by passing either 'cuda', 'gpu' or 'cpu'. Only relevant for 'pytorch' backend.
        verbose
            Whether or not to print progress during configuration.
        input_shape
            Shape of input data.
        data_type
            Optionally specify the data type (tabular, image or time-series). Added to metadata.
        """
        super().__init__(x_ref=x_ref,
                         ert=ert,
                         window_size=window_size,
                         preprocess_fn=preprocess_fn,
                         n_bootstraps=n_bootstraps,
                         verbose=verbose,
                         input_shape=input_shape,
                         data_type=data_type)
        self.meta.update({'backend': 'pytorch'})
        self.n_kernel_centers = n_kernel_centers
        self.lambda_rd_max = lambda_rd_max

        # set device
        self.device = get_device(device)

        self._configure_normalization()

        # initialize kernel
        if sigma is None:
            x_ref = torch.from_numpy(self.x_ref).to(
                self.device)  # type: ignore[assignment]
            self.kernel = GaussianRBF()
            _ = self.kernel(x_ref, x_ref, infer_sigma=True)
        else:
            sigma = torch.from_numpy(sigma).to(self.device) if isinstance(
                sigma,  # type: ignore[assignment]
                np.ndarray) else None
            self.kernel = GaussianRBF(sigma)  # type: ignore[arg-type]

        if self.n_kernel_centers is None:
            self.n_kernel_centers = 2 * window_size

        self._configure_kernel_centers()
        self._configure_thresholds()
        self._initialise()
Esempio n. 3
0
    def __init__(self,
                 x_ref: Union[np.ndarray, list],
                 ert: float,
                 window_size: int,
                 preprocess_fn: Optional[Callable] = None,
                 kernel: Callable = GaussianRBF,
                 sigma: Optional[np.ndarray] = None,
                 n_bootstraps: int = 1000,
                 device: Optional[str] = None,
                 verbose: bool = True,
                 input_shape: Optional[tuple] = None,
                 data_type: Optional[str] = None) -> None:
        """
        Online maximum Mean Discrepancy (MMD) data drift detector using preconfigured thresholds.

        Parameters
        ----------
        x_ref
            Data used as reference distribution.
        ert
            The expected run-time (ERT) in the absence of drift. For the multivariate detectors, the ERT is defined
            as the expected run-time from t=0.
        window_size
            The size of the sliding test-window used to compute the test-statistic.
            Smaller windows focus on responding quickly to severe drift, larger windows focus on
            ability to detect slight drift.
        preprocess_fn
            Function to preprocess the data before computing the data drift metrics.
        kernel
            Kernel used for the MMD computation, defaults to Gaussian RBF kernel.
        sigma
            Optionally set the GaussianRBF kernel bandwidth. Can also pass multiple bandwidth values as an array.
            The kernel evaluation is then averaged over those bandwidths. If `sigma` is not specified, the 'median
            heuristic' is adopted whereby `sigma` is set as the median pairwise distance between reference samples.
        n_bootstraps
            The number of bootstrap simulations used to configure the thresholds. The larger this is the
            more accurately the desired ERT will be targeted. Should ideally be at least an order of magnitude
            larger than the ERT.
        device
            Device type used. The default None tries to use the GPU and falls back on CPU if needed.
            Can be specified by passing either 'cuda', 'gpu' or 'cpu'. Only relevant for 'pytorch' backend.
        verbose
            Whether or not to print progress during configuration.
        input_shape
            Shape of input data.
        data_type
            Optionally specify the data type (tabular, image or time-series). Added to metadata.
        """
        super().__init__(x_ref=x_ref,
                         ert=ert,
                         window_size=window_size,
                         preprocess_fn=preprocess_fn,
                         n_bootstraps=n_bootstraps,
                         verbose=verbose,
                         input_shape=input_shape,
                         data_type=data_type)
        self.meta.update({'backend': 'pytorch'})

        # set device
        self.device = get_device(device)

        # initialize kernel
        sigma = torch.from_numpy(sigma).to(self.device) if isinstance(
            sigma,  # type: ignore[assignment]
            np.ndarray) else None
        self.kernel = kernel(sigma) if kernel == GaussianRBF else kernel

        # compute kernel matrix for the reference data
        self.x_ref = torch.from_numpy(self.x_ref).to(self.device)
        self.k_xx = self.kernel(self.x_ref,
                                self.x_ref,
                                infer_sigma=(sigma is None))

        self._configure_thresholds()
        self._initialise()
Esempio n. 4
0
    def __init__(self,
                 x_ref: Union[np.ndarray, list],
                 p_val: float = .05,
                 preprocess_x_ref: bool = True,
                 update_x_ref: Optional[Dict[str, int]] = None,
                 preprocess_fn: Optional[Callable] = None,
                 kernel: Callable = GaussianRBF,
                 sigma: Optional[np.ndarray] = None,
                 configure_kernel_from_x_ref: bool = True,
                 n_permutations: int = 100,
                 device: Optional[str] = None,
                 input_shape: Optional[tuple] = None,
                 data_type: Optional[str] = None) -> None:
        """
        Maximum Mean Discrepancy (MMD) data drift detector using a permutation test.

        Parameters
        ----------
        x_ref
            Data used as reference distribution.
        p_val
            p-value used for the significance of the permutation test.
        preprocess_x_ref
            Whether to already preprocess and store the reference data.
        update_x_ref
            Reference data can optionally be updated to the last n instances seen by the detector
            or via reservoir sampling with size n. For the former, the parameter equals {'last': n} while
            for reservoir sampling {'reservoir_sampling': n} is passed.
        preprocess_fn
            Function to preprocess the data before computing the data drift metrics.
        kernel
            Kernel used for the MMD computation, defaults to Gaussian RBF kernel.
        sigma
            Optionally set the GaussianRBF kernel bandwidth. Can also pass multiple bandwidth values as an array.
            The kernel evaluation is then averaged over those bandwidths.
        configure_kernel_from_x_ref
            Whether to already configure the kernel bandwidth from the reference data.
        n_permutations
            Number of permutations used in the permutation test.
        device
            Device type used. The default None tries to use the GPU and falls back on CPU if needed.
            Can be specified by passing either 'cuda', 'gpu' or 'cpu'.
        input_shape
            Shape of input data.
        data_type
            Optionally specify the data type (tabular, image or time-series). Added to metadata.
        """
        super().__init__(
            x_ref=x_ref,
            p_val=p_val,
            preprocess_x_ref=preprocess_x_ref,
            update_x_ref=update_x_ref,
            preprocess_fn=preprocess_fn,
            sigma=sigma,
            configure_kernel_from_x_ref=configure_kernel_from_x_ref,
            n_permutations=n_permutations,
            input_shape=input_shape,
            data_type=data_type)
        self.meta.update({'backend': 'pytorch'})

        # set device
        self.device = get_device(device)

        # initialize kernel
        sigma = torch.from_numpy(sigma).to(self.device) if isinstance(
            sigma,  # type: ignore[assignment]
            np.ndarray) else None
        self.kernel = kernel(sigma) if kernel == GaussianRBF else kernel

        # compute kernel matrix for the reference data
        if self.infer_sigma or isinstance(sigma, torch.Tensor):
            x = torch.from_numpy(self.x_ref).to(self.device)
            self.k_xx = self.kernel(x, x, infer_sigma=self.infer_sigma)
            self.infer_sigma = False
        else:
            self.k_xx, self.infer_sigma = None, True
Esempio n. 5
0
    def __init__(self,
                 x_ref: Union[np.ndarray, list],
                 c_ref: np.ndarray,
                 p_val: float = .05,
                 preprocess_x_ref: bool = True,
                 update_ref: Optional[Dict[str, int]] = None,
                 preprocess_fn: Optional[Callable] = None,
                 x_kernel: Callable = GaussianRBF,
                 c_kernel: Callable = GaussianRBF,
                 n_permutations: int = 1000,
                 prop_c_held: float = 0.25,
                 n_folds: int = 5,
                 batch_size: Optional[int] = 256,
                 device: Optional[str] = None,
                 input_shape: Optional[tuple] = None,
                 data_type: Optional[str] = None,
                 verbose: bool = False) -> None:
        """
        A context-aware drift detector based on a conditional analogue of the maximum mean discrepancy (MMD).
        Only detects differences between samples that can not be attributed to differences between associated
        sets of contexts. p-values are computed using a conditional permutation test.

        Parameters
        ----------
        x_ref
            Data used as reference distribution.
        c_ref
            Context for the reference distribution.
        p_val
            p-value used for the significance of the permutation test.
        preprocess_x_ref
            Whether to already preprocess and store the reference data `x_ref`.
        update_ref
            Reference data can optionally be updated to the last N instances seen by the detector.
            The parameter should be passed as a dictionary *{'last': N}*.
        preprocess_fn
            Function to preprocess the data before computing the data drift metrics.
        x_kernel
            Kernel defined on the input data, defaults to Gaussian RBF kernel.
        c_kernel
            Kernel defined on the context data, defaults to Gaussian RBF kernel.
        n_permutations
            Number of permutations used in the permutation test.
        prop_c_held
            Proportion of contexts held out to condition on.
        n_folds
            Number of cross-validation folds used when tuning the regularisation parameters.
        batch_size
            If not None, then compute batches of MMDs at a time (rather than all at once).
        device
            Device type used. The default None tries to use the GPU and falls back on CPU if needed.
            Can be specified by passing either 'cuda', 'gpu' or 'cpu'. Only relevant for 'pytorch' backend.
        input_shape
            Shape of input data.
        data_type
            Optionally specify the data type (tabular, image or time-series). Added to metadata.
        verbose
            Whether or not to print progress during configuration.
        """
        super().__init__(x_ref=x_ref,
                         c_ref=c_ref,
                         p_val=p_val,
                         preprocess_x_ref=preprocess_x_ref,
                         update_ref=update_ref,
                         preprocess_fn=preprocess_fn,
                         x_kernel=x_kernel,
                         c_kernel=c_kernel,
                         n_permutations=n_permutations,
                         prop_c_held=prop_c_held,
                         n_folds=n_folds,
                         batch_size=batch_size,
                         input_shape=input_shape,
                         data_type=data_type,
                         verbose=verbose)
        self.meta.update({'backend': 'pytorch'})

        # set device
        self.device = get_device(device)

        # initialize kernel
        self.x_kernel = x_kernel(init_sigma_fn=_sigma_median_diag
                                 ) if x_kernel == GaussianRBF else x_kernel
        self.c_kernel = c_kernel(init_sigma_fn=_sigma_median_diag
                                 ) if c_kernel == GaussianRBF else c_kernel

        # Initialize classifier (hardcoded for now)
        self.clf = _SVCDomainClf(self.c_kernel)
Esempio n. 6
0
    def __init__(self,
                 x_ref: Union[np.ndarray, list],
                 model: Union[nn.Module, nn.Sequential],
                 p_val: float = .05,
                 preprocess_x_ref: bool = True,
                 update_x_ref: Optional[Dict[str, int]] = None,
                 preprocess_fn: Optional[Callable] = None,
                 preds_type: str = 'probs',
                 binarize_preds: bool = False,
                 reg_loss_fn: Callable = (lambda model: 0),
                 train_size: Optional[float] = .75,
                 n_folds: Optional[int] = None,
                 retrain_from_scratch: bool = True,
                 seed: int = 0,
                 optimizer: Callable = torch.optim.Adam,
                 learning_rate: float = 1e-3,
                 batch_size: int = 32,
                 preprocess_batch_fn: Optional[Callable] = None,
                 epochs: int = 3,
                 verbose: int = 0,
                 train_kwargs: Optional[dict] = None,
                 device: Optional[str] = None,
                 dataset: Callable = TorchDataset,
                 dataloader: Callable = DataLoader,
                 data_type: Optional[str] = None) -> None:
        """
        Classifier-based drift detector. The classifier is trained on a fraction of the combined
        reference and test data and drift is detected on the remaining data. To use all the data
        to detect drift, a stratified cross-validation scheme can be chosen.

        Parameters
        ----------
        x_ref
            Data used as reference distribution.
        model
            PyTorch classification model used for drift detection.
        p_val
            p-value used for the significance of the test.
        preprocess_x_ref
            Whether to already preprocess and store the reference data.
        update_x_ref
            Reference data can optionally be updated to the last n instances seen by the detector
            or via reservoir sampling with size n. For the former, the parameter equals {'last': n} while
            for reservoir sampling {'reservoir_sampling': n} is passed.
        preprocess_fn
            Function to preprocess the data before computing the data drift metrics.
        preds_type
            Whether the model outputs 'probs' or 'logits'
        binarize_preds
            Whether to test for discrepency on soft (e.g. probs/logits) model predictions directly
            with a K-S test or binarise to 0-1 prediction errors and apply a binomial test.
        reg_loss_fn
            The regularisation term reg_loss_fn(model) is added to the loss function being optimized.
        train_size
            Optional fraction (float between 0 and 1) of the dataset used to train the classifier.
            The drift is detected on `1 - train_size`. Cannot be used in combination with `n_folds`.
        n_folds
            Optional number of stratified folds used for training. The model preds are then calculated
            on all the out-of-fold predictions. This allows to leverage all the reference and test data
            for drift detection at the expense of longer computation. If both `train_size` and `n_folds`
            are specified, `n_folds` is prioritized.
        retrain_from_scratch
            Whether the classifier should be retrained from scratch for each set of test data or whether
            it should instead continue training from where it left off on the previous set.
        seed
            Optional random seed for fold selection.
        optimizer
            Optimizer used during training of the classifier.
        learning_rate
            Learning rate used by optimizer.
        batch_size
            Batch size used during training of the classifier.
        preprocess_batch_fn
            Optional batch preprocessing function. For example to convert a list of objects to a batch which can be
            processed by the model.
        epochs
            Number of training epochs for the classifier for each (optional) fold.
        verbose
            Verbosity level during the training of the classifier. 0 is silent, 1 a progress bar.
        train_kwargs
            Optional additional kwargs when fitting the classifier.
        device
            Device type used. The default None tries to use the GPU and falls back on CPU if needed.
            Can be specified by passing either 'cuda', 'gpu' or 'cpu'.
        dataset
            Dataset object used during training.
        dataloader
            Dataloader object used during training.
        data_type
            Optionally specify the data type (tabular, image or time-series). Added to metadata.
        """
        super().__init__(x_ref=x_ref,
                         p_val=p_val,
                         preprocess_x_ref=preprocess_x_ref,
                         update_x_ref=update_x_ref,
                         preprocess_fn=preprocess_fn,
                         preds_type=preds_type,
                         binarize_preds=binarize_preds,
                         train_size=train_size,
                         n_folds=n_folds,
                         retrain_from_scratch=retrain_from_scratch,
                         seed=seed,
                         data_type=data_type)

        if preds_type not in ['probs', 'logits']:
            raise ValueError("'preds_type' should be 'probs' or 'logits'")

        self.meta.update({'backend': 'pytorch'})

        # set device, define model and training kwargs
        self.device = get_device(device)
        self.original_model = model
        self.model = deepcopy(model)

        # define kwargs for dataloader and trainer
        self.loss_fn = nn.CrossEntropyLoss() if (
            self.preds_type == 'logits') else nn.NLLLoss()
        self.dataset = dataset
        self.dataloader = partial(dataloader,
                                  batch_size=batch_size,
                                  shuffle=True)
        self.predict_fn = partial(predict_batch,
                                  device=self.device,
                                  preprocess_fn=preprocess_batch_fn,
                                  batch_size=batch_size)
        self.train_kwargs = {
            'optimizer': optimizer,
            'epochs': epochs,
            'preprocess_fn': preprocess_batch_fn,
            'reg_loss_fn': reg_loss_fn,
            'learning_rate': learning_rate,
            'verbose': verbose
        }
        if isinstance(train_kwargs, dict):
            self.train_kwargs.update(train_kwargs)
Esempio n. 7
0
    def __init__(
            self,
            x_ref: Union[np.ndarray, list],
            kernel: Union[nn.Module, nn.Sequential],
            p_val: float = .05,
            preprocess_x_ref: bool = True,
            update_x_ref: Optional[Dict[str, int]] = None,
            preprocess_fn: Optional[Callable] = None,
            n_permutations: int = 100,
            var_reg: float = 1e-5,
            reg_loss_fn: Callable = (lambda kernel: 0),
            train_size: Optional[float] = .75,
            retrain_from_scratch: bool = True,
            optimizer: torch.optim.Optimizer = torch.optim.
        Adam,  # type: ignore
            learning_rate: float = 1e-3,
            batch_size: int = 32,
            preprocess_batch_fn: Optional[Callable] = None,
            epochs: int = 3,
            verbose: int = 0,
            train_kwargs: Optional[dict] = None,
            device: Optional[str] = None,
            dataset: Callable = TorchDataset,
            dataloader: Callable = DataLoader,
            data_type: Optional[str] = None) -> None:
        """
        Maximum Mean Discrepancy (MMD) data drift detector where the kernel is trained to maximise an
        estimate of the test power. The kernel is trained on a split of the reference and test instances
        and then the MMD is evaluated on held out instances and a permutation test is performed.

        For details see Liu et al (2020): Learning Deep Kernels for Non-Parametric Two-Sample Tests
        (https://arxiv.org/abs/2002.09116)


        Parameters
        ----------
        x_ref
            Data used as reference distribution.
        kernel
            Trainable PyTorch module that returns a similarity between two instances.
        p_val
            p-value used for the significance of the test.
        preprocess_x_ref
            Whether to already preprocess and store the reference data.
        update_x_ref
            Reference data can optionally be updated to the last n instances seen by the detector
            or via reservoir sampling with size n. For the former, the parameter equals {'last': n} while
            for reservoir sampling {'reservoir_sampling': n} is passed.
        preprocess_fn
            Function to preprocess the data before applying the kernel.
        n_permutations
            The number of permutations to use in the permutation test once the MMD has been computed.
        var_reg
            Constant added to the estimated variance of the MMD for stability.
        reg_loss_fn
            The regularisation term reg_loss_fn(kernel) is added to the loss function being optimized.
        train_size
            Optional fraction (float between 0 and 1) of the dataset used to train the kernel.
            The drift is detected on `1 - train_size`.
        retrain_from_scratch
            Whether the kernel should be retrained from scratch for each set of test data or whether
            it should instead continue training from where it left off on the previous set.
        optimizer
            Optimizer used during training of the kernel.
        learning_rate
            Learning rate used by optimizer.
        batch_size
            Batch size used during training of the kernel.
        preprocess_batch_fn
            Optional batch preprocessing function. For example to convert a list of objects to a batch which can be
            processed by the kernel.
        epochs
            Number of training epochs for the kernel. Corresponds to the smaller of the reference and test sets.
        verbose
            Verbosity level during the training of the kernel. 0 is silent, 1 a progress bar.
        train_kwargs
            Optional additional kwargs when training the kernel.
        device
            Device type used. The default None tries to use the GPU and falls back on CPU if needed.
            Can be specified by passing either 'cuda', 'gpu' or 'cpu'. Only relevant for 'pytorch' backend.
        dataset
            Dataset object used during training.
        dataloader
            Dataloader object used during training. Only relevant for 'pytorch' backend.
        data_type
            Optionally specify the data type (tabular, image or time-series). Added to metadata.
        """
        super().__init__(x_ref=x_ref,
                         p_val=p_val,
                         preprocess_x_ref=preprocess_x_ref,
                         update_x_ref=update_x_ref,
                         preprocess_fn=preprocess_fn,
                         n_permutations=n_permutations,
                         train_size=train_size,
                         retrain_from_scratch=retrain_from_scratch,
                         data_type=data_type)
        self.meta.update({'backend': 'pytorch'})

        # set device, define model and training kwargs
        self.device = get_device(device)
        self.original_kernel = kernel
        self.kernel = deepcopy(kernel)

        # define kwargs for dataloader and trainer
        self.dataset = dataset
        self.dataloader = partial(dataloader,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  drop_last=True)
        self.kernel_mat_fn = partial(batch_compute_kernel_matrix,
                                     device=self.device,
                                     preprocess_fn=preprocess_batch_fn,
                                     batch_size=batch_size)
        self.train_kwargs = {
            'optimizer': optimizer,
            'epochs': epochs,
            'preprocess_fn': preprocess_batch_fn,
            'reg_loss_fn': reg_loss_fn,
            'learning_rate': learning_rate,
            'verbose': verbose
        }
        if isinstance(train_kwargs, dict):
            self.train_kwargs.update(train_kwargs)

        self.j_hat = LearnedKernelDriftTorch.JHat(self.kernel,
                                                  var_reg).to(self.device)