Ejemplo n.º 1
0
 def test_repeat_rows(self):
     x = torch.randn(2, 3, 4, 5)
     self.assertEqual(torchutils.repeat_rows(x, 1), x)
     y = torchutils.repeat_rows(x, 2)
     self.assertEqual(y.shape, torch.Size([4, 3, 4, 5]))
     self.assertEqual(x[0], y[0])
     self.assertEqual(x[0], y[1])
     self.assertEqual(x[1], y[2])
     self.assertEqual(x[1], y[3])
     with self.assertRaises(Exception):
         torchutils.repeat_rows(x, 0)
Ejemplo n.º 2
0
    def _sample_approx_posterior_mog(
        self, num_samples, x: Tensor, batch_size: int
    ) -> Tensor:
        r"""
        Sample from the approximate posterior.

        Args:
            num_samples: Desired number of samples.
            x: Conditioning context for posterior $p(\theta|x)$.
            batch_size: Batch size for sampling.

        Returns:
            Samples from the approximate mixture of Gaussians posterior.
        """

        # Compute the mixture components of the posterior.
        logits_p, m_p, prec_p = self._posthoc_correction(x)

        # Compute the precision factors which represent the upper triangular matrix
        # of the cholesky decomposition of the prec_p.
        prec_factors_p = torch.cholesky(prec_p, upper=True)

        assert logits_p.ndim == 2
        assert m_p.ndim == 3
        assert prec_p.ndim == 4
        assert prec_factors_p.ndim == 4

        # Replicate to use batched sampling from pyknos.
        if batch_size is not None and batch_size > 1:
            logits_p = logits_p.repeat(batch_size, 1)
            m_p = m_p.repeat(batch_size, 1, 1)
            prec_factors_p = prec_factors_p.repeat(batch_size, 1, 1, 1)

        # Get (optionally z-scored) MoG samples.
        theta = MultivariateGaussianMDN.sample_mog(
            num_samples, logits_p, m_p, prec_factors_p
        )

        embedded_context = self._neural_net._embedding_net(x)
        if embedded_context is not None:
            # Merge the context dimension with sample dimension in order to
            # apply the transform.
            theta = torchutils.merge_leading_dims(theta, num_dims=2)
            embedded_context = torchutils.repeat_rows(
                embedded_context, num_reps=num_samples
            )

        theta, _ = self._neural_net._transform.inverse(theta, context=embedded_context)

        if embedded_context is not None:
            # Split the context dimension from sample dimension.
            theta = torchutils.split_leading_dim(theta, shape=[-1, num_samples])

        return theta