Ejemplo n.º 1
0
    def _resample_posterior(self, x, num_samples, context):

        samples, log_q_z = self.model._approximate_posterior.sample_and_log_prob(
            num_samples, context=context)
        samples = utils.merge_leading_dims(samples, num_dims=2)
        log_q_z = utils.merge_leading_dims(log_q_z, num_dims=2)

        # Compute log prob of latents under the prior.
        log_p_z = self.model._prior.log_prob(samples)

        # Compute log prob of inputs under the decoder,
        x = utils.repeat_rows(x, num_reps=num_samples)
        log_p_x = self.model._likelihood.log_prob(x, context=samples)

        # Compute ELBO.
        log_w = log_p_x + log_p_z - log_q_z
        log_w = utils.split_leading_dim(log_w, [-1, num_samples])
        log_w -= torch.logsumexp(log_w, dim=-1)[:, None]

        samples = utils.split_leading_dim(samples, [-1, num_samples])
        idx = torch.distributions.Categorical(logits=log_w).sample(
            [num_samples])

        return samples[torch.arange(len(x), device=self.device)[:, None, None],
                       idx.T[:, :, None],
                       torch.arange(self.dimensions, device=self.device)[
                           None, None, :]]
Ejemplo n.º 2
0
    def stochastic_elbo(self,
                        inputs,
                        num_samples=1,
                        kl_multiplier=1,
                        keepdim=False):
        """Calculates an unbiased Monte-Carlo estimate of the evidence lower bound.
        Note: the KL term is also estimated via Monte Carlo.
        Args:
            inputs: Tensor of shape [batch_size, ...], the inputs.
            num_samples: int, number of samples to use for the Monte-Carlo estimate.
        Returns:
            A Tensor of shape [batch_size], an ELBO estimate for each input.
        """
        # Sample latents and calculate their log prob under the encoder.
        if self._inputs_encoder is None:
            posterior_context = inputs
        else:
            posterior_context = self._inputs_encoder(inputs)

        latents, log_q_z = self._approximate_posterior.sample_and_log_prob(
            num_samples, context=posterior_context)
        latents = utils.merge_leading_dims(latents, num_dims=2)
        log_q_z = utils.merge_leading_dims(log_q_z, num_dims=2)

        # Compute log prob of latents under the prior.
        inputs = utils.repeat_rows(inputs, num_reps=num_samples)
        log_p_z = self._prior.log_prob(inputs, context=latents)

        # Compute log prob of inputs under the decoder,
        log_p_x = self._likelihood.log_prob(inputs[0] - latents)

        # Compute ELBO.
        # TODO: maybe compute KL analytically when possible?
        elbo = log_p_x + kl_multiplier * (log_p_z - log_q_z)
        elbo = utils.split_leading_dim(elbo, [-1, num_samples])
        if keepdim:
            return elbo
        else:
            return torch.sum(
                elbo, dim=1) / num_samples  # Average ELBO across samples.
Ejemplo n.º 3
0
    def sample_and_log_prob(self, num_samples, context):
        samples = self.sample(num_samples, context)

        if context is not None:
            samples = utils.merge_leading_dims(samples, num_dims=2)
            context = utils.repeat_rows(context, num_reps=num_samples)

        log_prob = self.log_prob(samples, context)

        if context is not None:
            # Split the context dimension from sample dimension.
            samples = utils.split_leading_dim(samples, shape=[-1, num_samples])
            log_prob = utils.split_leading_dim(log_prob,
                                               shape=[-1, num_samples])

        return samples, log_prob
Ejemplo n.º 4
0
 def reconstruct(self, inputs, num_samples=None, mean=False):
     """Reconstruct given inputs.
     Args:
         inputs: Tensor of shape [batch_size, ...], the inputs to reconstruct.
         num_samples: int or None, the number of reconstructions to generate per input. If None,
             only one reconstruction is generated per input.
         mean: bool, if True it uses the mean of the decoder instead of sampling from it.
     Returns:
         A Tensor of shape [batch_size, num_samples, ...] or [batch_size, ...] if num_samples
         is None, the reconstructions for each input.
     """
     latents = self.encode(inputs, num_samples)
     if num_samples is not None:
         latents = utils.merge_leading_dims(latents, num_dims=2)
     recons = self._decode(latents, mean)
     if num_samples is not None:
         recons = utils.split_leading_dim(recons, [-1, num_samples])
     return recons
Ejemplo n.º 5
0
 def encode(self, inputs, num_samples=None):
     """Encodes inputs into the latent space.
     Args:
         inputs: Tensor of shape [batch_size, ...], the inputs to encode.
         num_samples: int or None, the number of latent samples to generate per input. If None,
             only one latent sample is generated per input.
     Returns:
         A Tensor of shape [batch_size, num_samples, ...] or [batch_size, ...] if num_samples
         is None, the latent samples for each input.
     """
     if num_samples is None:
         latents = self._approximate_posterior.sample(num_samples=1,
                                                      context=inputs)
         latents = utils.merge_leading_dims(latents, num_dims=2)
     else:
         latents = self._approximate_posterior.sample(
             num_samples=num_samples, context=inputs)
     return latents
Ejemplo n.º 6
0
 def _decode(self, latents, mean):
     if mean:
         return self._likelihood.mean(context=latents)
     else:
         samples = self._likelihood.sample(num_samples=1, context=latents)
         return utils.merge_leading_dims(samples, num_dims=2)