Ejemplo n.º 1
0
def compute_stochastic_elbo(a, b, nu, omega, x, y, a_0, b_0, mu_0):
    """
    Return a monte-carlo estimate of the ELBO, using a single sample from Q(sigma^-2, beta)
    
    a, b are the Gamma 'shape' and 'rate' parameters for the variational posterior over *precision*: q(tau) = q(sigma^-2)
    nu_k, omega_k are Normal 'mean' and 'precision' parameters for the variational posterior over weights: q(beta_k)
    x is an n by k matrix, where each row contains the regression inputs [1, x, x^2, x^3]
    y is an n by 1 values
    a_0, b_0 the parameters for the Gamma prior over precision P(tau) = P(sigma^-2)
    mu_0 is the mean of the Gamma prior on weights beta
    """
    
    # Define mean field variational distribution over (beta, tau).
    Q_beta = Normal(nu, omega**-0.5)
    Q_tau = Gamma(a, b) 
    
    # Sample from variational distribution: (tau, beta) ~ Q
    # Use rsample to make sure that the result is differentiable.
    tau = Q_tau.rsample()
    sigma = tau**-0.5
    beta = Q_beta.rsample()
    
    # Create a single sample monte-carlo estimate of ELBO.
    P_tau = Gamma(a_0, b_0) 
    P_beta = Normal(mu_0, sigma) 
    P_y = Normal((beta[None, :]*x).sum(dim=1, keepdim=True), sigma) 
    
    kl_tau = Q_tau.log_prob(tau) - P_tau.log_prob(tau)
    kl_beta = Q_beta.log_prob(beta).sum() - P_beta.log_prob(beta).sum()
    log_likelihood = P_y.log_prob(y).sum()

    elbo = log_likelihood - kl_tau - kl_beta
    return elbo
Ejemplo n.º 2
0
    def log_likelihood(self, x_norm, y_norm):
        mean, var, shape, rate, mixture_var = self(x_norm)
        norm_dist = Normal(mean, torch.sqrt(var))
        gamma_dist = Gamma(shape, rate)
        y = y_norm * self.y_std + self.y_mean + 10**(-4)

        only_normal_bool = (torch.abs(1 - mixture_var) < 10**(-4)).type(
            torch.float)
        only_gamma_bool = (mixture_var < 10**(-4)).type(torch.float)

        normal_component = norm_dist.log_prob(y_norm) + torch.log(mixture_var)
        gamma_component = gamma_dist.log_prob(y) + torch.log(1 - mixture_var)

        combined_tensor = torch.stack((normal_component, gamma_component),
                                      dim=0)
        output = torch.logsumexp(combined_tensor, dim=0)

        if mixture_var < 0.9:
            logging.debug("Mixture var: {}".format(float(mixture_var.mean())))
            logging.debug("NLLs: {:.3f}, {:.3f}".format(
                -float(norm_dist.log_prob(y_norm).mean()),
                -float(gamma_dist.log_prob(y).mean()),
            ))
            logging.debug("Combined NLL: {:.3f} or {:.3f}".format(
                -float(output.mean()), -float(old_output)))

        return output.mean()
Ejemplo n.º 3
0
def gamma_ll(target_vals, v):
    """
    Evaluate gamma-bernoulli mixture likelihood
    Parameters:
    ----------
    v: torch.Tensor(batch,86,channels)
        parameters from model [rho, alpha, beta]
    target_vals: torch.Tensor(batch,86)
        target vals to eval at
    """

    # Reshape
    target_vals = target_vals.reshape(-1)
    v = v.reshape(-1, 3)

    # Deal with cases where data is missing for a station
    v = v[~torch.isnan(target_vals), :]
    target_vals = target_vals[~torch.isnan(target_vals)]

    # Make r mask
    r, target_vals = make_r_mask(target_vals)

    gamma = Gamma(concentration=v[:, 1], rate=v[:, 2])
    logp = gamma.log_prob(target_vals)

    total = r * (torch.log(v[:, 0]) + logp) + (1 - r) * torch.log(1 - v[:, 0])

    return torch.mean(total)
