def test_split_merge_leading_dims_are_consistent(self): x = torch.randn(2, 3, 4, 5) y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 1), [2]) self.assertEqual(y, x) y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 2), [2, 3]) self.assertEqual(y, x) y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 3), [2, 3, 4]) self.assertEqual(y, x) y = torchutils.split_leading_dim( torchutils.merge_leading_dims(x, 4), [2, 3, 4, 5] ) self.assertEqual(y, x)
def test_merge_leading_dims(self): x = torch.randn(2, 3, 4, 5) self.assertEqual(torchutils.merge_leading_dims(x, 1), x) self.assertEqual(torchutils.merge_leading_dims(x, 2), x.view(6, 4, 5)) self.assertEqual(torchutils.merge_leading_dims(x, 3), x.view(24, 5)) self.assertEqual(torchutils.merge_leading_dims(x, 4), x.view(120)) with self.assertRaises(Exception): torchutils.merge_leading_dims(x, 0) with self.assertRaises(Exception): torchutils.merge_leading_dims(x, 5)
def _sample_approx_posterior_mog( self, num_samples, x: Tensor, batch_size: int ) -> Tensor: r""" Sample from the approximate posterior. Args: num_samples: Desired number of samples. x: Conditioning context for posterior $p(\theta|x)$. batch_size: Batch size for sampling. Returns: Samples from the approximate mixture of Gaussians posterior. """ # Compute the mixture components of the posterior. logits_p, m_p, prec_p = self._posthoc_correction(x) # Compute the precision factors which represent the upper triangular matrix # of the cholesky decomposition of the prec_p. prec_factors_p = torch.cholesky(prec_p, upper=True) assert logits_p.ndim == 2 assert m_p.ndim == 3 assert prec_p.ndim == 4 assert prec_factors_p.ndim == 4 # Replicate to use batched sampling from pyknos. if batch_size is not None and batch_size > 1: logits_p = logits_p.repeat(batch_size, 1) m_p = m_p.repeat(batch_size, 1, 1) prec_factors_p = prec_factors_p.repeat(batch_size, 1, 1, 1) # Get (optionally z-scored) MoG samples. theta = MultivariateGaussianMDN.sample_mog( num_samples, logits_p, m_p, prec_factors_p ) embedded_context = self._neural_net._embedding_net(x) if embedded_context is not None: # Merge the context dimension with sample dimension in order to # apply the transform. theta = torchutils.merge_leading_dims(theta, num_dims=2) embedded_context = torchutils.repeat_rows( embedded_context, num_reps=num_samples ) theta, _ = self._neural_net._transform.inverse(theta, context=embedded_context) if embedded_context is not None: # Split the context dimension from sample dimension. theta = torchutils.split_leading_dim(theta, shape=[-1, num_samples]) return theta