def sample_and_log_prob(self, num_samples, context=None):
        """Generates samples from the distribution together with their log probability.

        Args:
            num_samples: int, number of samples to generate.
            context: Tensor or None, conditioning variables. If None, the context is ignored.

        Returns:
            A tuple of:
                * A Tensor containing the samples, with shape [num_samples, ...] if context is None,
                  or [context_size, num_samples, ...] if context is given.
                * A Tensor containing the log probabilities of the samples, with shape
                  [num_samples, ...] if context is None, or [context_size, num_samples, ...] if
                  context is given.
        """
        samples = self.sample(num_samples, context=context)

        if context is not None:
            # Merge the context dimension with sample dimension in order to call log_prob.
            samples = torchutils.merge_leading_dims(samples, num_dims=2)
            context = torchutils.repeat_rows(context, num_reps=num_samples)
            assert samples.shape[0] == context.shape[0]

        log_prob = self.log_prob(samples, context=context)

        if context is not None:
            # Split the context dimension from sample dimension.
            samples = torchutils.split_leading_dim(samples,
                                                   shape=[-1, num_samples])
            log_prob = torchutils.split_leading_dim(log_prob,
                                                    shape=[-1, num_samples])

        return samples, log_prob
Exemple #2
0
    def sample_and_log_prob(self, num_samples, context=None):
        """Generates samples from the flow, together with their log probabilities.

        For flows, this is more efficient that calling `sample` and `log_prob` separately.
        """
        embedded_context = self._embedding_net(context)
        noise, log_prob = self._distribution.sample_and_log_prob(
            num_samples, context=embedded_context
        )

        if embedded_context is not None:
            # Merge the context dimension with sample dimension in order to apply the transform.
            noise = torchutils.merge_leading_dims(noise, num_dims=2)
            embedded_context = torchutils.repeat_rows(
                embedded_context, num_reps=num_samples
            )

        samples, logabsdet = self._transform.inverse(noise, context=embedded_context)

        if embedded_context is not None:
            # Split the context dimension from sample dimension.
            samples = torchutils.split_leading_dim(samples, shape=[-1, num_samples])
            logabsdet = torchutils.split_leading_dim(logabsdet, shape=[-1, num_samples])

        return samples, log_prob - logabsdet
Exemple #3
0
 def test_split_leading_dim(self):
     x = torch.randn(24, 5)
     self.assertEqual(torchutils.split_leading_dim(x, [-1]), x)
     self.assertEqual(torchutils.split_leading_dim(x, [2, -1]),
                      x.view(2, 12, 5))
     self.assertEqual(torchutils.split_leading_dim(x, [2, 3, -1]),
                      x.view(2, 3, 4, 5))
     with self.assertRaises(Exception):
         self.assertEqual(torchutils.split_leading_dim(x, []), x)
     with self.assertRaises(Exception):
         self.assertEqual(torchutils.split_leading_dim(x, [5, 5]), x)
 def test_split_merge_leading_dims_are_consistent(self):
     x = torch.randn(2, 3, 4, 5)
     y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 1), [2])
     self.assertEqual(y, x)
     y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 2), [2, 3])
     self.assertEqual(y, x)
     y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 3), [2, 3, 4])
     self.assertEqual(y, x)
     y = torchutils.split_leading_dim(
         torchutils.merge_leading_dims(x, 4), [2, 3, 4, 5]
     )
     self.assertEqual(y, x)
Exemple #5
0
 def _sample(self, num_samples, context):
     if context is None:
         return torch.randn(num_samples, *self._shape)
     else:
         # The value of the context is ignored, only its size is taken into account.
         context_size = context.shape[0]
         samples = torch.randn(context_size * num_samples, *self._shape)
         return torchutils.split_leading_dim(samples, [context_size, num_samples])
    def _sample(self, num_samples, context):   
        context_size = 1 if context is None else context.shape[0]
        low_expanded =  self._low.expand(context_size  * num_samples, *self._low.shape)
        high_expanded = self._high.expand(context_size * num_samples, *self._low.shape)
        samples = low_expanded + torch.rand(context_size * num_samples, *self._low.shape, device=self._low.device)*(high_expanded - low_expanded)

        if context is None:
            return samples
        else:
            return torchutils.split_leading_dim(samples, [context_size, num_samples])
Exemple #7
0
    def _sample(self, num_samples, context):
        # Compute parameters.
        logits = self._compute_params(context)
        probs = torch.sigmoid(logits)
        probs = torchutils.repeat_rows(probs, num_samples)

        # Generate samples.
        context_size = context.shape[0]
        noise = torch.rand(context_size * num_samples, *self._shape)
        samples = (noise < probs).float()
        return torchutils.split_leading_dim(samples, [context_size, num_samples])
Exemple #8
0
    def _sample(self, num_samples, context):
        # Compute parameters.
        means, log_stds = self._compute_params(context)
        stds = torch.exp(log_stds)
        means = torchutils.repeat_rows(means, num_samples)
        stds = torchutils.repeat_rows(stds, num_samples)

        # Generate samples.
        context_size = context.shape[0]
        noise = torch.randn(context_size * num_samples, *self._shape)
        samples = means + stds * noise
        return torchutils.split_leading_dim(samples, [context_size, num_samples])
Exemple #9
0
    def _sample(self, num_samples, context):
        embedded_context = self._embedding_net(context)
        noise = self._distribution.sample(num_samples, context=embedded_context)

        if embedded_context is not None:
            # Merge the context dimension with sample dimension in order to apply the transform.
            noise = torchutils.merge_leading_dims(noise, num_dims=2)
            embedded_context = torchutils.repeat_rows(
                embedded_context, num_reps=num_samples
            )

        samples, _ = self._transform.inverse(noise, context=embedded_context)

        if embedded_context is not None:
            # Split the context dimension from sample dimension.
            samples = torchutils.split_leading_dim(samples, shape=[-1, num_samples])

        return samples
Exemple #10
0
    def _sample(self, num_samples, context):
        if context is None:
            return self.dist.rsample(
                sample_shape=[num_samples]
            )  #torch.randn(num_samples, *self._shape, device=self._log_z.device)
        else:
            # The value of the context is ignored, only its size and device are taken into account.
            context_size = context.shape[0]
            # samples = self.dist.low + (self.dist.high - self.dist.low) * torch.rand(context_size * num_samples,
            #                                                                        *self._shape,
            #                                                                        device=context.device)

            # context_size * num_samples is used to adjust for how many cases of context we have!
            samples = self.dist.rsample(
                sample_shape=[context_size * num_samples])

            return torchutils.split_leading_dim(samples,
                                                [context_size, num_samples])