def sample(self, num_samples, context): """ Generated num_samples independent samples from p(inputs | context). NB: Generates num_samples samples for EACH item in context batch i.e. returns (num_samples * batch_size) samples in total. :param num_samples: int Number of samples to generate. :param context: torch.Tensor [batch_size, context_dim] Conditioning variable. :return: torch.Tensor [batch_size, num_samples, output_dim] Batch of generated samples. """ # Get necessary quantities. logits, means, _, _, precision_factors = self.get_mixture_components( context) batch_size, n_mixtures, output_dim = means.shape # We need (batch_size * num_samples) samples in total. means, precision_factors = ( utils.repeat_rows(means, num_samples), utils.repeat_rows(precision_factors, num_samples), ) # Normalize the logits for the coefficients. coefficients = F.softmax(logits, dim=-1) # [batch_size, num_components] # Choose num_samples mixture components per example in the batch. choices = torch.multinomial(coefficients, num_samples=num_samples, replacement=True).view( -1) # [batch_size, num_samples] # Create dummy index for indexing means and precision factors. ix = utils.repeat_rows(torch.arange(batch_size), num_samples) # Select means and precision factors. chosen_means = means[ix, choices, :] chosen_precision_factors = precision_factors[ix, choices, :, :] # Batch triangular solve to multiply standard normal samples by inverse # of upper triangular precision factor. zero_mean_samples, _ = torch.triangular_solve( torch.randn(batch_size * num_samples, output_dim, 1), # Need dummy final dimension. chosen_precision_factors, ) # Mow center samples at chosen means, removing dummy final dimension # from triangular solve. samples = chosen_means + zero_mean_samples.squeeze(-1) return samples.reshape(batch_size, num_samples, output_dim)
def _sample(self, num_samples, context): # Compute parameters. means, log_stds = self._compute_params(context) stds = torch.exp(log_stds) means = utils.repeat_rows(means, num_samples) stds = utils.repeat_rows(stds, num_samples) # Generate samples. context_size = context.shape[0] noise = torch.randn(context_size * num_samples, *self._shape) samples = means + stds * noise return utils.split_leading_dim(samples, [context_size, num_samples])
def sample(self, num_samples, context=None): if context is not None: context = utils.repeat_rows(context, num_samples) with torch.no_grad(): samples = torch.zeros(context.shape[0], self.features) for feature in range(self.features): outputs = self.forward(samples, context) outputs = outputs.reshape(*samples.shape, self.num_mixture_components, 3) logits, means, unconstrained_stds = ( outputs[:, feature, :, 0], outputs[:, feature, :, 1], outputs[:, feature, :, 2], ) logits = torch.log_softmax(logits, dim=-1) stds = F.softplus(unconstrained_stds) + self.epsilon component_distribution = distributions.Categorical( logits=logits) components = component_distribution.sample( (1, )).reshape(-1, 1) means, stds = ( means.gather(1, components).reshape(-1), stds.gather(1, components).reshape(-1), ) samples[:, feature] = ( means + torch.randn(context.shape[0]) * stds).detach() return samples.reshape(-1, num_samples, self.features)
def sample_and_log_prob(self, num_samples, context=None): """Generates samples from the distribution together with their log probability. Args: num_samples: int, number of samples to generate. context: Tensor or None, conditioning variables. If None, the context is ignored. Returns: A tuple of: * A Tensor containing the samples, with shape [num_samples, ...] if context is None, or [context_size, num_samples, ...] if context is given. * A Tensor containing the log probabilities of the samples, with shape [num_samples, ...] if context is None, or [context_size, num_samples, ...] if context is given. """ samples = self.sample(num_samples, context=context) if context is not None: # Merge the context dimension with sample dimension in order to call log_prob. samples = utils.merge_leading_dims(samples, num_dims=2) context = utils.repeat_rows(context, num_reps=num_samples) assert samples.shape[0] == context.shape[0] log_prob = self.log_prob(samples, context=context) if context is not None: # Split the context dimension from sample dimension. samples = utils.split_leading_dim(samples, shape=[-1, num_samples]) log_prob = utils.split_leading_dim(log_prob, shape=[-1, num_samples]) return samples, log_prob
def _get_log_prob(parameters, observations): # num_atoms = parameters.shape[0] num_atoms = self._num_atoms if self._num_atoms > 0 else batch_size repeated_observations = utils.repeat_rows(observations, num_atoms) # Choose between 1 and num_atoms - 1 parameters from the rest # of the batch for each observation. assert 0 < num_atoms - 1 < batch_size probs = ((1 / (batch_size - 1)) * torch.ones(batch_size, batch_size) * (1 - torch.eye(batch_size))) choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False) contrasting_parameters = parameters[choices] atomic_parameters = torch.cat( (parameters[:, None, :], contrasting_parameters), dim=1).reshape(batch_size * num_atoms, -1) inputs = torch.cat((atomic_parameters, repeated_observations), dim=1) logits = self._classifier(inputs).reshape(batch_size, num_atoms) log_prob = logits[:, 0] - torch.logsumexp(logits, dim=-1) return log_prob
def _sample(self, num_samples, context): # Compute parameters. logits = self._compute_params(context) probs = torch.sigmoid(logits) probs = utils.repeat_rows(probs, num_samples) # Generate samples. context_size = context.shape[0] noise = torch.rand(context_size * num_samples, *self._shape) samples = (noise < probs).float() return utils.split_leading_dim(samples, [context_size, num_samples])
def _sample(self, num_samples, context): noise = self._distribution.sample(num_samples, context=context) if context is not None: # Merge the context dimension with sample dimension in order to apply the transform. noise = utils.merge_leading_dims(noise, num_dims=2) context = utils.repeat_rows(context, num_reps=num_samples) samples, _ = self._transform.inverse(noise, context=context) if context is not None: # Split the context dimension from sample dimension. samples = utils.split_leading_dim(samples, shape=[-1, num_samples]) return samples
def sample_and_log_prob(self, num_samples, context=None): """Generates samples from the flow, together with their log probabilities. For flows, this is more efficient that calling `sample` and `log_prob` separately. """ noise, log_prob = self._distribution.sample_and_log_prob( num_samples, context=context) if context is not None: # Merge the context dimension with sample dimension in order to apply the transform. noise = utils.merge_leading_dims(noise, num_dims=2) context = utils.repeat_rows(context, num_reps=num_samples) samples, logabsdet = self._transform.inverse(noise, context=context) if context is not None: # Split the context dimension from sample dimension. samples = utils.split_leading_dim(samples, shape=[-1, num_samples]) logabsdet = utils.split_leading_dim(logabsdet, shape=[-1, num_samples]) return samples, log_prob - logabsdet
def _get_log_prob_proposal_posterior(inputs, context, masks): """ We have two main options when evaluating the proposal posterior. (1) Generate atoms from the proposal prior. (2) Generate atoms from a more targeted distribution, such as the most recent posterior. If we choose the latter, it is likely beneficial not to do this in the first round, since we would be sampling from a randomly initialized neural density estimator. :param inputs: torch.Tensor Batch of parameters. :param context: torch.Tensor Batch of observations. :return: torch.Tensor [1] log_prob_proposal_posterior """ log_prob_posterior_non_atomic = self._neural_posterior.log_prob( inputs, context) # just do maximum likelihood in the first round if round_ == 0: return log_prob_posterior_non_atomic num_atoms = self._num_atoms if self._num_atoms > 0 else batch_size # Each set of parameter atoms is evaluated using the same observation, # so we repeat rows of the context. # e.g. [1, 2] -> [1, 1, 2, 2] repeated_context = utils.repeat_rows(context, num_atoms) # To generate the full set of atoms for a given item in the batch, # we sample without replacement num_atoms - 1 times from the rest # of the parameters in the batch. assert 0 < num_atoms - 1 < batch_size probs = ((1 / (batch_size - 1)) * torch.ones(batch_size, batch_size) * (1 - torch.eye(batch_size))) choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False) contrasting_inputs = inputs[choices] # We can now create our sets of atoms from the contrasting parameter sets # we have generated. atomic_inputs = torch.cat((inputs[:, None, :], contrasting_inputs), dim=1).reshape(batch_size * num_atoms, -1) # Evaluate large batch giving (batch_size * num_atoms) log prob posterior evals. log_prob_posterior = self._neural_posterior.log_prob( atomic_inputs, repeated_context) assert utils.notinfnotnan( log_prob_posterior), "NaN/inf detected in posterior eval." log_prob_posterior = log_prob_posterior.reshape( batch_size, num_atoms) # Get (batch_size * num_atoms) log prob prior evals. if isinstance(self._prior, distributions.Uniform): log_prob_prior = self._prior.log_prob(atomic_inputs).sum(-1) # log_prob_prior = torch.zeros(log_prob_prior.shape) else: log_prob_prior = self._prior.log_prob(atomic_inputs) log_prob_prior = log_prob_prior.reshape(batch_size, num_atoms) assert utils.notinfnotnan( log_prob_prior), "NaN/inf detected in prior eval." # Compute unnormalized proposal posterior. unnormalized_log_prob_proposal_posterior = (log_prob_posterior - log_prob_prior) # Normalize proposal posterior across discrete set of atoms. log_prob_proposal_posterior = unnormalized_log_prob_proposal_posterior[:, 0] - torch.logsumexp( unnormalized_log_prob_proposal_posterior, dim=-1) assert utils.notinfnotnan( log_prob_proposal_posterior ), "NaN/inf detected in proposal posterior eval." if self._use_combined_loss: masks = masks.reshape(-1) log_prob_proposal_posterior = ( masks * log_prob_posterior_non_atomic + log_prob_proposal_posterior) return log_prob_proposal_posterior