コード例 #1
0
ファイル: mdn.py プロジェクト: plcrodrigues/pyknos
    def sample(self, num_samples, context):
        """
        Generated num_samples independent samples from p(inputs | context).
        NB: Generates num_samples samples for EACH item in context batch i.e. returns
        (num_samples * batch_size) samples in total.

        :param num_samples: int
            Number of samples to generate.
        :param context: torch.Tensor [batch_size, context_dim]
            Conditioning variable.
        :return: torch.Tensor [batch_size, num_samples, output_dim]
            Batch of generated samples.
        """

        # Get necessary quantities.
        logits, means, _, _, precision_factors = self.get_mixture_components(
            context)
        batch_size, n_mixtures, output_dim = means.shape

        # We need (batch_size * num_samples) samples in total.
        means, precision_factors = (
            torchutils.repeat_rows(means, num_samples),
            torchutils.repeat_rows(precision_factors, num_samples),
        )

        # Normalize the logits for the coefficients.
        coefficients = F.softmax(logits,
                                 dim=-1)  # [batch_size, num_components]

        # Choose num_samples mixture components per example in the batch.
        choices = torch.multinomial(coefficients,
                                    num_samples=num_samples,
                                    replacement=True).view(
                                        -1)  # [batch_size, num_samples]

        # Create dummy index for indexing means and precision factors.
        ix = torchutils.repeat_rows(torch.arange(batch_size), num_samples)

        # Select means and precision factors.
        chosen_means = means[ix, choices, :]
        chosen_precision_factors = precision_factors[ix, choices, :, :]

        # Batch triangular solve to multiply standard normal samples by inverse
        # of upper triangular precision factor.
        arg_A = torch.randn(batch_size * num_samples, output_dim, 1)
        cuda_check = context.is_cuda
        if cuda_check:
            get_cuda_device = context.get_device()
            arg_A = arg_A.to(get_cuda_device)
        arg_B = chosen_precision_factors
        zero_mean_samples, _ = torch.triangular_solve(arg_A, arg_B)

        # Mow center samples at chosen means, removing dummy final dimension
        # from triangular solve.
        samples = chosen_means + zero_mean_samples.squeeze(-1)
        samples = samples.reshape(batch_size, num_samples, output_dim)

        return samples
コード例 #2
0
    def sample(self, num_samples: int, context: Tensor) -> Tensor:
        """
        Return num_samples independent samples from MoG(inputs | context).

        Generates num_samples samples for EACH item in context batch i.e. returns
        (num_samples * batch_size) samples in total.

        Args:
            num_samples: Number of samples to generate.
            context: Conditioning variable, leading dimension is batch dimension.

        Returns:
            Generated samples: (num_samples, output_dim) with leading batch dimension.
        """

        # Get necessary quantities.
        logits, means, _, _, precision_factors = self.get_mixture_components(
            context)
        batch_size, n_mixtures, output_dim = means.shape

        # We need (batch_size * num_samples) samples in total.
        means, precision_factors = (
            torchutils.repeat_rows(means, num_samples),
            torchutils.repeat_rows(precision_factors, num_samples),
        )

        # Normalize the logits for the coefficients.
        coefficients = F.softmax(logits,
                                 dim=-1)  # [batch_size, num_components]

        # Choose num_samples mixture components per example in the batch.
        choices = torch.multinomial(coefficients,
                                    num_samples=num_samples,
                                    replacement=True).view(
                                        -1)  # [batch_size, num_samples]

        # Create dummy index for indexing means and precision factors.
        ix = torchutils.repeat_rows(torch.arange(batch_size), num_samples)

        # Select means and precision factors.
        chosen_means = means[ix, choices, :]
        chosen_precision_factors = precision_factors[ix, choices, :, :]

        # Batch triangular solve to multiply standard normal samples by inverse
        # of upper triangular precision factor.
        zero_mean_samples, _ = torch.triangular_solve(
            torch.randn(batch_size * num_samples, output_dim,
                        1),  # Need dummy final dimension.
            chosen_precision_factors,
        )

        # Mow center samples at chosen means, removing dummy final dimension
        # from triangular solve.
        samples = chosen_means + zero_mean_samples.squeeze(-1)

        return samples.reshape(batch_size, num_samples, output_dim)
コード例 #3
0
 def test_repeat_rows(self):
     x = torch.randn(2, 3, 4, 5)
     self.assertEqual(torchutils.repeat_rows(x, 1), x)
     y = torchutils.repeat_rows(x, 2)
     self.assertEqual(y.shape, torch.Size([4, 3, 4, 5]))
     self.assertEqual(x[0], y[0])
     self.assertEqual(x[0], y[1])
     self.assertEqual(x[1], y[2])
     self.assertEqual(x[1], y[3])
     with self.assertRaises(Exception):
         torchutils.repeat_rows(x, 0)
