def sample_and_log_prob(self, num_samples, context=None): """Generates samples from the distribution together with their log probability. Args: num_samples: int, number of samples to generate. context: Tensor or None, conditioning variables. If None, the context is ignored. Returns: A tuple of: * A Tensor containing the samples, with shape [num_samples, ...] if context is None, or [context_size, num_samples, ...] if context is given. * A Tensor containing the log probabilities of the samples, with shape [num_samples, ...] if context is None, or [context_size, num_samples, ...] if context is given. """ samples = self.sample(num_samples, context=context) if context is not None: # Merge the context dimension with sample dimension in order to call log_prob. samples = utils.merge_leading_dims(samples, num_dims=2) context = utils.repeat_rows(context, num_reps=num_samples) assert samples.shape[0] == context.shape[0] log_prob = self.log_prob(samples, context=context) if context is not None: # Split the context dimension from sample dimension. samples = utils.split_leading_dim(samples, shape=[-1, num_samples]) log_prob = utils.split_leading_dim(log_prob, shape=[-1, num_samples]) return samples, log_prob
def _sample(self, num_samples, context): if context is None: return torch.randn(num_samples, *self._shape) else: # The value of the context is ignored, only its size is taken into account. context_size = context.shape[0] samples = torch.randn(context_size * num_samples, *self._shape) return utils.split_leading_dim(samples, [context_size, num_samples])
def _sample(self, num_samples, context): # Compute parameters. logits = self._compute_params(context) probs = torch.sigmoid(logits) probs = utils.repeat_rows(probs, num_samples) # Generate samples. context_size = context.shape[0] noise = torch.rand(context_size * num_samples, *self._shape) samples = (noise < probs).float() return utils.split_leading_dim(samples, [context_size, num_samples])
def _sample(self, num_samples, context): # Compute parameters. means, log_stds = self._compute_params(context) stds = torch.exp(log_stds) means = utils.repeat_rows(means, num_samples) stds = utils.repeat_rows(stds, num_samples) # Generate samples. context_size = context.shape[0] noise = torch.randn(context_size * num_samples, *self._shape) samples = means + stds * noise return utils.split_leading_dim(samples, [context_size, num_samples])
def sample_and_log_prob(self, num_samples, context=None): """Generates samples from the flow, together with their log probabilities. For flows, this is more efficient that calling `sample` and `log_prob` separately. """ noise, log_prob = self._distribution.sample_and_log_prob( num_samples, context=context) if context is not None: # Merge the context dimension with sample dimension in order to apply the transform. noise = utils.merge_leading_dims(noise, num_dims=2) context = utils.repeat_rows(context, num_reps=num_samples) samples, logabsdet = self._transform.inverse(noise, context=context) if context is not None: # Split the context dimension from sample dimension. samples = utils.split_leading_dim(samples, shape=[-1, num_samples]) logabsdet = utils.split_leading_dim(logabsdet, shape=[-1, num_samples]) return samples, log_prob - logabsdet
def _sample(self, num_samples, context): noise = self._distribution.sample(num_samples, context=context) if context is not None: # Merge the context dimension with sample dimension in order to apply the transform. noise = utils.merge_leading_dims(noise, num_dims=2) context = utils.repeat_rows(context, num_reps=num_samples) samples, _ = self._transform.inverse(noise, context=context) if context is not None: # Split the context dimension from sample dimension. samples = utils.split_leading_dim(samples, shape=[-1, num_samples]) return samples
def stochastic_elbo(self, inputs, num_samples=1, kl_multiplier=1, keepdim=False): """Calculates an unbiased Monte-Carlo estimate of the evidence lower bound. Note: the KL term is also estimated via Monte Carlo. Args: inputs: Tensor of shape [batch_size, ...], the inputs. num_samples: int, number of samples to use for the Monte-Carlo estimate. Returns: A Tensor of shape [batch_size], an ELBO estimate for each input. """ # Sample latents and calculate their log prob under the encoder. if self._inputs_encoder is None: posterior_context = inputs else: posterior_context = self._inputs_encoder(inputs) latents, log_q_z = self._approximate_posterior.sample_and_log_prob( num_samples, context=posterior_context) latents = utils.merge_leading_dims(latents, num_dims=2) log_q_z = utils.merge_leading_dims(log_q_z, num_dims=2) # Compute log prob of latents under the prior. log_p_z = self._prior.log_prob(latents) # Compute log prob of inputs under the decoder, inputs = utils.repeat_rows(inputs, num_reps=num_samples) log_p_x = self._likelihood.log_prob(inputs, context=latents) # Compute ELBO. # TODO: maybe compute KL analytically when possible? elbo = log_p_x + kl_multiplier * (log_p_z - log_q_z) elbo = utils.split_leading_dim(elbo, [-1, num_samples]) if keepdim: return elbo else: return torch.sum( elbo, dim=1) / num_samples # Average ELBO across samples.
def reconstruct(self, inputs, num_samples=None, mean=False): """Reconstruct given inputs. Args: inputs: Tensor of shape [batch_size, ...], the inputs to reconstruct. num_samples: int or None, the number of reconstructions to generate per input. If None, only one reconstruction is generated per input. mean: bool, if True it uses the mean of the decoder instead of sampling from it. Returns: A Tensor of shape [batch_size, num_samples, ...] or [batch_size, ...] if num_samples is None, the reconstructions for each input. """ latents = self.encode(inputs, num_samples) if num_samples is not None: latents = utils.merge_leading_dims(latents, num_dims=2) recons = self._decode(latents, mean) if num_samples is not None: recons = utils.split_leading_dim(recons, [-1, num_samples]) return recons