예제 #1
0
파일: mdn.py 프로젝트: plcrodrigues/lfi
    def sample(self, num_samples, context):
        """
        Generated num_samples independent samples from p(inputs | context).
        NB: Generates num_samples samples for EACH item in context batch i.e. returns
        (num_samples * batch_size) samples in total.

        :param num_samples: int
            Number of samples to generate.
        :param context: torch.Tensor [batch_size, context_dim]
            Conditioning variable.
        :return: torch.Tensor [batch_size, num_samples, output_dim]
            Batch of generated samples.
        """

        # Get necessary quantities.
        logits, means, _, _, precision_factors = self.get_mixture_components(
            context)
        batch_size, n_mixtures, output_dim = means.shape

        # We need (batch_size * num_samples) samples in total.
        means, precision_factors = (
            utils.repeat_rows(means, num_samples),
            utils.repeat_rows(precision_factors, num_samples),
        )

        # Normalize the logits for the coefficients.
        coefficients = F.softmax(logits,
                                 dim=-1)  # [batch_size, num_components]

        # Choose num_samples mixture components per example in the batch.
        choices = torch.multinomial(coefficients,
                                    num_samples=num_samples,
                                    replacement=True).view(
                                        -1)  # [batch_size, num_samples]

        # Create dummy index for indexing means and precision factors.
        ix = utils.repeat_rows(torch.arange(batch_size), num_samples)

        # Select means and precision factors.
        chosen_means = means[ix, choices, :]
        chosen_precision_factors = precision_factors[ix, choices, :, :]

        # Batch triangular solve to multiply standard normal samples by inverse
        # of upper triangular precision factor.
        zero_mean_samples, _ = torch.triangular_solve(
            torch.randn(batch_size * num_samples, output_dim,
                        1),  # Need dummy final dimension.
            chosen_precision_factors,
        )

        # Mow center samples at chosen means, removing dummy final dimension
        # from triangular solve.
        samples = chosen_means + zero_mean_samples.squeeze(-1)

        return samples.reshape(batch_size, num_samples, output_dim)
예제 #2
0
파일: normal.py 프로젝트: plcrodrigues/lfi
    def _sample(self, num_samples, context):
        # Compute parameters.
        means, log_stds = self._compute_params(context)
        stds = torch.exp(log_stds)
        means = utils.repeat_rows(means, num_samples)
        stds = utils.repeat_rows(stds, num_samples)

        # Generate samples.
        context_size = context.shape[0]
        noise = torch.randn(context_size * num_samples, *self._shape)
        samples = means + stds * noise
        return utils.split_leading_dim(samples, [context_size, num_samples])
예제 #3
0
    def sample(self, num_samples, context=None):

        if context is not None:
            context = utils.repeat_rows(context, num_samples)

        with torch.no_grad():

            samples = torch.zeros(context.shape[0], self.features)

            for feature in range(self.features):
                outputs = self.forward(samples, context)
                outputs = outputs.reshape(*samples.shape,
                                          self.num_mixture_components, 3)

                logits, means, unconstrained_stds = (
                    outputs[:, feature, :, 0],
                    outputs[:, feature, :, 1],
                    outputs[:, feature, :, 2],
                )
                logits = torch.log_softmax(logits, dim=-1)
                stds = F.softplus(unconstrained_stds) + self.epsilon

                component_distribution = distributions.Categorical(
                    logits=logits)
                components = component_distribution.sample(
                    (1, )).reshape(-1, 1)
                means, stds = (
                    means.gather(1, components).reshape(-1),
                    stds.gather(1, components).reshape(-1),
                )
                samples[:, feature] = (
                    means + torch.randn(context.shape[0]) * stds).detach()

        return samples.reshape(-1, num_samples, self.features)
