Esempio n. 1
0
    def elbo_with_policy(self, rng, params, x, policy, train, context=None):
        """Computes the ELBO for AO-ARMs using uniform distribution over policy.

    Args:
      rng: random number key.
      params: parameters for the apply_fn.
      x: input image
      policy: An array of integers describing the generative model,
        parallelizing sampling steps if integers are missing. For example, the
        list [0, 2, 4, 5] indicates that step 0 & 1 should be generated in
        parallel, then then 2 & 3 in parallel and then 4 (individually) and then
        5, ..., n_steps - 1 (in parallel).
      train: Is the model in train or eval mode?
      context: Optional context to condition on.

    Returns:
      elbo: batch of stochastic elbo estimates.
      ce: batch of the direct cross-entropy loss
      t: batch timesteps that were sampled
    """
        d = np.prod(x.shape[1:])
        batch_size = x.shape[0]

        rng_stage, rng_perm, rng_t, rng_dropout = jax.random.split(rng, 4)

        # Get random stage s ~ Unif({0, 1, ..., num_stages-1})
        stage = jax.random.randint(rng_stage,
                                   shape=(batch_size, ),
                                   minval=0,
                                   maxval=self.num_stages)

        x_future, x_past = self.corrupt(x, stage)

        # Get random permutation sigma ~ Unif(S_n_steps) for a stage.
        sigma_in_stage = ardm_utils.get_batch_permutations(
            rng_perm, x.shape[0], self.num_steps_per_stage)

        # Sample t from policy.
        t_in_stage, _, weight_policy = self.sample_policy_t(
            rng_t, policy[stage])
        t = t_in_stage + stage * self.num_steps_per_stage

        already_predicted, _ = ardm_utils.get_selection_for_sigma_and_t(
            sigma_in_stage, t_in_stage, self.config.mask_shape)
        to_predict = (1 - already_predicted)

        model_inp = already_predicted * x_future + to_predict * x_past

        net_out = self.apply_fn(
            {'params': params},
            model_inp,
            t,
            self.prepare_additional_input(stage, already_predicted),
            train,
            context=context,
            rngs={'dropout': rng_dropout} if train else None)

        log_prob_future_given_past = self.log_prob_for_x_future_given_past(
            x_future, x_past, net_out, stage)

        log_prob = log_prob_future_given_past * to_predict

        log_prob = util_fns.sum_except_batch(log_prob)

        # Negative cross-entropy.
        nce = log_prob / d / np.log(2)

        # Reweigh for summation over i.
        reweighting_factor_expectation_i = 1. / (self.num_steps_per_stage -
                                                 t_in_stage)
        elbo_per_t = reweighting_factor_expectation_i * log_prob

        # Reweigh for expectation over policy and stages.
        elbo = elbo_per_t * weight_policy * self.num_stages

        elbo = elbo / d / np.log(2)
        elbo_per_t = elbo_per_t / d / np.log(2)

        return elbo, elbo_per_t, nce, t
Esempio n. 2
0
    def elbo_with_policy(self, rng, params, x, policy, train, context=None):
        """Computes the ELBO for AO-ARMs using uniform distribution over policy.

    Args:
      rng: Random number key.
      params: Parameters for the apply_fn.
      x: Input image.
      policy: An array of integers describing the generative model,
        parallelizing sampling steps if integers are missing. For example, the
        list [0, 2, 4, 5] indicates that step 0 & 1 should be generated in
        parallel, then then 2 & 3 in parallel and then 4 (individually) and then
        5, ..., n_steps - 1 (in parallel).
      train: Is the model in train or eval mode?
      context: Anything the model might want to condition on.

    Returns:
      elbo: batch of stochastic elbo estimates.
      ce: batch of the direct cross-entropy loss
      t: batch timesteps that were sampled
    """
        d = np.prod(x.shape[1:])
        batch_size = x.shape[0]

        rng_perm, rng_t, rng_drop = jax.random.split(rng, 3)

        # Get random sigma ~ Unif(S_n_steps)
        sigmas = ardm_utils.get_batch_permutations(rng_perm, x.shape[0],
                                                   self.num_steps)

        # Sample t from policy.
        t, _, weight_policy = self.sample_policy_t(rng_t, batch_size, policy)

        prev_selection, _ = ardm_utils.get_selection_for_sigma_and_t(
            sigmas, t, self.config.mask_shape)
        future_selection = (1. - prev_selection)

        corrupted = self.corrupt(x, prev_selection)

        net_out = self.apply_fn({'params': params},
                                corrupted,
                                t,
                                prev_selection,
                                train,
                                rngs={'dropout': rng_drop} if train else None,
                                context=context)

        log_px_sigma_geq_t = self.logprob_fn(x, net_out)

        log_px_sigma_geq_t = future_selection.reshape(
            log_px_sigma_geq_t.shape) * log_px_sigma_geq_t
        log_px_sigma_geq_t = util_fns.sum_except_batch(log_px_sigma_geq_t)

        ce = log_px_sigma_geq_t / d / np.log(2)

        # Reweigh for expectation over i.
        reweighting_factor_expectation_i = 1. / (self.num_steps - t)
        elbo_per_t = reweighting_factor_expectation_i * log_px_sigma_geq_t

        # Reweigh for expectation over policy.
        elbo = elbo_per_t * weight_policy

        elbo = elbo / d / np.log(2)
        elbo_per_t = elbo_per_t / d / np.log(2)

        return elbo, elbo_per_t, ce, t