Ejemplo n.º 4
0
    def test_loss(self, test_data):
        """
        outputs the losses the test data
        """
        x, y = test_data[:]
        if not test_data.x_normalised:
            constant_x = self.x_std == 0
            x = (x - self.x_mean) / self.x_std
            x[:, constant_x] = 0
        self.train(False)
        shape, rate = self(x)

        y = y.squeeze()
        shape = shape.squeeze()
        rate = rate.squeeze()

        gamma_dist = Gamma(shape, rate)

        test_nll = -gamma_dist.log_prob(y + 10**(-8)).mean()
        test_rmse = (((y - gamma_dist.mean)**2).mean())**0.5
        calibration_arr = self.calibration_test(y.detach().numpy(),
                                                shape.detach().numpy(),
                                                rate.detach().numpy())

        return float(test_nll), float(test_rmse), calibration_arr
Ejemplo n.º 5
0
def gamma_logpdf(inputs, loc, scale, reduction=None):
    """Gamma log-density.

    Args:
        inputs (tensor): Inputs.
        mean (tensor): Mean.
        sigma (tensor): Standard deviation.
        reduction (str, optional): Reduction. Defaults to no reduction.
            Possible values are "sum", "mean", and "batched_mean".

    Returns:
        tensor: Log-density.
    """
    dist = Gamma(concentration=loc, rate=scale)
    logp = dist.log_prob(inputs)

    if not reduction:
        return logp
    elif reduction == 'sum':
        return torch.sum(logp)
    elif reduction == 'mean':
        return torch.mean(logp)
    elif reduction == 'batched_mean':
        return torch.mean(torch.sum(logp, 1))
    else:
        raise RuntimeError(f'Unknown reduction "{reduction}".')
Ejemplo n.º 6
0
    def fit(self, train_data):
        self.train(True)
        self.x_mean, self.x_std = train_data.normalise_x()
        data_generator = data.DataLoader(train_data, batch_size=self.batch_size)

        optimiser = torch.optim.Adam(self.parameters(), lr=self.lr)

        for _ in torch.arange(self.n_epochs):
            for i, sample in enumerate(data_generator):
                x, y = sample
                shape, rate = self(x)
                gamma_dist = Gamma(shape, rate)
                optimiser.zero_grad()
                loss = -gamma_dist.log_prob(y.squeeze() + 10 ** (-8)).mean()
                loss.backward()
                optimiser.step()
Ejemplo n.º 7
0
def tbi_func(x, v):
    """
    Evaluate gamma-GP-Bernoulli mixture likelihood
    Parameters:
    ----------
    v: torch.Tensor(batch*86, channels)
        parameters from model
    x: torch.Tensor(batch*86)
        target vals to eval at
    """
    # Gamma distribution
    g = Gamma(concentration=v[:, 2], rate=v[:, 3])
    gamma = torch.exp(torch.clamp(g.log_prob(x), min=-1e5, max=1e5))

    # Weight term
    weight_term = (1 / 2) + (1 / np.pi) * torch.atan((x - v[:, 5]) / v[:, 6])

    # GP distribution
    gp = (1 / v[:, 4]) * (1 + (v[:, 1] * x / v[:, 4]))**((-1 / v[:, 1]) - 1)

    # total
    tbi = gamma * (1 - weight_term) + gp * weight_term
    return torch.clamp(tbi, min=1e-5)