예제 #4
0
파일: base.py 프로젝트: plcrodrigues/lfi
    def sample_and_log_prob(self, num_samples, context=None):
        """Generates samples from the distribution together with their log probability.

        Args:
            num_samples: int, number of samples to generate.
            context: Tensor or None, conditioning variables. If None, the context is ignored.

        Returns:
            A tuple of:
                * A Tensor containing the samples, with shape [num_samples, ...] if context is None,
                  or [context_size, num_samples, ...] if context is given.
                * A Tensor containing the log probabilities of the samples, with shape
                  [num_samples, ...] if context is None, or [context_size, num_samples, ...] if
                  context is given.
        """
        samples = self.sample(num_samples, context=context)

        if context is not None:
            # Merge the context dimension with sample dimension in order to call log_prob.
            samples = utils.merge_leading_dims(samples, num_dims=2)
            context = utils.repeat_rows(context, num_reps=num_samples)
            assert samples.shape[0] == context.shape[0]

        log_prob = self.log_prob(samples, context=context)

        if context is not None:
            # Split the context dimension from sample dimension.
            samples = utils.split_leading_dim(samples, shape=[-1, num_samples])
            log_prob = utils.split_leading_dim(log_prob,
                                               shape=[-1, num_samples])

        return samples, log_prob
예제 #5
0
파일: sre.py 프로젝트: plcrodrigues/lfi
        def _get_log_prob(parameters, observations):

            # num_atoms = parameters.shape[0]
            num_atoms = self._num_atoms if self._num_atoms > 0 else batch_size

            repeated_observations = utils.repeat_rows(observations, num_atoms)

            # Choose between 1 and num_atoms - 1 parameters from the rest
            # of the batch for each observation.
            assert 0 < num_atoms - 1 < batch_size
            probs = ((1 /
                      (batch_size - 1)) * torch.ones(batch_size, batch_size) *
                     (1 - torch.eye(batch_size)))
            choices = torch.multinomial(probs,
                                        num_samples=num_atoms - 1,
                                        replacement=False)
            contrasting_parameters = parameters[choices]

            atomic_parameters = torch.cat(
                (parameters[:, None, :], contrasting_parameters),
                dim=1).reshape(batch_size * num_atoms, -1)

            inputs = torch.cat((atomic_parameters, repeated_observations),
                               dim=1)

            logits = self._classifier(inputs).reshape(batch_size, num_atoms)

            log_prob = logits[:, 0] - torch.logsumexp(logits, dim=-1)

            return log_prob
예제 #6
0
    def _sample(self, num_samples, context):
        # Compute parameters.
        logits = self._compute_params(context)
        probs = torch.sigmoid(logits)
        probs = utils.repeat_rows(probs, num_samples)

        # Generate samples.
        context_size = context.shape[0]
        noise = torch.rand(context_size * num_samples, *self._shape)
        samples = (noise < probs).float()
        return utils.split_leading_dim(samples, [context_size, num_samples])
예제 #7
0
    def _sample(self, num_samples, context):
        noise = self._distribution.sample(num_samples, context=context)

        if context is not None:
            # Merge the context dimension with sample dimension in order to apply the transform.
            noise = utils.merge_leading_dims(noise, num_dims=2)
            context = utils.repeat_rows(context, num_reps=num_samples)

        samples, _ = self._transform.inverse(noise, context=context)

        if context is not None:
            # Split the context dimension from sample dimension.
            samples = utils.split_leading_dim(samples, shape=[-1, num_samples])

        return samples
예제 #8
0
    def sample_and_log_prob(self, num_samples, context=None):
        """Generates samples from the flow, together with their log probabilities.

        For flows, this is more efficient that calling `sample` and `log_prob` separately.
        """
        noise, log_prob = self._distribution.sample_and_log_prob(
            num_samples, context=context)

        if context is not None:
            # Merge the context dimension with sample dimension in order to apply the transform.
            noise = utils.merge_leading_dims(noise, num_dims=2)
            context = utils.repeat_rows(context, num_reps=num_samples)

        samples, logabsdet = self._transform.inverse(noise, context=context)

        if context is not None:
            # Split the context dimension from sample dimension.
            samples = utils.split_leading_dim(samples, shape=[-1, num_samples])
            logabsdet = utils.split_leading_dim(logabsdet,
                                                shape=[-1, num_samples])

        return samples, log_prob - logabsdet
