Ejemplo n.º 1
0
    def predict(
            self,
            x: Union[torch.Tensor, np.ndarray],
            probability_space: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Query the model for posterior mean and variance.

        Args:
            x (torch.Tensor): Points at which to predict from the model.
            probability_space (bool, optional): Return outputs in units of
                response probability instead of latent function value. Defaults to False.

        Returns:
            Tuple[np.ndarray, np.ndarray]: Posterior mean and variance at queries points.
        """
        with torch.no_grad():
            post = self.posterior(x)
        fmean = post.mean.squeeze()
        fvar = post.variance.squeeze()
        if probability_space:
            if isinstance(self.likelihood, BernoulliLikelihood):
                # Probability-space mean and variance for Bernoulli-probit models is
                # available in closed form, Proposition 1 in Letham et al. 2022 (AISTATS).
                a_star = fmean / torch.sqrt(1 + fvar)
                pmean = Normal(0, 1).cdf(a_star)
                t_term = torch.tensor(
                    owens_t(a_star.numpy(), 1 / np.sqrt(1 + 2 * fvar.numpy())),
                    dtype=a_star.dtype,
                )
                pvar = pmean - 2 * t_term - pmean.square()
                return promote_0d(pmean), promote_0d(pvar)
            else:
                fsamps = post.sample(torch.Size([10000]))
                if hasattr(self.likelihood, "objective"):
                    psamps = self.likelihood.objective(fsamps)
                else:
                    psamps = norm.cdf(fsamps)
                pmean, pvar = psamps.mean(0), psamps.var(0)
                return promote_0d(pmean), promote_0d(pvar)

        else:
            return promote_0d(fmean), promote_0d(fvar)