Ejemplo n.º 8
0
class MLLGP():
    def __init__(self, model_gp, likelihood_gp, hyperpriors: dict) -> None:
        self.model_gp = model_gp
        self.likelihood_gp = likelihood_gp
        self.hyperpriors = hyperpriors

        a_beta = self.hyperpriors["lengthscales"].kwds["a"]
        b_beta = self.hyperpriors["lengthscales"].kwds["b"]

        self.Beta_tmp = Beta(concentration1=a_beta, concentration0=b_beta)

        a_gg = self.hyperpriors["outputscale"].kwds["a"]
        b_gg = self.hyperpriors["outputscale"].kwds["scale"]

        self.Gamma_tmp = Gamma(concentration=a_gg, rate=1. / b_gg)

    def log_marginal(self, lengthscales, outputscale) -> float:
        """
        """

        # print("lengthscales.shape:",lengthscales.shape)
        # print("outputscale.shape:",outputscale.shape)
        if lengthscales.dim() == 3 or outputscale.dim() == 3:
            Nels = lengthscales.shape[0]
            loss_vec = torch.zeros(Nels)
            for k in range(Nels):
                loss_vec[k] = self.log_marginal(lengthscales[k, 0, :],
                                                outputscale[k, 0, :])
            return loss_vec

        assert lengthscales.dim() <= 1 and outputscale.dim() <= 1

        assert not torch.any(torch.isnan(lengthscales)) and not torch.any(
            torch.isinf(lengthscales)), "lengthscales is inf or NaN"
        assert not torch.isnan(outputscale) and not torch.isinf(
            outputscale), "outputscale is inf or NaN"

        # Update hyperparameters:
        self.model_gp.covar_module.outputscale = outputscale
        self.model_gp.covar_module.base_kernel.lengthscale = lengthscales

        # self.model_gp.display_hyperparameters()

        # Get the log prob of the marginal distribution:
        function_dist = self.model_gp(self.model_gp.train_inputs[0])
        output = self.likelihood_gp(function_dist)
        loss_val = output.log_prob(self.model_gp.train_targets).view(1)

        # if self.debug == True:
        #     pdb.set_trace()

        loss_lengthscales_hyperprior = torch.sum(
            self.Beta_tmp.log_prob(lengthscales)).view(1)
        loss_outputscale_hyperprior = self.Gamma_tmp.log_prob(outputscale)

        # loss_lengthscales_hyperprior = sum(self.hyperpriors["lengthscales"].logpdf(lengthscales))
        # loss_outputscale_hyperprior = self.hyperpriors["outputscale"].logpdf(outputscale).item()

        loss_val += loss_lengthscales_hyperprior + loss_outputscale_hyperprior

        try:
            assert not torch.any(torch.isnan(loss_val)) and not torch.any(
                torch.isinf(loss_val)), "loss_val is Inf or NaN"
        except:  # debug TODO DEBUG
            logger.info("loss_val: {0:s}".format(str(loss_val)))
            logger.info("loss_lengthscales_hyperprior: {0:s}".format(
                str(loss_lengthscales_hyperprior)))
            logger.info("loss_outputscale_hyperprior: {0:s}".format(
                str(loss_outputscale_hyperprior)))

        return loss_val

    def __call__(self, pars_in):
        # Slice only last dimension: https://pytorch.org/docs/stable/tensors.html#torch.Tensor.narrow
        lengthscales = pars_in.narrow(
            dim=-1,
            start=self.model_gp.idx_hyperpars["lengthscales"][0],
            length=len(self.model_gp.idx_hyperpars["lengthscales"]))

        outputscale = pars_in.narrow(
            dim=-1,
            start=self.model_gp.idx_hyperpars["outputscale"][0],
            length=len(self.model_gp.idx_hyperpars["outputscale"]))

        return -self.log_marginal(
            lengthscales,
            outputscale)  # Use minus (-) when minizing the marginal likelihood