예제 #9
0
        def _get_log_prob_proposal_posterior(inputs, context, masks):
            """
            We have two main options when evaluating the proposal posterior.
            (1) Generate atoms from the proposal prior.
            (2) Generate atoms from a more targeted distribution,
            such as the most recent posterior.
            If we choose the latter, it is likely beneficial not to do this in the first
            round, since we would be sampling from a randomly initialized neural density
            estimator.

            :param inputs: torch.Tensor Batch of parameters.
            :param context: torch.Tensor Batch of observations.
            :return: torch.Tensor [1] log_prob_proposal_posterior
            """

            log_prob_posterior_non_atomic = self._neural_posterior.log_prob(
                inputs, context)

            # just do maximum likelihood in the first round
            if round_ == 0:
                return log_prob_posterior_non_atomic

            num_atoms = self._num_atoms if self._num_atoms > 0 else batch_size

            # Each set of parameter atoms is evaluated using the same observation,
            # so we repeat rows of the context.
            # e.g. [1, 2] -> [1, 1, 2, 2]
            repeated_context = utils.repeat_rows(context, num_atoms)

            # To generate the full set of atoms for a given item in the batch,
            # we sample without replacement num_atoms - 1 times from the rest
            # of the parameters in the batch.
            assert 0 < num_atoms - 1 < batch_size
            probs = ((1 /
                      (batch_size - 1)) * torch.ones(batch_size, batch_size) *
                     (1 - torch.eye(batch_size)))
            choices = torch.multinomial(probs,
                                        num_samples=num_atoms - 1,
                                        replacement=False)
            contrasting_inputs = inputs[choices]

            # We can now create our sets of atoms from the contrasting parameter sets
            # we have generated.
            atomic_inputs = torch.cat((inputs[:, None, :], contrasting_inputs),
                                      dim=1).reshape(batch_size * num_atoms,
                                                     -1)

            # Evaluate large batch giving (batch_size * num_atoms) log prob posterior evals.
            log_prob_posterior = self._neural_posterior.log_prob(
                atomic_inputs, repeated_context)
            assert utils.notinfnotnan(
                log_prob_posterior), "NaN/inf detected in posterior eval."
            log_prob_posterior = log_prob_posterior.reshape(
                batch_size, num_atoms)

            # Get (batch_size * num_atoms) log prob prior evals.
            if isinstance(self._prior, distributions.Uniform):
                log_prob_prior = self._prior.log_prob(atomic_inputs).sum(-1)
                # log_prob_prior = torch.zeros(log_prob_prior.shape)
            else:
                log_prob_prior = self._prior.log_prob(atomic_inputs)
            log_prob_prior = log_prob_prior.reshape(batch_size, num_atoms)
            assert utils.notinfnotnan(
                log_prob_prior), "NaN/inf detected in prior eval."

            # Compute unnormalized proposal posterior.
            unnormalized_log_prob_proposal_posterior = (log_prob_posterior -
                                                        log_prob_prior)

            # Normalize proposal posterior across discrete set of atoms.
            log_prob_proposal_posterior = unnormalized_log_prob_proposal_posterior[:, 0] - torch.logsumexp(
                unnormalized_log_prob_proposal_posterior, dim=-1)
            assert utils.notinfnotnan(
                log_prob_proposal_posterior
            ), "NaN/inf detected in proposal posterior eval."

            if self._use_combined_loss:
                masks = masks.reshape(-1)

                log_prob_proposal_posterior = (
                    masks * log_prob_posterior_non_atomic +
                    log_prob_proposal_posterior)

            return log_prob_proposal_posterior