def elbo_with_policy(self, rng, params, x, policy, train, context=None): """Computes the ELBO for AO-ARMs using uniform distribution over policy. Args: rng: random number key. params: parameters for the apply_fn. x: input image policy: An array of integers describing the generative model, parallelizing sampling steps if integers are missing. For example, the list [0, 2, 4, 5] indicates that step 0 & 1 should be generated in parallel, then then 2 & 3 in parallel and then 4 (individually) and then 5, ..., n_steps - 1 (in parallel). train: Is the model in train or eval mode? context: Optional context to condition on. Returns: elbo: batch of stochastic elbo estimates. ce: batch of the direct cross-entropy loss t: batch timesteps that were sampled """ d = np.prod(x.shape[1:]) batch_size = x.shape[0] rng_stage, rng_perm, rng_t, rng_dropout = jax.random.split(rng, 4) # Get random stage s ~ Unif({0, 1, ..., num_stages-1}) stage = jax.random.randint(rng_stage, shape=(batch_size, ), minval=0, maxval=self.num_stages) x_future, x_past = self.corrupt(x, stage) # Get random permutation sigma ~ Unif(S_n_steps) for a stage. sigma_in_stage = ardm_utils.get_batch_permutations( rng_perm, x.shape[0], self.num_steps_per_stage) # Sample t from policy. t_in_stage, _, weight_policy = self.sample_policy_t( rng_t, policy[stage]) t = t_in_stage + stage * self.num_steps_per_stage already_predicted, _ = ardm_utils.get_selection_for_sigma_and_t( sigma_in_stage, t_in_stage, self.config.mask_shape) to_predict = (1 - already_predicted) model_inp = already_predicted * x_future + to_predict * x_past net_out = self.apply_fn( {'params': params}, model_inp, t, self.prepare_additional_input(stage, already_predicted), train, context=context, rngs={'dropout': rng_dropout} if train else None) log_prob_future_given_past = self.log_prob_for_x_future_given_past( x_future, x_past, net_out, stage) log_prob = log_prob_future_given_past * to_predict log_prob = util_fns.sum_except_batch(log_prob) # Negative cross-entropy. nce = log_prob / d / np.log(2) # Reweigh for summation over i. reweighting_factor_expectation_i = 1. / (self.num_steps_per_stage - t_in_stage) elbo_per_t = reweighting_factor_expectation_i * log_prob # Reweigh for expectation over policy and stages. elbo = elbo_per_t * weight_policy * self.num_stages elbo = elbo / d / np.log(2) elbo_per_t = elbo_per_t / d / np.log(2) return elbo, elbo_per_t, nce, t
def elbo_with_policy(self, rng, params, x, policy, train, context=None): """Computes the ELBO for AO-ARMs using uniform distribution over policy. Args: rng: Random number key. params: Parameters for the apply_fn. x: Input image. policy: An array of integers describing the generative model, parallelizing sampling steps if integers are missing. For example, the list [0, 2, 4, 5] indicates that step 0 & 1 should be generated in parallel, then then 2 & 3 in parallel and then 4 (individually) and then 5, ..., n_steps - 1 (in parallel). train: Is the model in train or eval mode? context: Anything the model might want to condition on. Returns: elbo: batch of stochastic elbo estimates. ce: batch of the direct cross-entropy loss t: batch timesteps that were sampled """ d = np.prod(x.shape[1:]) batch_size = x.shape[0] rng_perm, rng_t, rng_drop = jax.random.split(rng, 3) # Get random sigma ~ Unif(S_n_steps) sigmas = ardm_utils.get_batch_permutations(rng_perm, x.shape[0], self.num_steps) # Sample t from policy. t, _, weight_policy = self.sample_policy_t(rng_t, batch_size, policy) prev_selection, _ = ardm_utils.get_selection_for_sigma_and_t( sigmas, t, self.config.mask_shape) future_selection = (1. - prev_selection) corrupted = self.corrupt(x, prev_selection) net_out = self.apply_fn({'params': params}, corrupted, t, prev_selection, train, rngs={'dropout': rng_drop} if train else None, context=context) log_px_sigma_geq_t = self.logprob_fn(x, net_out) log_px_sigma_geq_t = future_selection.reshape( log_px_sigma_geq_t.shape) * log_px_sigma_geq_t log_px_sigma_geq_t = util_fns.sum_except_batch(log_px_sigma_geq_t) ce = log_px_sigma_geq_t / d / np.log(2) # Reweigh for expectation over i. reweighting_factor_expectation_i = 1. / (self.num_steps - t) elbo_per_t = reweighting_factor_expectation_i * log_px_sigma_geq_t # Reweigh for expectation over policy. elbo = elbo_per_t * weight_policy elbo = elbo / d / np.log(2) elbo_per_t = elbo_per_t / d / np.log(2) return elbo, elbo_per_t, ce, t
def log_prob_with_policy_and_sigma(self, rng, params, x, policy, sigmas, train, context=None): """Expected log prob with specific policy and generation order sigma. Computes the log probability for AO-ARMs using a specific policy _and_ a specific permutation sigma. The given permutation makes this exact (hence log prob), the policy ensures that the estimator has reasonable variance. Args: rng: Random number key. params: Parameters for the apply_fn. x: Input image. policy: An array of integers describing the generative model, parallelizing sampling steps if integers are missing. For example, the list [0, 2, 4, 5] indicates that step 0 & 1 should be generated in parallel, then then 2 & 3 in parallel and then 4 (individually) and then 5, ..., n_steps - 1 (in parallel). sigmas: An array describing the generation order that is being enforced. train: Is the model in train or eval mode? context: Anything the model might want to condition on. Returns: log_prob: batch of stochastic log probability estimates. """ d = np.prod(x.shape[1:]) batch_size = x.shape[0] # Expand the dimensions of sigma if only a single order is given. if len(sigmas.shape) == 1: sigmas = jnp.repeat(sigmas[None], repeats=batch_size, axis=0) assert sigmas.shape == (batch_size, self.num_steps), ( f'{sigmas.shape} does not match') rng_t, rng_drop = jax.random.split(rng, 2) # Sample t from policy. left_t, right_t, weight_policy = self.sample_policy_t( rng_t, batch_size, policy) num_tokens_in_parallel = right_t - left_t prev_selection, current_selection = ardm_utils.get_selections_for_sigma_and_range( sigmas, left_t, right_t, self.config.mask_shape) corrupted = self.corrupt(x, prev_selection) net_out = self.apply_fn({'params': params}, corrupted, left_t, prev_selection, train, rngs={'dropout': rng_drop} if train else None, context=context) log_px_sigma_geq_t = self.logprob_fn(x, net_out) current_selection = current_selection.reshape(log_px_sigma_geq_t.shape) log_px_sigma_geq_t = current_selection * log_px_sigma_geq_t log_px_sigma_geq_t = util_fns.sum_except_batch(log_px_sigma_geq_t) # Reweigh for expectation over policy. log_prob = log_px_sigma_geq_t / num_tokens_in_parallel * weight_policy log_prob = log_prob / d / np.log(2) return log_prob
def log_prob_with_policy_and_sigma(self, rng, params, x, policy, sigmas, train, context=None): """Expected log prob with specific policy and generation order sigma. Computes the log probability for AO-ARMs using a specific policy _and_ a specific permutation sigma. The given permutation makes this exact (hence log prob), the policy ensures that the estimator has reasonable variance. Args: rng: Random number key. params: Parameters for the apply_fn. x: Input. policy: An array of integers describing the generative model, parallelizing sampling steps if integers are missing. For example, the list [0, 2, 4, 5] indicates that step 0 & 1 should be generated in parallel, then then 2 & 3 in parallel and then 4 (individually) and then 5, ..., n_steps - 1 (in parallel). sigmas: An array describing the generation order that is being enforced. train: Is the model in train or eval mode? context: Anything the model might want to condition on. Returns: log_prob: batch of stochastic log probability estimates. """ d = np.prod(x.shape[1:]) batch_size = x.shape[0] assert sigmas.shape == (self.num_stages, self.num_steps_per_stage) sigmas = jnp.repeat(sigmas[:, None], repeats=batch_size, axis=1) assert policy.shape[0] == self.num_stages rng_stage, rng_t, rng_dropout = jax.random.split(rng, 3) # Get random stage s ~ Unif({0, 1, ..., num_stages-1}) stage = jax.random.randint(rng_stage, shape=(batch_size, ), minval=0, maxval=self.num_stages) x_future, x_past = self.corrupt(x, stage) # Retrieve the relevant permutation. sigma_in_stage = sigmas[stage, jnp.arange(batch_size)] # Sample t from policy. t_in_stage, right_t_in_stage, weight_policy = self.sample_policy_t( rng_t, policy[stage]) num_tokens_in_parallel = right_t_in_stage - t_in_stage t = t_in_stage + stage * self.num_steps_per_stage already_predicted, to_predict = ardm_utils.get_selections_for_sigma_and_range( sigma_in_stage, t_in_stage, right_t_in_stage, self.config.mask_shape) model_inp = already_predicted * x_future + (1 - already_predicted) * x_past net_out = self.apply_fn( {'params': params}, model_inp, t, self.prepare_additional_input(stage, already_predicted), train, context=context, rngs={'dropout': rng_dropout} if train else None) log_prob_future_given_past = self.log_prob_for_x_future_given_past( x_future, x_past, net_out, stage) log_prob = log_prob_future_given_past * to_predict log_prob = util_fns.sum_except_batch(log_prob) # Reweigh for expectation over policy and stages. log_prob = log_prob / num_tokens_in_parallel * weight_policy * self.num_stages log_prob = log_prob / d / np.log(2) return log_prob