Ejemplo n.º 9
0
class BayesianNN:
    def __init__(self, X_train, y_train, batch_size, num_particles,
                 hidden_dim):
        self.gamma_prior = Gamma(torch.tensor(1., device=device),
                                 torch.tensor(1 / 0.1, device=device))
        self.lambda_prior = Gamma(torch.tensor(1., device=device),
                                  torch.tensor(1 / 0.1, device=device))
        self.X_train = X_train
        self.y_train = y_train
        self.batch_size = batch_size
        self.num_particles = num_particles
        self.n_features = X_train.shape[1]
        self.hidden_dim = hidden_dim

    def forward(self, inputs, theta):
        # Unpack theta
        w1 = theta[:, 0:self.n_features * self.hidden_dim].reshape(
            -1, self.n_features, self.hidden_dim)
        b1 = theta[:, self.n_features * self.hidden_dim:(self.n_features + 1) *
                   self.hidden_dim].unsqueeze(1)
        w2 = theta[:, (self.n_features + 1) *
                   self.hidden_dim:(self.n_features + 2) *
                   self.hidden_dim].unsqueeze(2)
        b2 = theta[:, -3].reshape(-1, 1, 1)
        # log_gamma, log_lambda = theta[-2], theta[-1]

        # num_particles times of forward
        inputs = inputs.unsqueeze(0).repeat(self.num_particles, 1, 1)
        inter = F.relu(torch.bmm(inputs, w1) + b1)
        out = torch.bmm(inter, w2) + b2
        out = out.squeeze()
        return out

    def log_prob(self, theta):
        model_gamma = torch.exp(theta[:, -2])
        model_lambda = torch.exp(theta[:, -1])
        model_w = theta[:, :-2]
        # w_prior should be decided based on current lambda (not sure)
        w_prior = Normal(
            0, torch.sqrt(torch.ones_like(model_lambda) / model_lambda))

        random_idx = random.sample([i for i in range(self.X_train.shape[0])],
                                   self.batch_size)
        X_batch = self.X_train[random_idx]
        y_batch = self.y_train[random_idx]

        outputs = self.forward(X_batch, theta)  # [num_particles, batch_size]
        model_gamma_repeat = model_gamma.unsqueeze(1).repeat(
            1, self.batch_size)
        y_batch_repeat = y_batch.unsqueeze(0).repeat(self.num_particles, 1)
        distribution = Normal(
            outputs,
            torch.sqrt(
                torch.ones_like(model_gamma_repeat) / model_gamma_repeat))
        log_p_data = distribution.log_prob(y_batch_repeat).sum(dim=1)

        log_p0 = w_prior.log_prob(
            model_w.t()).sum(dim=0) + self.gamma_prior.log_prob(
                model_gamma) + self.lambda_prior.log_prob(model_lambda)
        log_p = log_p0 + log_p_data * (self.X_train.shape[0] / self.batch_size
                                       )  # (8) in paper
        return log_p
