def sample(self, num_samples, context): """ Generated num_samples independent samples from p(inputs | context). NB: Generates num_samples samples for EACH item in context batch i.e. returns (num_samples * batch_size) samples in total. :param num_samples: int Number of samples to generate. :param context: torch.Tensor [batch_size, context_dim] Conditioning variable. :return: torch.Tensor [batch_size, num_samples, output_dim] Batch of generated samples. """ # Get necessary quantities. logits, means, _, _, precision_factors = self.get_mixture_components( context) batch_size, n_mixtures, output_dim = means.shape # We need (batch_size * num_samples) samples in total. means, precision_factors = ( torchutils.repeat_rows(means, num_samples), torchutils.repeat_rows(precision_factors, num_samples), ) # Normalize the logits for the coefficients. coefficients = F.softmax(logits, dim=-1) # [batch_size, num_components] # Choose num_samples mixture components per example in the batch. choices = torch.multinomial(coefficients, num_samples=num_samples, replacement=True).view( -1) # [batch_size, num_samples] # Create dummy index for indexing means and precision factors. ix = torchutils.repeat_rows(torch.arange(batch_size), num_samples) # Select means and precision factors. chosen_means = means[ix, choices, :] chosen_precision_factors = precision_factors[ix, choices, :, :] # Batch triangular solve to multiply standard normal samples by inverse # of upper triangular precision factor. arg_A = torch.randn(batch_size * num_samples, output_dim, 1) cuda_check = context.is_cuda if cuda_check: get_cuda_device = context.get_device() arg_A = arg_A.to(get_cuda_device) arg_B = chosen_precision_factors zero_mean_samples, _ = torch.triangular_solve(arg_A, arg_B) # Mow center samples at chosen means, removing dummy final dimension # from triangular solve. samples = chosen_means + zero_mean_samples.squeeze(-1) samples = samples.reshape(batch_size, num_samples, output_dim) return samples
def sample(self, num_samples: int, context: Tensor) -> Tensor: """ Return num_samples independent samples from MoG(inputs | context). Generates num_samples samples for EACH item in context batch i.e. returns (num_samples * batch_size) samples in total. Args: num_samples: Number of samples to generate. context: Conditioning variable, leading dimension is batch dimension. Returns: Generated samples: (num_samples, output_dim) with leading batch dimension. """ # Get necessary quantities. logits, means, _, _, precision_factors = self.get_mixture_components( context) batch_size, n_mixtures, output_dim = means.shape # We need (batch_size * num_samples) samples in total. means, precision_factors = ( torchutils.repeat_rows(means, num_samples), torchutils.repeat_rows(precision_factors, num_samples), ) # Normalize the logits for the coefficients. coefficients = F.softmax(logits, dim=-1) # [batch_size, num_components] # Choose num_samples mixture components per example in the batch. choices = torch.multinomial(coefficients, num_samples=num_samples, replacement=True).view( -1) # [batch_size, num_samples] # Create dummy index for indexing means and precision factors. ix = torchutils.repeat_rows(torch.arange(batch_size), num_samples) # Select means and precision factors. chosen_means = means[ix, choices, :] chosen_precision_factors = precision_factors[ix, choices, :, :] # Batch triangular solve to multiply standard normal samples by inverse # of upper triangular precision factor. zero_mean_samples, _ = torch.triangular_solve( torch.randn(batch_size * num_samples, output_dim, 1), # Need dummy final dimension. chosen_precision_factors, ) # Mow center samples at chosen means, removing dummy final dimension # from triangular solve. samples = chosen_means + zero_mean_samples.squeeze(-1) return samples.reshape(batch_size, num_samples, output_dim)
def test_repeat_rows(self): x = torch.randn(2, 3, 4, 5) self.assertEqual(torchutils.repeat_rows(x, 1), x) y = torchutils.repeat_rows(x, 2) self.assertEqual(y.shape, torch.Size([4, 3, 4, 5])) self.assertEqual(x[0], y[0]) self.assertEqual(x[0], y[1]) self.assertEqual(x[1], y[2]) self.assertEqual(x[1], y[3]) with self.assertRaises(Exception): torchutils.repeat_rows(x, 0)
def _sample(self, num_samples, context): # Compute parameters. means, log_stds = self._compute_params(context) stds = torch.exp(log_stds) means = torchutils.repeat_rows(means, num_samples) stds = torchutils.repeat_rows(stds, num_samples) # Generate samples. context_size = context.shape[0] noise = torch.randn(context_size * num_samples, *self._shape) samples = means + stds * noise return torchutils.split_leading_dim(samples, [context_size, num_samples])
def sample_and_log_prob(self, num_samples, context=None): """Generates samples from the flow, together with their log probabilities. For flows, this is more efficient that calling `sample` and `log_prob` separately. """ embedded_context = self._embedding_net(context) noise, log_prob = self._distribution.sample_and_log_prob( num_samples, context=embedded_context ) if embedded_context is not None: # Merge the context dimension with sample dimension in order to apply the transform. noise = torchutils.merge_leading_dims(noise, num_dims=2) embedded_context = torchutils.repeat_rows( embedded_context, num_reps=num_samples ) samples, logabsdet = self._transform.inverse(noise, context=embedded_context) if embedded_context is not None: # Split the context dimension from sample dimension. samples = torchutils.split_leading_dim(samples, shape=[-1, num_samples]) logabsdet = torchutils.split_leading_dim(logabsdet, shape=[-1, num_samples]) return samples, log_prob - logabsdet
def sample_and_log_prob(self, num_samples, context=None): """Generates samples from the distribution together with their log probability. Args: num_samples: int, number of samples to generate. context: Tensor or None, conditioning variables. If None, the context is ignored. Returns: A tuple of: * A Tensor containing the samples, with shape [num_samples, ...] if context is None, or [context_size, num_samples, ...] if context is given. * A Tensor containing the log probabilities of the samples, with shape [num_samples, ...] if context is None, or [context_size, num_samples, ...] if context is given. """ samples = self.sample(num_samples, context=context) if context is not None: # Merge the context dimension with sample dimension in order to call log_prob. samples = torchutils.merge_leading_dims(samples, num_dims=2) context = torchutils.repeat_rows(context, num_reps=num_samples) assert samples.shape[0] == context.shape[0] log_prob = self.log_prob(samples, context=context) if context is not None: # Split the context dimension from sample dimension. samples = torchutils.split_leading_dim(samples, shape=[-1, num_samples]) log_prob = torchutils.split_leading_dim(log_prob, shape=[-1, num_samples]) return samples, log_prob
def sample(self, num_samples, context=None): if context is not None: context = torchutils.repeat_rows(context, num_samples) with torch.no_grad(): samples = torch.zeros(context.shape[0], self.features) for feature in range(self.features): outputs = self.forward(samples, context) outputs = outputs.reshape(*samples.shape, self.num_mixture_components, 3) logits, means, unconstrained_stds = ( outputs[:, feature, :, 0], outputs[:, feature, :, 1], outputs[:, feature, :, 2], ) logits = torch.log_softmax(logits, dim=-1) stds = F.softplus(unconstrained_stds) + self.epsilon component_distribution = distributions.Categorical( logits=logits) components = component_distribution.sample( (1, )).reshape(-1, 1) means, stds = ( means.gather(1, components).reshape(-1), stds.gather(1, components).reshape(-1), ) samples[:, feature] = ( means + torch.randn(context.shape[0]) * stds).detach() return samples.reshape(-1, num_samples, self.features)
def _sample(self, num_samples, context): # Compute parameters. logits = self._compute_params(context) probs = torch.sigmoid(logits) probs = torchutils.repeat_rows(probs, num_samples) # Generate samples. context_size = context.shape[0] noise = torch.rand(context_size * num_samples, *self._shape) samples = (noise < probs).float() return torchutils.split_leading_dim(samples, [context_size, num_samples])
def _sample(self, num_samples, context): embedded_context = self._embedding_net(context) noise = self._distribution.sample(num_samples, context=embedded_context) if embedded_context is not None: # Merge the context dimension with sample dimension in order to apply the transform. noise = torchutils.merge_leading_dims(noise, num_dims=2) embedded_context = torchutils.repeat_rows( embedded_context, num_reps=num_samples ) samples, _ = self._transform.inverse(noise, context=embedded_context) if embedded_context is not None: # Split the context dimension from sample dimension. samples = torchutils.split_leading_dim(samples, shape=[-1, num_samples]) return samples