Пример #1
0
    def sample_and_log_prob(self, num_samples, context=None):
        """Generates samples from the distribution together with their log probability.

        Args:
            num_samples: int, number of samples to generate.
            context: Tensor or None, conditioning variables. If None, the context is ignored.

        Returns:
            A tuple of:
                * A Tensor containing the samples, with shape [num_samples, ...] if context is None,
                  or [context_size, num_samples, ...] if context is given.
                * A Tensor containing the log probabilities of the samples, with shape
                  [num_samples, ...] if context is None, or [context_size, num_samples, ...] if
                  context is given.
        """
        samples = self.sample(num_samples, context=context)

        if context is not None:
            # Merge the context dimension with sample dimension in order to call log_prob.
            samples = utils.merge_leading_dims(samples, num_dims=2)
            context = utils.repeat_rows(context, num_reps=num_samples)
            assert samples.shape[0] == context.shape[0]

        log_prob = self.log_prob(samples, context=context)

        if context is not None:
            # Split the context dimension from sample dimension.
            samples = utils.split_leading_dim(samples, shape=[-1, num_samples])
            log_prob = utils.split_leading_dim(log_prob, shape=[-1, num_samples])

        return samples, log_prob
Пример #2
0
 def _sample(self, num_samples, context):
     if context is None:
         return torch.randn(num_samples, *self._shape)
     else:
         # The value of the context is ignored, only its size is taken into account.
         context_size = context.shape[0]
         samples = torch.randn(context_size * num_samples, *self._shape)
         return utils.split_leading_dim(samples, [context_size, num_samples])
Пример #3
0
    def _sample(self, num_samples, context):
        # Compute parameters.
        logits = self._compute_params(context)
        probs = torch.sigmoid(logits)
        probs = utils.repeat_rows(probs, num_samples)

        # Generate samples.
        context_size = context.shape[0]
        noise = torch.rand(context_size * num_samples, *self._shape)
        samples = (noise < probs).float()
        return utils.split_leading_dim(samples, [context_size, num_samples])
Пример #4
0
    def _sample(self, num_samples, context):
        # Compute parameters.
        means, log_stds = self._compute_params(context)
        stds = torch.exp(log_stds)
        means = utils.repeat_rows(means, num_samples)
        stds = utils.repeat_rows(stds, num_samples)

        # Generate samples.
        context_size = context.shape[0]
        noise = torch.randn(context_size * num_samples, *self._shape)
        samples = means + stds * noise
        return utils.split_leading_dim(samples, [context_size, num_samples])
Пример #5
0
    def sample_and_log_prob(self, num_samples, context=None):
        """Generates samples from the flow, together with their log probabilities.

        For flows, this is more efficient that calling `sample` and `log_prob` separately.
        """
        noise, log_prob = self._distribution.sample_and_log_prob(
            num_samples, context=context)

        if context is not None:
            # Merge the context dimension with sample dimension in order to apply the transform.
            noise = utils.merge_leading_dims(noise, num_dims=2)
            context = utils.repeat_rows(context, num_reps=num_samples)

        samples, logabsdet = self._transform.inverse(noise, context=context)

        if context is not None:
            # Split the context dimension from sample dimension.
            samples = utils.split_leading_dim(samples, shape=[-1, num_samples])
            logabsdet = utils.split_leading_dim(logabsdet,
                                                shape=[-1, num_samples])

        return samples, log_prob - logabsdet
Пример #6
0
    def _sample(self, num_samples, context):
        noise = self._distribution.sample(num_samples, context=context)

        if context is not None:
            # Merge the context dimension with sample dimension in order to apply the transform.
            noise = utils.merge_leading_dims(noise, num_dims=2)
            context = utils.repeat_rows(context, num_reps=num_samples)

        samples, _ = self._transform.inverse(noise, context=context)

        if context is not None:
            # Split the context dimension from sample dimension.
            samples = utils.split_leading_dim(samples, shape=[-1, num_samples])

        return samples
Пример #7
0
    def stochastic_elbo(self,
                        inputs,
                        num_samples=1,
                        kl_multiplier=1,
                        keepdim=False):
        """Calculates an unbiased Monte-Carlo estimate of the evidence lower bound.

        Note: the KL term is also estimated via Monte Carlo.

        Args:
            inputs: Tensor of shape [batch_size, ...], the inputs.
            num_samples: int, number of samples to use for the Monte-Carlo estimate.

        Returns:
            A Tensor of shape [batch_size], an ELBO estimate for each input.
        """
        # Sample latents and calculate their log prob under the encoder.
        if self._inputs_encoder is None:
            posterior_context = inputs
        else:
            posterior_context = self._inputs_encoder(inputs)
        latents, log_q_z = self._approximate_posterior.sample_and_log_prob(
            num_samples, context=posterior_context)
        latents = utils.merge_leading_dims(latents, num_dims=2)
        log_q_z = utils.merge_leading_dims(log_q_z, num_dims=2)

        # Compute log prob of latents under the prior.
        log_p_z = self._prior.log_prob(latents)

        # Compute log prob of inputs under the decoder,
        inputs = utils.repeat_rows(inputs, num_reps=num_samples)
        log_p_x = self._likelihood.log_prob(inputs, context=latents)

        # Compute ELBO.
        # TODO: maybe compute KL analytically when possible?
        elbo = log_p_x + kl_multiplier * (log_p_z - log_q_z)
        elbo = utils.split_leading_dim(elbo, [-1, num_samples])
        if keepdim:
            return elbo
        else:
            return torch.sum(
                elbo, dim=1) / num_samples  # Average ELBO across samples.
Пример #8
0
    def reconstruct(self, inputs, num_samples=None, mean=False):
        """Reconstruct given inputs.

        Args:
            inputs: Tensor of shape [batch_size, ...], the inputs to reconstruct.
            num_samples: int or None, the number of reconstructions to generate per input. If None,
                only one reconstruction is generated per input.
            mean: bool, if True it uses the mean of the decoder instead of sampling from it.

        Returns:
            A Tensor of shape [batch_size, num_samples, ...] or [batch_size, ...] if num_samples
            is None, the reconstructions for each input.
        """
        latents = self.encode(inputs, num_samples)
        if num_samples is not None:
            latents = utils.merge_leading_dims(latents, num_dims=2)
        recons = self._decode(latents, mean)
        if num_samples is not None:
            recons = utils.split_leading_dim(recons, [-1, num_samples])
        return recons