Ejemplo n.º 10
0
class MAMLParticles(MetaNetwork):
    """
    Object that contains all the particles.
    """
    def __init__(self,
                 feature_extractor_params,
                 lr_chaser=0.001,
                 lr_leader=None,
                 n_epochs_chaser=1,
                 n_epochs_predict=0,
                 s_epochs_leader=1,
                 m_particles=2,
                 kernel_function='rbf',
                 n_samples=10,
                 a_likelihood=2.,
                 b_likelihood=.2,
                 a_prior=2.,
                 b_prior=.2,
                 use_mse=False):
        """
        Initialises the object.

        Parameters
        ----------
        feature_extractor_params: dict
            Parameters for the feature extractor.
        lr_chaser: float
            Learning rate for the chaser
        lr_leader: float
            Learning rate for the leader
        n_epochs_chaser: int
            Number of steps to be performed by the chaser.
        s_epochs_leader: int
            Number of steps to be performed by the leader.
        m_particles:
            Number of particles.
        kernel_function: str, {'rbf', 'quadratic'}
            The kernel function to use.
        use_mse: bool
            Whether to use MSE loss or Chaser loss.
        """

        super(MAMLParticles, self).__init__()

        self.kernel_function = kernel_function

        self.n_epochs_chaser = n_epochs_chaser
        self.s_epochs_leader = s_epochs_leader

        self.n_epochs_predict = n_epochs_predict

        if lr_leader is None:
            lr_leader = lr_chaser / 10

        self.lr = {
            'chaser': lr_chaser,
            'leader': lr_leader,
        }

        self.m_particles = m_particles

        self.n_samples = n_samples

        self.feature_extractor = FeaturesExtractorFactory()(
            **feature_extractor_params)
        self.fe_output_dim = self.feature_extractor.output_dim

        self.gamma_likelihood = Gamma(a_likelihood, b_likelihood)
        self.gamma_prior = Gamma(a_prior, b_prior)

        # The particles only implement the last (linear) layer.
        # The first two columns are the kappas (likelihood then prior)
        self.particles = nn.Parameter(
            torch.cat((
                self.gamma_likelihood.sample((m_particles, 1)),
                self.gamma_prior.sample((m_particles, 1)),
                nn.init.kaiming_uniform(
                    torch.empty((m_particles, self.fe_output_dim + 1))),
            ),
                      dim=1))

        self.loss = 0

        self.use_mse = use_mse

    @property
    def return_var(self):
        return True

    def kernel(self, weights):
        """
        Computes the cross-particle kernel. Given the stacked parameter vectors of the particles,
        outputs the kernel (be it RBF or quadratic).

        Parameters
        ----------
        weights: torch.Tensor
            B * M * M * (D + 1) tensor. Expanded versions of the weights.

        Returns
        -------
        kernel: torch.Tensor
            B * M * M tensor representing the cross-particle kernel.
        """
        def rbf_kernel(pv):
            """
            Computes the RBF kernel for a set of parameter vectors.

            Parameters
            ----------
            pv: torch.Tensor
                Stack of flatten parameters for each particle.

            Returns
            -------
            kernel: m x m torch.Tensor
                A m x m torch tensor representing the kernel.
            """

            x = pv - pv.transpose(1, 2)
            x = -x.norm(2, dim=3).pow(2) / 2
            x = x.exp()

            return x

        def quadratic_kernel(pv):
            """
            Computes the RBF kernel for a set of parameter vectors.

            Parameters
            ----------
            pv: torch.Tensor
                Stack of flatten parameters for each particle.

            Returns
            -------
            kernel: m x m torch.Tensor
                A m x m torch tensor representing the kernel.
            """

            x = pv - pv.transpose(1, 2)
            x = -x.norm(2, dim=3).pow(2)
            x = 1 / x

            return x

        kernel_functions = {'rbf': rbf_kernel, 'quadratic': quadratic_kernel}

        kernel = kernel_functions[self.kernel_function]

        return kernel(weights)

    @staticmethod
    def compute_predictions(features, parameters):
        """

        Parameters
        ----------
        features: torch.Tensor
            B * N * D tensor representing the features.
        parameters: torch.Tensor
            B * M * (D + 3) tensor representing the M particles
            (including the bias-feature trick and two kappa vectors).

        Returns
        -------
        predictions: torch.Tensor
            B * M * N tensor, representing the predictions.
        """
        # Obtains the weights
        weights = parameters[..., 2:]

        # Implements the bias-feature trick
        features = torch.cat((features, torch.ones_like(features[..., :1])),
                             dim=2)

        predictions = torch.bmm(weights, features.transpose(1, 2))

        return predictions

    def compute_mean_std(self, features, parameters):
        """

        Parameters
        ----------
        features: torch.Tensor
            B * N * D tensor representing the features.
        parameters: torch.Tensor
            B * M * (D + 3) tensor representing the M particles
            (including the bias-feature trick and two kappa vectors).

        Returns
        -------
        predictions: torch.Tensor
            B * M * N tensor, representing the predictions.
        """
        # Obtains the kappas (B * M)
        kappa_likelihood = parameters[..., 0]

        # Computes the predictions (B * M * N)
        predictions = self.compute_predictions(features, parameters)

        # Transposes the predictions to B * N * M
        predictions = predictions.transpose(1, 2)

        # Computes the mean
        mean = predictions.mean(dim=2)

        # Adds the variability
        variability = torch.randn(
            (*predictions.size(), self.n_samples)).to(mean.device)
        variability = variability / kappa_likelihood.unsqueeze(1).unsqueeze(
            3).pow(.5)
        predictions = predictions.unsqueeze(3) + variability

        # Reshapes the predictions to B * N * (M x S), where S is the number of samples
        predictions = predictions.view(*predictions.shape[:2], -1)

        # mean = predictions.mean(dim=2)
        std = predictions.std(dim=2)

        return mean, std

    def posterior(self, predictions, targets, mask, weights, kappa_likelihood,
                  kappa_prior):
        r"""
        Computes the posterior of the configuration.

        Parameters
        ----------
        predictions: torch.Tensor
            B * M * N tensor representing the prediction made by the network.
        targets: torch.Tensor
            B * N * 1 tensor representing the targets.
        mask: torch.Tensor
            B * N mask of the examples (some tasks have less than N examples).
        weights: torch.Tensor
            B * M * (D + 1) tensor representing the weights, including the bias-feature trick
        kappa_likelihood: torch.Tensor:
            B * M tensor representing $\kappa_{likelihood}$.
        kappa_prior: torch.Tensor:
            B * M tensor representing $\kappa_{prior}$.

        Returns
        -------
        objective: torch.Tensor
            B * M tensor, representing the posterior of each particle, for each batch.
        """
        # Computing the log-likelihood
        log_likelihood = log_pdf(predictions - targets.transpose(1, 2),
                                 kappa_likelihood)  # B * M * N
        log_likelihood = log_likelihood * mask.unsqueeze(
            1)  # Keep only the actual examples
        log_likelihood = log_likelihood.sum(dim=2)

        # We enforce a Gaussian prior on the weights
        log_prior = log_pdf(weights[..., :-1], kappa_prior).sum(dim=2)

        # Gamma prior on the kappas
        log_prior_kappa = self.gamma_likelihood.log_prob(kappa_likelihood)
        log_prior_kappa = log_prior_kappa + self.gamma_prior.log_prob(
            kappa_prior)

        objective = log_likelihood + log_prior + log_prior_kappa

        return objective

    def svgd(self, features, targets, mask, parameters, update_type='chaser'):
        r"""
        Performs the Stein Variational Gradient Update on the particles.

        For each particle, the update is given by
        :math:`\theta_{t+1} \gets \theta_t + \varepsilon_t \phi(\theta_t)` where:
        .. math::

            \phi(\theta_t) = \frac{1}{M} \sum_{m=1}^M \left[ k(\theta_t^{(m)}, \theta_t)
            \nabla_{\theta_t^{(m)}} \log p(\theta_t^{(m)}) +
            \nabla_{\theta_t^{(m)}} k(\theta_t^{(m)}, \theta_t) \right]

        Parameters
        ----------
        features: torch.Tensor
            B * N * D tensor. The precomputed features associated with the dataset.
        targets: torch.Tensor
            B * N * 1 tensor. The targets associated to the features. Useful to compute the posterior.
        mask: torch.Tensor
            B * N mask of the examples (some tasks have less than N examples).
        parameters: torch.Tensor
            B * M * (D + 3) tensor containing the full parameters, already expanded along a batch dimension.
        update_type: str, 'chaser' or 'leader'
            Defines which learning rate to use.
        """

        # Expands the parameters : B * M * (D + 3) -> B * M * M * (D + 3)
        expanded_parameters = parameters.unsqueeze(1)
        expanded_parameters = expanded_parameters.expand(
            (parameters.size(0), self.m_particles, *parameters.shape[1:]))

        # Splits the different parameters
        kappa_likelihood = parameters[..., 0]
        kappa_prior = parameters[..., 1]
        weights = parameters[..., 2:]

        expanded_weights = expanded_parameters[..., 2:]

        # weights is B * M * (D + 1), features is B * N * D
        # predictions is B * M * N
        predictions = self.compute_predictions(features, parameters)

        # B * M * M
        kernel = self.kernel(expanded_weights)

        # B * M
        objectives = self.posterior(
            predictions=predictions,
            targets=targets,
            mask=mask,
            weights=weights,
            kappa_likelihood=kappa_likelihood,
            kappa_prior=kappa_prior,
        )

        # Computes the gradients for the objective (B * M * (D + 3))
        objective_grads = autograd.grad(objectives.sum(),
                                        parameters,
                                        create_graph=True)[0]

        # Computes the gradients for the kernel, using the expanded parameters (B * M * M * (D + 3))
        kernel_grads = autograd.grad(kernel.sum(),
                                     expanded_parameters,
                                     create_graph=True)[0]

        # Computes the update
        # The matmul term multiplies batches of matrices that are B * M * M and B * M * (D + 3)
        update = torch.matmul(
            kernel, objective_grads) / self.m_particles + kernel_grads.mean(
                dim=2)

        # Performs the update
        new_parameters = parameters + self.lr[update_type] * update

        # We need to make sure that the kappas remain in the right range for numerical stability
        new_parameters = torch.cat([
            torch.clamp(new_parameters[..., :2], min=1e-8), new_parameters[...,
                                                                           2:]
        ],
                                   dim=2)

        return new_parameters

    def forward(self,
                episodes,
                train=None,
                test=None,
                query=None,
                trim_ends=True):
        """
        Performs a forward and backward pass on a single episode.
        To keep memory load low, the backward pass is done simultaneously.

        Parameters
        ----------
        episodes: list
            A batch of meta-learning episodes.
        train: dataset
            The train dataset.
        test: dataset
            The test dataset.
        query: dataset
            The query dataset.
        trim_ends: bool
            Whether to trim the results.

        Returns
        -------
        results: list(tuple)
            A list of tuples containing the mean and standard deviation computed by the network
            for each episodes.
        query_results: list(tuple)
            A list of tuples containing the mean and standard deviation computed by the network
            for each episodes of the query set.
        """

        if episodes is not None:
            train, test = pack_episodes(episodes,
                                        return_ys_test=True,
                                        return_query=False)
            x_test, y_test, len_test, mask_test = test
            query = None
        else:
            assert (train is not None) and (test is not None)
            x_test, len_test, mask_test = test

        # x is B * N * D dimensional, y is B * N * 1 dimensional
        x_train, y_train, len_train, mask_train = train

        b, n, d = x_train.size()

        train_features = self.feature_extractor(x_train.reshape(
            -1, d)).reshape(b, -1, self.fe_output_dim)
        test_features = self.feature_extractor(x_test.reshape(-1, d)).reshape(
            b, -1, self.fe_output_dim)

        # Expands the parameters along the batch dimension : M * (D + 3) -> B * M * (D + 3)
        parameters = self.particles.unsqueeze(0).expand(
            (b, *self.particles.size()))

        with autograd.enable_grad():
            # Initialise the chaser as a new tensor
            chaser = parameters + 0.
            for i in range(self.n_epochs_chaser):
                chaser = self.svgd(train_features,
                                   y_train,
                                   mask_train,
                                   parameters=chaser,
                                   update_type='chaser')

            if self.training and not self.use_mse:
                full_features = torch.cat((train_features, test_features),
                                          dim=1)
                y_full = torch.cat((y_train, y_test), dim=1)
                mask_full = torch.cat((mask_train, mask_test), dim=1)

                leader = chaser + 0.
                for i in range(self.s_epochs_leader):
                    leader = self.svgd(full_features,
                                       y_full,
                                       mask_full,
                                       parameters=leader,
                                       update_type='leader')

                # Added stability
                self.loss = (leader.detach() - chaser)[...,
                                                       2:].pow(2).sum() / b

        with autograd.enable_grad():
            for i in range(self.n_epochs_predict):
                chaser = self.svgd(train_features,
                                   y_train,
                                   mask_train,
                                   parameters=chaser,
                                   update_type='chaser')

        # Computes the mean and standard deviation
        mean, std = self.compute_mean_std(test_features, chaser)

        # Unsqueezes the results to keep the same shape as the targets
        mean = mean.unsqueeze(2)
        std = std.unsqueeze(2)

        # Re-organises the results in the episodic form
        mean = [m[:n] for m, n in zip(mean, len_test)]
        std = [s[:n] for s, n in zip(std, len_test)]

        results = [(m[:n], s[:n]) for m, s, n in zip(mean, std, len_test)
                   ] if trim_ends else (mean, std)

        if query is None:
            return results

        x_query, _, len_query, mask_query = query

        query_features = self.feature_extractor(x_query.reshape(
            -1, d)).reshape(b, -1, self.fe_output_dim)

        mean, std = self.compute_mean_std(query_features, chaser)

        # Unsqueezes the results to keep the same shape as the targets
        mean = mean.unsqueeze(2)
        std = std.unsqueeze(2)

        query_results = [
            (m[:n], s[:n]) for m, s, n in zip(mean, std, len_test)
        ] if trim_ends else (mean, std)

        return results, query_results