コード例 #4
0
    def _sample(self, num_samples, context):
        # Compute parameters.
        means, log_stds = self._compute_params(context)
        stds = torch.exp(log_stds)
        means = torchutils.repeat_rows(means, num_samples)
        stds = torchutils.repeat_rows(stds, num_samples)

        # Generate samples.
        context_size = context.shape[0]
        noise = torch.randn(context_size * num_samples, *self._shape)
        samples = means + stds * noise
        return torchutils.split_leading_dim(samples, [context_size, num_samples])
コード例 #5
0
ファイル: base.py プロジェクト: zhangJianfeng/nflows
    def sample_and_log_prob(self, num_samples, context=None):
        """Generates samples from the flow, together with their log probabilities.

        For flows, this is more efficient that calling `sample` and `log_prob` separately.
        """
        embedded_context = self._embedding_net(context)
        noise, log_prob = self._distribution.sample_and_log_prob(
            num_samples, context=embedded_context
        )

        if embedded_context is not None:
            # Merge the context dimension with sample dimension in order to apply the transform.
            noise = torchutils.merge_leading_dims(noise, num_dims=2)
            embedded_context = torchutils.repeat_rows(
                embedded_context, num_reps=num_samples
            )

        samples, logabsdet = self._transform.inverse(noise, context=embedded_context)

        if embedded_context is not None:
            # Split the context dimension from sample dimension.
            samples = torchutils.split_leading_dim(samples, shape=[-1, num_samples])
            logabsdet = torchutils.split_leading_dim(logabsdet, shape=[-1, num_samples])

        return samples, log_prob - logabsdet
    def sample_and_log_prob(self, num_samples, context=None):
        """Generates samples from the distribution together with their log probability.

        Args:
            num_samples: int, number of samples to generate.
            context: Tensor or None, conditioning variables. If None, the context is ignored.

        Returns:
            A tuple of:
                * A Tensor containing the samples, with shape [num_samples, ...] if context is None,
                  or [context_size, num_samples, ...] if context is given.
                * A Tensor containing the log probabilities of the samples, with shape
                  [num_samples, ...] if context is None, or [context_size, num_samples, ...] if
                  context is given.
        """
        samples = self.sample(num_samples, context=context)

        if context is not None:
            # Merge the context dimension with sample dimension in order to call log_prob.
            samples = torchutils.merge_leading_dims(samples, num_dims=2)
            context = torchutils.repeat_rows(context, num_reps=num_samples)
            assert samples.shape[0] == context.shape[0]

        log_prob = self.log_prob(samples, context=context)

        if context is not None:
            # Split the context dimension from sample dimension.
            samples = torchutils.split_leading_dim(samples,
                                                   shape=[-1, num_samples])
            log_prob = torchutils.split_leading_dim(log_prob,
                                                    shape=[-1, num_samples])

        return samples, log_prob
コード例 #7
0
    def sample(self, num_samples, context=None):

        if context is not None:
            context = torchutils.repeat_rows(context, num_samples)

        with torch.no_grad():

            samples = torch.zeros(context.shape[0], self.features)

            for feature in range(self.features):
                outputs = self.forward(samples, context)
                outputs = outputs.reshape(*samples.shape,
                                          self.num_mixture_components, 3)

                logits, means, unconstrained_stds = (
                    outputs[:, feature, :, 0],
                    outputs[:, feature, :, 1],
                    outputs[:, feature, :, 2],
                )
                logits = torch.log_softmax(logits, dim=-1)
                stds = F.softplus(unconstrained_stds) + self.epsilon

                component_distribution = distributions.Categorical(
                    logits=logits)
                components = component_distribution.sample(
                    (1, )).reshape(-1, 1)
                means, stds = (
                    means.gather(1, components).reshape(-1),
                    stds.gather(1, components).reshape(-1),
                )
                samples[:, feature] = (
                    means + torch.randn(context.shape[0]) * stds).detach()

        return samples.reshape(-1, num_samples, self.features)
コード例 #8
0
ファイル: discrete.py プロジェクト: zhangJianfeng/nflows
    def _sample(self, num_samples, context):
        # Compute parameters.
        logits = self._compute_params(context)
        probs = torch.sigmoid(logits)
        probs = torchutils.repeat_rows(probs, num_samples)

        # Generate samples.
        context_size = context.shape[0]
        noise = torch.rand(context_size * num_samples, *self._shape)
        samples = (noise < probs).float()
        return torchutils.split_leading_dim(samples, [context_size, num_samples])
コード例 #9
0
ファイル: base.py プロジェクト: zhangJianfeng/nflows
    def _sample(self, num_samples, context):
        embedded_context = self._embedding_net(context)
        noise = self._distribution.sample(num_samples, context=embedded_context)

        if embedded_context is not None:
            # Merge the context dimension with sample dimension in order to apply the transform.
            noise = torchutils.merge_leading_dims(noise, num_dims=2)
            embedded_context = torchutils.repeat_rows(
                embedded_context, num_reps=num_samples
            )

        samples, _ = self._transform.inverse(noise, context=embedded_context)

        if embedded_context is not None:
            # Split the context dimension from sample dimension.
            samples = torchutils.split_leading_dim(samples, shape=[-1, num_samples])

        return samples