Example #1
0
def test_ensure_batch_dim():
    # test if batch dimension is added when parameter is ndim==1
    t1 = torch.tensor([0.0, -1.0, 1.0])
    t2 = torchutils.ensure_theta_batched(t1)
    assert t2.ndim == 2

    # test if batch dimension is added when observation is ndim==1
    t1 = torch.tensor([0.0, -1.0, 1.0])
    t2 = torchutils.ensure_x_batched(t1)
    assert t2.ndim == 2

    # then test if batch dimension is added when observation is ndim==2, e.g. an image
    t1 = torch.tensor([[1, 2, 3], [1, 2, 3]])
    t2 = torchutils.ensure_x_batched(t1)
    assert t2.ndim == 3
Example #2
0
    def np_potential(self, theta: np.array) -> ScalarFloat:
        """Return potential for Numpy slice sampler."

        Args:
            theta: Parameters $\theta$, batch dimension 1.

        Returns:
            Posterior log probability of theta.
        """
        theta = torch.as_tensor(theta, dtype=torch.float32)
        theta = ensure_theta_batched(theta)
        num_batch = theta.shape[0]
        x_batched = ensure_x_batched(self.x)
        # Repeat x over batch dim to match theta batch, accounting for multi-D x.
        x_repeated = x_batched.repeat(num_batch,
                                      *(1 for _ in range(x_batched.ndim - 1)))

        assert (
            x_batched.ndim == 2
        ), """X must not be multidimensional for ratio-based methods because it will be
              concatenated with theta."""
        with torch.set_grad_enabled(False):
            log_ratio = (self.classifier(
                torch.cat((theta.to(self.x.device), x_repeated),
                          dim=1)).reshape(-1).cpu())

        # Notice opposite sign to pyro potential.
        return log_ratio + self.prior.log_prob(theta)
Example #3
0
    def np_potential(self, theta: np.ndarray) -> ScalarFloat:
        r"""Return posterior theta log prob. $p(\theta|x)$, $-\infty$ if outside prior."

        Args:
            theta: Parameters $\theta$, batch dimension 1.

        Returns:
            Posterior log probability $\log(p(\theta|x))$.
        """
        theta = torch.as_tensor(theta, dtype=torch.float32)
        theta = ensure_theta_batched(theta)
        num_batch = theta.shape[0]

        x_batched = ensure_x_batched(self.x)
        # Repeat x over batch dim to match theta batch, accounting for multi-D x.
        x_repeated = x_batched.repeat(num_batch,
                                      *(1 for _ in range(x_batched.ndim - 1)))

        with torch.set_grad_enabled(False):
            target_log_prob = self.posterior_nn.log_prob(
                inputs=theta.to(self.x.device),
                context=x_repeated,
            )
            is_within_prior = torch.isfinite(self.prior.log_prob(theta))
            target_log_prob[~is_within_prior] = -float("Inf")

        return target_log_prob
Example #4
0
    def np_potential(self, theta: np.array) -> ScalarFloat:
        r"""Return posterior log prob. of theta $p(\theta|x)$"

        Args:
            theta: Parameters $\theta$, batch dimension 1.

        Returns:
            Posterior log probability of the theta, $-\infty$ if impossible under prior.
        """
        theta = torch.as_tensor(theta, dtype=torch.float32)
        theta = ensure_theta_batched(theta)
        num_batch = theta.shape[0]
        x = ensure_x_batched(self.x).repeat(num_batch, 1)

        with torch.set_grad_enabled(False):
            log_likelihood = self.likelihood_nn.log_prob(inputs=x, context=theta)

        # Notice opposite sign to pyro potential.
        return log_likelihood + self.prior.log_prob(theta)
Example #5
0
    def np_potential(self, theta: np.array) -> ScalarFloat:
        """Return potential for Numpy slice sampler."

        Args:
            theta: Parameters $\theta$, batch dimension 1.

        Returns:
            Posterior log probability of theta.
        """
        theta = torch.as_tensor(theta, dtype=torch.float32)

        # Theta and x should have shape (1, dim).
        theta = ensure_theta_batched(theta)
        x = ensure_x_batched(self.x)

        log_ratio = self.classifier(
            torch.cat((theta, x), dim=1).reshape(1, -1))

        # Notice opposite sign to pyro potential.
        return log_ratio + self.prior.log_prob(theta)
Example #6
0
    def np_potential(self, theta: np.array) -> ScalarFloat:
        """Return potential for Numpy slice sampler."

        Args:
            theta: Parameters $\theta$, batch dimension 1.

        Returns:
            Posterior log probability of theta.
        """
        theta = torch.as_tensor(theta, dtype=torch.float32)
        theta = ensure_theta_batched(theta)
        num_batch = theta.shape[0]
        x = ensure_x_batched(self.x).repeat(num_batch, 1)

        with torch.set_grad_enabled(False):
            log_ratio = self.classifier(torch.cat((theta, x),
                                                  dim=1)).reshape(-1)

        # Notice opposite sign to pyro potential.
        return log_ratio + self.prior.log_prob(theta)
Example #7
0
    def pyro_potential(self, theta: Dict[str, Tensor]) -> Tensor:
        r"""Return potential for Pyro sampler.

        Args:
            theta: Parameters $\theta$. The tensor's shape will be
             (1, shape_of_single_theta) if running a single chain or just
             (shape_of_single_theta) for multiple chains.

        Returns:
            Potential $-(\log r(x_o, \theta) + \log p(\theta))$.
        """

        theta = next(iter(theta.values()))

        # Theta and x should have shape (1, dim).
        theta = ensure_theta_batched(theta)
        x = ensure_x_batched(self.x)

        log_ratio = self.classifier([theta.to(x.device), x]).cpu()

        return -(log_ratio + self.prior.log_prob(theta))
Example #8
0
    def np_potential(self, theta: np.ndarray) -> ScalarFloat:
        r"""Return posterior theta log prob. $p(\theta|x)$, $-\infty$ if outside prior."

        Args:
            theta: Parameters $\theta$, batch dimension 1.

        Returns:
            Posterior log probability $\log(p(\theta|x))$.
        """
        theta = torch.as_tensor(theta, dtype=torch.float32)
        theta = ensure_theta_batched(theta)
        num_batch = theta.shape[0]
        x = ensure_x_batched(self.x).repeat(num_batch, 1)

        with torch.set_grad_enabled(False):
            target_log_prob = self.posterior_nn.log_prob(
                inputs=theta,
                context=self.x,
            )
            is_within_prior = torch.isfinite(self.prior.log_prob(theta))
            target_log_prob[~is_within_prior] = -float("Inf")

        return target_log_prob