コード例 #1
0
 def test_split_merge_leading_dims_are_consistent(self):
     x = torch.randn(2, 3, 4, 5)
     y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 1), [2])
     self.assertEqual(y, x)
     y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 2), [2, 3])
     self.assertEqual(y, x)
     y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 3), [2, 3, 4])
     self.assertEqual(y, x)
     y = torchutils.split_leading_dim(
         torchutils.merge_leading_dims(x, 4), [2, 3, 4, 5]
     )
     self.assertEqual(y, x)
コード例 #2
0
 def test_merge_leading_dims(self):
     x = torch.randn(2, 3, 4, 5)
     self.assertEqual(torchutils.merge_leading_dims(x, 1), x)
     self.assertEqual(torchutils.merge_leading_dims(x, 2), x.view(6, 4, 5))
     self.assertEqual(torchutils.merge_leading_dims(x, 3), x.view(24, 5))
     self.assertEqual(torchutils.merge_leading_dims(x, 4), x.view(120))
     with self.assertRaises(Exception):
         torchutils.merge_leading_dims(x, 0)
     with self.assertRaises(Exception):
         torchutils.merge_leading_dims(x, 5)
コード例 #3
0
ファイル: snpe_a.py プロジェクト: michaeldeistler/sbi
    def _sample_approx_posterior_mog(
        self, num_samples, x: Tensor, batch_size: int
    ) -> Tensor:
        r"""
        Sample from the approximate posterior.

        Args:
            num_samples: Desired number of samples.
            x: Conditioning context for posterior $p(\theta|x)$.
            batch_size: Batch size for sampling.

        Returns:
            Samples from the approximate mixture of Gaussians posterior.
        """

        # Compute the mixture components of the posterior.
        logits_p, m_p, prec_p = self._posthoc_correction(x)

        # Compute the precision factors which represent the upper triangular matrix
        # of the cholesky decomposition of the prec_p.
        prec_factors_p = torch.cholesky(prec_p, upper=True)

        assert logits_p.ndim == 2
        assert m_p.ndim == 3
        assert prec_p.ndim == 4
        assert prec_factors_p.ndim == 4

        # Replicate to use batched sampling from pyknos.
        if batch_size is not None and batch_size > 1:
            logits_p = logits_p.repeat(batch_size, 1)
            m_p = m_p.repeat(batch_size, 1, 1)
            prec_factors_p = prec_factors_p.repeat(batch_size, 1, 1, 1)

        # Get (optionally z-scored) MoG samples.
        theta = MultivariateGaussianMDN.sample_mog(
            num_samples, logits_p, m_p, prec_factors_p
        )

        embedded_context = self._neural_net._embedding_net(x)
        if embedded_context is not None:
            # Merge the context dimension with sample dimension in order to
            # apply the transform.
            theta = torchutils.merge_leading_dims(theta, num_dims=2)
            embedded_context = torchutils.repeat_rows(
                embedded_context, num_reps=num_samples
            )

        theta, _ = self._neural_net._transform.inverse(theta, context=embedded_context)

        if embedded_context is not None:
            # Split the context dimension from sample dimension.
            theta = torchutils.split_leading_dim(theta, shape=[-1, num_samples])

        return theta