Esempio n. 3
0
    def log_prob_with_policy_and_sigma(self,
                                       rng,
                                       params,
                                       x,
                                       policy,
                                       sigmas,
                                       train,
                                       context=None):
        """Expected log prob with specific policy and generation order sigma.

    Computes the log probability for AO-ARMs using a specific policy _and_ a
    specific permutation sigma. The given permutation makes this exact (hence
    log prob), the policy ensures that the estimator has reasonable variance.

    Args:
      rng: Random number key.
      params: Parameters for the apply_fn.
      x: Input image.
      policy: An array of integers describing the generative model,
        parallelizing sampling steps if integers are missing. For example, the
        list [0, 2, 4, 5] indicates that step 0 & 1 should be generated in
        parallel, then then 2 & 3 in parallel and then 4 (individually) and then
        5, ..., n_steps - 1 (in parallel).
      sigmas: An array describing the generation order that is being enforced.
      train: Is the model in train or eval mode?
      context: Anything the model might want to condition on.

    Returns:
      log_prob: batch of stochastic log probability estimates.
    """
        d = np.prod(x.shape[1:])
        batch_size = x.shape[0]

        # Expand the dimensions of sigma if only a single order is given.
        if len(sigmas.shape) == 1:
            sigmas = jnp.repeat(sigmas[None], repeats=batch_size, axis=0)
        assert sigmas.shape == (batch_size, self.num_steps), (
            f'{sigmas.shape} does not match')

        rng_t, rng_drop = jax.random.split(rng, 2)

        # Sample t from policy.
        left_t, right_t, weight_policy = self.sample_policy_t(
            rng_t, batch_size, policy)
        num_tokens_in_parallel = right_t - left_t

        prev_selection, current_selection = ardm_utils.get_selections_for_sigma_and_range(
            sigmas, left_t, right_t, self.config.mask_shape)

        corrupted = self.corrupt(x, prev_selection)

        net_out = self.apply_fn({'params': params},
                                corrupted,
                                left_t,
                                prev_selection,
                                train,
                                rngs={'dropout': rng_drop} if train else None,
                                context=context)

        log_px_sigma_geq_t = self.logprob_fn(x, net_out)

        current_selection = current_selection.reshape(log_px_sigma_geq_t.shape)
        log_px_sigma_geq_t = current_selection * log_px_sigma_geq_t
        log_px_sigma_geq_t = util_fns.sum_except_batch(log_px_sigma_geq_t)

        # Reweigh for expectation over policy.
        log_prob = log_px_sigma_geq_t / num_tokens_in_parallel * weight_policy
        log_prob = log_prob / d / np.log(2)

        return log_prob
Esempio n. 4
0
    def log_prob_with_policy_and_sigma(self,
                                       rng,
                                       params,
                                       x,
                                       policy,
                                       sigmas,
                                       train,
                                       context=None):
        """Expected log prob with specific policy and generation order sigma.

    Computes the log probability for AO-ARMs using a specific policy _and_ a
    specific permutation sigma. The given permutation makes this exact (hence
    log prob), the policy ensures that the estimator has reasonable variance.

    Args:
      rng: Random number key.
      params: Parameters for the apply_fn.
      x: Input.
      policy: An array of integers describing the generative model,
        parallelizing sampling steps if integers are missing. For example, the
        list [0, 2, 4, 5] indicates that step 0 & 1 should be generated in
        parallel, then then 2 & 3 in parallel and then 4 (individually) and then
        5, ..., n_steps - 1 (in parallel).
      sigmas: An array describing the generation order that is being enforced.
      train: Is the model in train or eval mode?
      context: Anything the model might want to condition on.

    Returns:
      log_prob: batch of stochastic log probability estimates.
    """
        d = np.prod(x.shape[1:])
        batch_size = x.shape[0]

        assert sigmas.shape == (self.num_stages, self.num_steps_per_stage)
        sigmas = jnp.repeat(sigmas[:, None], repeats=batch_size, axis=1)
        assert policy.shape[0] == self.num_stages

        rng_stage, rng_t, rng_dropout = jax.random.split(rng, 3)

        # Get random stage s ~ Unif({0, 1, ..., num_stages-1})
        stage = jax.random.randint(rng_stage,
                                   shape=(batch_size, ),
                                   minval=0,
                                   maxval=self.num_stages)

        x_future, x_past = self.corrupt(x, stage)

        # Retrieve the relevant permutation.
        sigma_in_stage = sigmas[stage, jnp.arange(batch_size)]

        # Sample t from policy.
        t_in_stage, right_t_in_stage, weight_policy = self.sample_policy_t(
            rng_t, policy[stage])
        num_tokens_in_parallel = right_t_in_stage - t_in_stage
        t = t_in_stage + stage * self.num_steps_per_stage

        already_predicted, to_predict = ardm_utils.get_selections_for_sigma_and_range(
            sigma_in_stage, t_in_stage, right_t_in_stage,
            self.config.mask_shape)

        model_inp = already_predicted * x_future + (1 -
                                                    already_predicted) * x_past

        net_out = self.apply_fn(
            {'params': params},
            model_inp,
            t,
            self.prepare_additional_input(stage, already_predicted),
            train,
            context=context,
            rngs={'dropout': rng_dropout} if train else None)

        log_prob_future_given_past = self.log_prob_for_x_future_given_past(
            x_future, x_past, net_out, stage)

        log_prob = log_prob_future_given_past * to_predict

        log_prob = util_fns.sum_except_batch(log_prob)

        # Reweigh for expectation over policy and stages.
        log_prob = log_prob / num_tokens_in_parallel * weight_policy * self.num_stages
        log_prob = log_prob / d / np.log(2)

        return log_prob