def sample_and_log_prob(self, num_samples, context=None): """Generates samples from the distribution together with their log probability. Args: num_samples: int, number of samples to generate. context: Tensor or None, conditioning variables. If None, the context is ignored. Returns: A tuple of: * A Tensor containing the samples, with shape [num_samples, ...] if context is None, or [context_size, num_samples, ...] if context is given. * A Tensor containing the log probabilities of the samples, with shape [num_samples, ...] if context is None, or [context_size, num_samples, ...] if context is given. """ samples = self.sample(num_samples, context=context) if context is not None: # Merge the context dimension with sample dimension in order to call log_prob. samples = torchutils.merge_leading_dims(samples, num_dims=2) context = torchutils.repeat_rows(context, num_reps=num_samples) assert samples.shape[0] == context.shape[0] log_prob = self.log_prob(samples, context=context) if context is not None: # Split the context dimension from sample dimension. samples = torchutils.split_leading_dim(samples, shape=[-1, num_samples]) log_prob = torchutils.split_leading_dim(log_prob, shape=[-1, num_samples]) return samples, log_prob
def sample_and_log_prob(self, num_samples, context=None): """Generates samples from the flow, together with their log probabilities. For flows, this is more efficient that calling `sample` and `log_prob` separately. """ embedded_context = self._embedding_net(context) noise, log_prob = self._distribution.sample_and_log_prob( num_samples, context=embedded_context ) if embedded_context is not None: # Merge the context dimension with sample dimension in order to apply the transform. noise = torchutils.merge_leading_dims(noise, num_dims=2) embedded_context = torchutils.repeat_rows( embedded_context, num_reps=num_samples ) samples, logabsdet = self._transform.inverse(noise, context=embedded_context) if embedded_context is not None: # Split the context dimension from sample dimension. samples = torchutils.split_leading_dim(samples, shape=[-1, num_samples]) logabsdet = torchutils.split_leading_dim(logabsdet, shape=[-1, num_samples]) return samples, log_prob - logabsdet
def test_split_leading_dim(self): x = torch.randn(24, 5) self.assertEqual(torchutils.split_leading_dim(x, [-1]), x) self.assertEqual(torchutils.split_leading_dim(x, [2, -1]), x.view(2, 12, 5)) self.assertEqual(torchutils.split_leading_dim(x, [2, 3, -1]), x.view(2, 3, 4, 5)) with self.assertRaises(Exception): self.assertEqual(torchutils.split_leading_dim(x, []), x) with self.assertRaises(Exception): self.assertEqual(torchutils.split_leading_dim(x, [5, 5]), x)
def test_split_merge_leading_dims_are_consistent(self): x = torch.randn(2, 3, 4, 5) y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 1), [2]) self.assertEqual(y, x) y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 2), [2, 3]) self.assertEqual(y, x) y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 3), [2, 3, 4]) self.assertEqual(y, x) y = torchutils.split_leading_dim( torchutils.merge_leading_dims(x, 4), [2, 3, 4, 5] ) self.assertEqual(y, x)
def _sample(self, num_samples, context): if context is None: return torch.randn(num_samples, *self._shape) else: # The value of the context is ignored, only its size is taken into account. context_size = context.shape[0] samples = torch.randn(context_size * num_samples, *self._shape) return torchutils.split_leading_dim(samples, [context_size, num_samples])
def _sample(self, num_samples, context): context_size = 1 if context is None else context.shape[0] low_expanded = self._low.expand(context_size * num_samples, *self._low.shape) high_expanded = self._high.expand(context_size * num_samples, *self._low.shape) samples = low_expanded + torch.rand(context_size * num_samples, *self._low.shape, device=self._low.device)*(high_expanded - low_expanded) if context is None: return samples else: return torchutils.split_leading_dim(samples, [context_size, num_samples])
def _sample(self, num_samples, context): # Compute parameters. logits = self._compute_params(context) probs = torch.sigmoid(logits) probs = torchutils.repeat_rows(probs, num_samples) # Generate samples. context_size = context.shape[0] noise = torch.rand(context_size * num_samples, *self._shape) samples = (noise < probs).float() return torchutils.split_leading_dim(samples, [context_size, num_samples])
def _sample(self, num_samples, context): # Compute parameters. means, log_stds = self._compute_params(context) stds = torch.exp(log_stds) means = torchutils.repeat_rows(means, num_samples) stds = torchutils.repeat_rows(stds, num_samples) # Generate samples. context_size = context.shape[0] noise = torch.randn(context_size * num_samples, *self._shape) samples = means + stds * noise return torchutils.split_leading_dim(samples, [context_size, num_samples])
def _sample(self, num_samples, context): embedded_context = self._embedding_net(context) noise = self._distribution.sample(num_samples, context=embedded_context) if embedded_context is not None: # Merge the context dimension with sample dimension in order to apply the transform. noise = torchutils.merge_leading_dims(noise, num_dims=2) embedded_context = torchutils.repeat_rows( embedded_context, num_reps=num_samples ) samples, _ = self._transform.inverse(noise, context=embedded_context) if embedded_context is not None: # Split the context dimension from sample dimension. samples = torchutils.split_leading_dim(samples, shape=[-1, num_samples]) return samples
def _sample(self, num_samples, context): if context is None: return self.dist.rsample( sample_shape=[num_samples] ) #torch.randn(num_samples, *self._shape, device=self._log_z.device) else: # The value of the context is ignored, only its size and device are taken into account. context_size = context.shape[0] # samples = self.dist.low + (self.dist.high - self.dist.low) * torch.rand(context_size * num_samples, # *self._shape, # device=context.device) # context_size * num_samples is used to adjust for how many cases of context we have! samples = self.dist.rsample( sample_shape=[context_size * num_samples]) return torchutils.split_leading_dim(samples, [context_size, num_samples])