Exemple #1
0
 def f(log_probs, advantages, old_log_probs, mask):
     if reweight:  # Use new policy weights for sampled actions instead.
         mask *= jnp.exp(math.stop_gradient(log_probs) - old_log_probs)
     weights = jnp.minimum(awr_weights(advantages, beta), w_max)
     return -jnp.sum(log_probs * weights * mask) / jnp.sum(mask)
Exemple #2
0
def _WeightedMean(inputs, **unused_kwargs):
  """Returns a layer to compute weighted mean over all values in the input."""
  values, weights = inputs
  return np.sum(values * weights) / np.sum(weights)
Exemple #3
0
 def MaskedL2Loss(inputs, **unused_kwargs):
   y_hat, y, mask = inputs
   l2 = mask * (y_hat - y)**2
   return np.sum(l2) / np.sum(mask)
Exemple #4
0
def _L2(inputs, **unused_kwargs):
  """Returns a layer to compute L2 norms of predicted minus target vectors."""
  y_hat, y = inputs
  return np.sum((y_hat - y)**2, axis=-1)
Exemple #5
0
def _CrossEntropy(inputs, **unused_kwargs):
  """Returns a layer to compute prediction-target cross entropies."""
  y_hat, target_category = inputs
  return -1.0 * np.sum(y_hat * one_hot(target_category, y_hat.shape[-1]),
                       axis=-1)
Exemple #6
0
 def f(log_probs, advantages, old_log_probs, mask):
   del old_log_probs  # Not used in AWR.
   weights = jnp.minimum(awr_weights(advantages, beta), w_max)
   return -jnp.sum(log_probs * weights * mask) / jnp.sum(mask)
Exemple #7
0
def L2Loss(inputs):
    y_hat, y, mask = inputs
    shapes.assert_same_shape(y_hat, y)
    shapes.assert_same_shape(y, mask)
    l2 = mask * (y_hat - y)**2
    return np.sum(l2) / np.sum(mask)
Exemple #8
0
def Sum(x, axis=-1, keepdims=False):
    return np.sum(x, axis=axis, keepdims=keepdims)
Exemple #9
0
def CrossEntropy(x, axis=-1, **kw):
    del kw
    prediction, target = x
    return np.sum(prediction * core.one_hot(target, prediction.shape[-1]),
                  axis=axis)
Exemple #10
0
def Sum(axis=-1, keepdims=False):
    return Fn('Sum', lambda x: jnp.sum(x, axis=axis, keepdims=keepdims))
Exemple #11
0
 def entropy(self, log_probs):
   probs = jnp.exp(log_probs)
   return -jnp.sum(probs * log_probs, axis=-1)
Exemple #12
0
def Sum(x, axis=-1, keepdims=False, **unused_kwargs):
    return np.sum(x, axis=axis, keepdims=keepdims)
Exemple #13
0
  def forward_unbatched(self, x, *, weights, state, rng, update_state):
    attend_rng, output_rng = jax.random.split(rng)
    w_q, w_v, w_o = weights

    q = np.matmul(x, w_q)
    v = np.matmul(x, w_v)

    if update_state:
      _, old_hash_rng = state
      hash_rng, hash_subrng = jax.random.split(old_hash_rng)
      buckets = self.hash_vectors(q, hash_subrng)
      state = (buckets, hash_rng)
    else:
      buckets, _ = state

    seqlen = x.shape[0]
    assert int(buckets.shape[0]) == self.n_hashes * seqlen

    ticker = jax.lax.tie_in(x, np.arange(self.n_hashes * seqlen))
    buckets_and_t = seqlen * buckets + (ticker % seqlen)
    buckets_and_t = jax.lax.stop_gradient(buckets_and_t)

    # Hash-based sort ("s" at the start of variable names means "sorted")
    sbuckets_and_t, sticker = jax.lax.sort_key_val(
        buckets_and_t, ticker, dimension=-1)
    _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1)
    sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t)
    sticker = jax.lax.stop_gradient(sticker)
    undo_sort = jax.lax.stop_gradient(undo_sort)

    st = (sticker % seqlen)
    sq = np.take(q, st, axis=0)
    sv = np.take(v, st, axis=0)

    mask_fn = functools.partial(
        mask_self_attention, causal=self.causal, exclude_self=True)
    q_info = st
    so, slogits = attend(
        sq, k=None, v=sv,
        q_chunk_len=self.chunk_len,
        n_chunks_before=self.n_chunks_before,
        n_chunks_after=self.n_chunks_after,
        mask_fn=mask_fn, q_info=q_info,
        dropout=self.attention_dropout, rng=attend_rng,
        )

    # np.take(so, undo_sort, axis=0); np.take(slogits, undo_sort, axis=0) would
    # also work, but these helpers include performance optimizations for TPU.
    o = permute_via_gather(so, undo_sort, sticker, axis=0)
    logits = permute_via_sort(slogits, sticker, buckets_and_t, axis=-1)

    if self.n_hashes > 1:
      o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1]))
      logits = np.reshape(logits, (self.n_hashes, seqlen, 1))
      probs = np.exp(logits - logsumexp(logits, axis=0, keepdims=True))
      o = np.sum(o * probs, axis=0)

    assert o.shape == (seqlen, w_v.shape[-1])
    out = np.matmul(o, w_o)
    out = apply_broadcasted_dropout(out, self.output_dropout, output_rng)
    return out, state
Exemple #14
0
def WeightedSum(inputs, **unused_kwargs):
  """Returns a layer to compute weighted sum over all values in the input."""
  values, weights = inputs
  return np.sum(values * weights)
Exemple #15
0
    def train_epoch(self, evaluate=True):
        epoch_start_time = time.time()

        # Evaluate the policy.
        policy_eval_start_time = time.time()
        if evaluate and (self.epoch + 1) % self._eval_every_n == 0:
            self.evaluate()
        policy_eval_time = policy_based_utils.get_time(policy_eval_start_time)

        def write_metric(key, value):
            self._train_sw.scalar(key, value, step=self.epoch)
            self._history.append('train', key, self.epoch, value)

        # Get fresh trajectories every time.
        self._should_reset_train_env = True

        trajectory_collection_start_time = time.time()
        logging.vlog(1, 'AWR epoch [% 6d]: collecting trajectories.',
                     self._epoch)
        trajs, _, timing_info, self._model_state = self.collect_trajectories(
            train=True, temperature=1.0, raw_trajectory=True)
        del timing_info
        trajectory_collection_time = policy_based_utils.get_time(
            trajectory_collection_start_time)

        logging.vlog(1, 'AWR epoch [% 6d]: n_trajectories [%s].', self._epoch,
                     len(trajs))

        # Convert these into numpy now.
        def extract_obs_act_rew_dones(traj_np):
            return traj_np[0], traj_np[1], traj_np[2], traj_np[4]

        trajs_np = [extract_obs_act_rew_dones(traj.as_numpy) for traj in trajs]

        # number of new actions.
        new_sample_count = sum(traj[1].shape[0] for traj in trajs_np)
        self._n_observations_seen += new_sample_count
        logging.vlog(1, 'AWR epoch [% 6d]: new_sample_count [%d].',
                     self._epoch, new_sample_count)

        if self._should_write_summaries:
            write_metric('trajs/batch', len(trajs))
            write_metric('trajs/new_sample_count', new_sample_count)

        # The number of trajectories, i.e. `B`can keep changing from iteration to
        # iteration, since we are capped on the number of observations requested.
        # So let's operate on each trajectory on this own?

        # TODO(afrozm): So should our batches look like (B, T+1, *OBS) or B
        # different examples of (T+1, *OBS) each. Since B can keep changing?

        # Add these to the replay buffer.
        for traj in trajs:
            _ = self._replay_buffer.store(traj)

        rewards = np.array([np.sum(traj[2]) for traj in trajs_np])
        avg_reward = np.mean(rewards)
        std_reward = np.std(rewards)
        max_reward = np.max(rewards)
        min_reward = np.min(rewards)

        self._log('train', 'train/reward_mean_truncated', avg_reward)
        if evaluate and not self._separate_eval and self._should_write_summaries:
            metrics = {'raw': {1.0: {'mean': avg_reward, 'std': std_reward}}}
            policy_based_utils.write_eval_reward_summaries(
                metrics, self._log, self.epoch)

        logging.vlog(
            1, 'AWR epoch [% 6d]: Rewards avg=[%0.2f], max=[%0.2f], '
            'min=[%0.2f].', self.epoch, avg_reward, max_reward, min_reward)

        if self._should_write_summaries:
            write_metric('reward/avg', avg_reward)
            write_metric('reward/std', std_reward)
            write_metric('reward/max', max_reward)
            write_metric('reward/min', min_reward)

        # Wrap these observations/rewards inside ReplayBuffer.
        idx, valid_mask, valid_idx = self._replay_buffer.get_valid_indices()

        # pylint: disable=g-complex-comprehension
        observations = [
            self._replay_buffer.get(
                replay_buffer.ReplayBuffer.OBSERVATIONS_KEY,
                idx[start_idx:end_plus_1_idx])
            for (start_idx,
                 end_plus_1_idx) in self._replay_buffer.iterate_over_paths(idx)
        ]

        rewards = [
            self._replay_buffer.get(replay_buffer.ReplayBuffer.REWARDS_KEY,
                                    idx[start_idx:end_plus_1_idx][:-1])
            for (start_idx,
                 end_plus_1_idx) in self._replay_buffer.iterate_over_paths(idx)
        ]
        # pylint: enable=g-complex-comprehension

        t_final = awr_utils.padding_length(rewards, boundary=self._boundary)
        logging.vlog(1, 'AWR epoch [% 6d]: t_final [%s].', self._epoch,
                     t_final)

        if self._should_write_summaries:
            write_metric('trajs/t_final', t_final)

        # These padded observations are over *all* the non-final observations in
        # the entire replay buffer.
        # Shapes:
        # padded_observations      = (B, T + 1, *OBS)
        # padded_observations_mask = (B, T + 1)
        padded_observations, padded_observations_mask = (
            awr_utils.pad_array_to_length(observations, t_final + 1))

        batch = len(observations)
        self._check_shapes('padded_observations',
                           '(batch, t_final + 1)',
                           padded_observations, (batch, t_final + 1),
                           array_prefix=2)
        self._check_shapes('padded_observations_mask', '(batch, t_final + 1)',
                           padded_observations_mask, (batch, t_final + 1))

        # Shapes:
        # padded_rewards      = (B, T)
        # padded_rewards_mask = (B, T)
        padded_rewards, padded_rewards_mask = awr_utils.pad_array_to_length(
            rewards, t_final)
        self._check_shapes('padded_rewards', '(batch, t_final)',
                           padded_rewards, (batch, t_final))
        self._check_shapes('padded_rewards_mask', '(batch, t_final)',
                           padded_rewards_mask, (batch, t_final))

        # Shapes:
        # lengths = (B,)
        lengths = np.sum(padded_rewards_mask, axis=1, dtype=np.int32)
        self._check_shapes('lengths', '(batch,)', lengths, (batch, ))

        # TODO(pkozakowski): Pass the actual actions here, to enable autoregressive
        # action sampling.
        dummy_actions = np.zeros(
            (batch, t_final + 1) + self._action_shape,
            self._action_dtype,
        )

        # Shapes:
        # log_probabs_traj       = (B, T + 1, #controls, #actions)
        # value_predictions_traj = (B, T + 1)
        log_probabs_traj, value_predictions_traj, self._model_state, unused_rng = (
            self._policy_fun_all_timesteps(padded_observations, lengths,
                                           self._model_state, self._get_rng()))
        self._check_shapes(
            'log_probabs_traj', '(batch, t_final + 1, n_controls, n_actions)',
            log_probabs_traj,
            (batch, t_final + 1, self._n_controls, self._n_actions))
        self._check_shapes('value_predictions_traj', '(batch, t_final + 1)',
                           value_predictions_traj, (batch, t_final + 1))

        # Zero out the padding's value predictions, since the net may give some
        # prediction to the padding observations.
        value_predictions_traj *= padded_observations_mask

        # Compute td-lambda returns, and reshape to match value_predictions_traj.
        list_td_lambda_returns = awr_utils.batched_compute_td_lambda_return(
            padded_rewards, padded_rewards_mask, value_predictions_traj,
            padded_observations_mask, self._gamma, self._td_lambda)

        if logging.vlog_is_on(1) and list_td_lambda_returns:
            l = len(list_td_lambda_returns)
            logging.vlog(1, f'Len of list_td_lambda_returns: {l}.')
            self._log_shape('td_lambda_returns[0]', list_td_lambda_returns[0])

        # pad an extra 0 for each to match lengths of value predictions.
        list_target_values = [
            onp.pad(l, (0, 1), 'constant') for l in list_td_lambda_returns
        ]

        if batch != len(list_target_values):
            raise ValueError(f'batch != len(list_target_values) : '
                             f'{batch} vs {len(list_target_values)}')

        # Shape: (len(idx),)
        target_values = onp.concatenate(list_target_values)
        self._check_shapes('target_values', '(len(idx),)', target_values,
                           (len(idx), ))

        # Shape: (len(idx),)
        vals = self.flatten_vals(value_predictions_traj,
                                 padded_observations_mask)
        self._check_shapes('vals', '(len(idx),)', vals, (len(idx), ))

        # Calculate advantages.
        adv, norm_adv, adv_mean, adv_std = self._calc_adv(
            target_values, vals, valid_mask)
        self._check_shapes('norm_adv', '(len(idx),)', norm_adv, (len(idx), ))

        adv_weights, adv_weights_mean, adv_weights_min, adv_weights_max = (
            self._calc_adv_weights(norm_adv, valid_mask))
        self._check_shapes('adv_weights', '(len(idx),)', adv_weights,
                           (len(idx), ))

        del adv, adv_mean, adv_std
        del adv_weights_min, adv_weights_max, adv_weights_mean

        combined_steps = int(
            np.ceil(self._optimization_steps * new_sample_count /
                    self._num_samples_to_collect))
        optimization_start_time = time.time()
        combined_losses = self._update_combined(combined_steps, valid_idx,
                                                target_values, adv_weights)
        optimization_time = policy_based_utils.get_time(
            optimization_start_time)

        self._epoch += 1

        if self._should_write_summaries:
            write_metric('combined/optimization_steps', combined_steps)
            epoch_time = policy_based_utils.get_time(epoch_start_time)
            timing_dict = {
                'epoch': epoch_time,
                'trajectory_collection': trajectory_collection_time,
                'optimization': optimization_time,
                'policy_eval': policy_eval_time,
            }

            if self._should_write_summaries:
                for k, v in timing_dict.items():
                    write_metric('timing/{}'.format(k), v)

            # Only dump the average post losses.
            if combined_losses:
                for k, v in combined_losses.items():
                    if 'post_entropy' in k:
                        write_metric(k.replace('post_entropy', 'entropy'), v)
                    if 'post_loss' in k:
                        write_metric(k.replace('post_loss', 'loss'), v)

        self.flush_summaries()
Exemple #16
0
def L2(x, axis=-1, **kw):
    del kw
    prediction, target = x
    return np.sum((prediction - target)**2, axis=axis)
Exemple #17
0
 def flatten_vals(self, value_predictions_traj, padded_observations_mask):
     batch = len(padded_observations_mask)
     lens = np.sum(padded_observations_mask, axis=1)
     return np.concatenate(
         [value_predictions_traj[b][:int(lens[b])] for b in range(batch)])
Exemple #18
0
def WeightedMean(x, **kw):
    del kw
    metric, weights = x
    weights_sum = np.sum(weights)
    return np.sum(metric * weights) / weights_sum
Exemple #19
0
  def train_epoch(self, evaluate=True):
    def write_metric(key, value):
      self._train_sw.scalar(key, value, step=self.epoch)
      self._history.append('train', key, self.epoch, value)

    # Get fresh trajectories every time.
    self._should_reset_train_env = True

    trajectory_collection_start_time = time.time()
    logging.vlog(1, 'AWR epoch [% 6d]: collecting trajectories.', self._epoch)
    trajs, _, timing_info, self._model_state = self.collect_trajectories(
        train=True, temperature=1.0, raw_trajectory=True)
    del timing_info
    trajectory_collection_time = ppo.get_time(trajectory_collection_start_time)

    # Convert these into numpy now.
    def extract_obs_act_rew_dones(traj_np):
      return traj_np[0], traj_np[1], traj_np[2], traj_np[4]

    trajs_np = [extract_obs_act_rew_dones(traj.as_numpy) for traj in trajs]

    # number of new actions.
    new_sample_count = sum(traj[1].shape[0] for traj in trajs_np)

    if self._should_write_summaries:
      write_metric('trajs/batch', len(trajs))
      write_metric('trajs/new_sample_count', new_sample_count)

    # The number of trajectories, i.e. `B`can keep changing from iteration to
    # iteration, since we are capped on the number of observations requested.
    # So let's operate on each trajectory on this own?

    # TODO(afrozm): So should our batches look like (B, T+1, *OBS) or B
    # different examples of (T+1, *OBS) each. Since B can keep changing?

    # Add these to the replay buffer.
    for traj in trajs:
      _ = self._replay_buffer.store(traj)

    if self._should_write_summaries:
      rewards = np.array([np.sum(traj[2]) for traj in trajs_np])
      avg_reward = np.mean(rewards)
      std_reward = np.std(rewards)
      max_reward = np.max(rewards)
      min_reward = np.min(rewards)

      write_metric('reward/avg', avg_reward)
      write_metric('reward/std', std_reward)
      write_metric('reward/max', max_reward)
      write_metric('reward/min', min_reward)

    # Wrap these observations/rewards inside ReplayBuffer.
    idx, valid_mask, valid_idx = self._replay_buffer.get_valid_indices()

    # pylint: disable=g-complex-comprehension
    observations = [
        self._replay_buffer.get(replay_buffer.ReplayBuffer.OBSERVATIONS_KEY,
                                idx[start_idx:end_plus_1_idx])
        for (start_idx,
             end_plus_1_idx) in self._replay_buffer.iterate_over_paths(idx)
    ]

    rewards = [
        self._replay_buffer.get(replay_buffer.ReplayBuffer.REWARDS_KEY,
                                idx[start_idx:end_plus_1_idx][:-1])
        for (start_idx,
             end_plus_1_idx) in self._replay_buffer.iterate_over_paths(idx)
    ]
    # pylint: enable=g-complex-comprehension

    t_final = awr_utils.padding_length(rewards, boundary=self._boundary)

    if self._should_write_summaries:
      write_metric('trajs/t_final', t_final)

    # These padded observations are over *all* the non-final observations in
    # the entire replay buffer.
    # Shapes:
    # padded_observations      = (B, T + 1, *OBS)
    # padded_observations_mask = (B, T + 1)
    padded_observations, padded_observations_mask = (
        awr_utils.pad_array_to_length(observations, t_final + 1)
    )

    batch = len(observations)
    if ((batch, t_final + 1) != padded_observations.shape[:2] or
        (batch, t_final + 1) != padded_observations_mask.shape):
      raise ValueError(
          f'Shapes mismatch, batch {batch}, t_final {t_final}'
          f'padded_observations.shape {padded_observations.shape}'
          f'padded_observations_mask.shape {padded_observations_mask.shape}')

    # Shapes:
    # padded_rewards      = (B, T)
    # padded_rewards_mask = (B, T)
    padded_rewards, padded_rewards_mask = awr_utils.pad_array_to_length(
        rewards, t_final)
    if ((padded_rewards.shape != (batch, t_final)) or
        (padded_rewards_mask.shape != (batch, t_final))):
      raise ValueError(
          f'Shapes mismatch, batch {batch}, t_final {t_final}'
          f'padded_rewards.shape {padded_rewards.shape}')

    # Shapes:
    # log_probabs_traj       = (B, T + 1, #actions)
    # value_predictions_traj = (B, T + 1)
    (log_probabs_traj, value_predictions_traj) = (
        self._policy_and_value_net_apply(
            padded_observations,
            weights=self._policy_and_value_net_weights,
            state=self._model_state,
            rng=self._get_rng(),
        ))

    if ((batch, t_final + 1) != log_probabs_traj.shape[:2] or
        (batch, t_final + 1) != value_predictions_traj.shape):
      raise ValueError(
          f'Shapes mismatch, batch {batch}, t_final {t_final}'
          f'log_probabs_traj.shape {log_probabs_traj.shape}'
          f'value_predictions_traj.shape {value_predictions_traj.shape}')

    # Zero out the padding's value predictions, since the net may give some
    # prediction to the padding observations.
    value_predictions_traj *= padded_observations_mask

    # Compute td-lambda returns, and reshape to match value_predictions_traj.
    list_td_lambda_returns = awr_utils.batched_compute_td_lambda_return(
        padded_rewards, padded_rewards_mask, value_predictions_traj,
        padded_observations_mask, self._gamma, self._td_lambda)
    # pad an extra 0 for each to match lengths of value predictions.
    list_target_values = [
        onp.pad(l, (0, 1), 'constant') for l in list_td_lambda_returns
    ]

    if batch != len(list_target_values):
      raise ValueError(f'batch != len(list_target_values) : '
                       f'{batch} vs {len(list_target_values)}')

    # Shape: (len(idx),)
    target_values = onp.concatenate(list_target_values)
    if target_values.shape != (len(idx),):
      raise ValueError(f'target_values.shape != (len(idx),) = '
                       f'{target_values.shape} != ({len(idx)},)')

    # Shape: (len(idx),)
    target_values = onp.concatenate(list_target_values)

    vals = self.flatten_vals(value_predictions_traj, padded_observations_mask)

    if vals.shape != target_values.shape:
      raise ValueError(f'vals.shape != target_values.shape : '
                       f'{vals.shape} vs {target_values.shape}')

    # Calculate advantages.
    adv, norm_adv, adv_mean, adv_std = self._calc_adv(
        target_values, vals, valid_mask)

    adv_weights, adv_weights_mean, adv_weights_min, adv_weights_max = (
        self._calc_adv_weights(norm_adv, valid_mask)
    )

    del adv, adv_mean, adv_std
    del adv_weights_min, adv_weights_max, adv_weights_mean

    combined_steps = int(
        np.ceil(self._optimization_steps * new_sample_count /
                self._num_samples_to_collect))
    combined_losses = self._update_combined(combined_steps, valid_idx,
                                            target_values, adv_weights)

    if self._should_write_summaries:
      write_metric('combined/optimization_steps', combined_steps)

      timing_dict = {
          'trajectory_collection': trajectory_collection_time,
          # 'epoch': epoch_time,
          # 'policy_eval': policy_eval_time,
          # 'preprocessing': preprocessing_time,
          # 'log_prob_recompute': log_prob_recompute_time,
          # 'loss_compute': loss_compute_time,
          # 'optimization': optimization_time,
          # 'policy_save': policy_save_time,
      }

      if self._should_write_summaries:
        for k, v in timing_dict.items():
          write_metric('timing/{}'.format(k), v)

      # Only dump the average post losses.
      if combined_losses:
        for k, v in combined_losses.items():
          if 'post_entropy' in k:
            write_metric(k.replace('post_entropy', 'entropy'), v)
          if 'post_loss' in k:
            write_metric(k.replace('post_loss', 'loss'), v)

    self._epoch += 1

    self.flush_summaries()
Exemple #20
0
    def forward_unbatched(self, x, *, weights, state, update_state):
        w_q, w_v, w_o = weights

        q = np.matmul(x, w_q)
        v = np.matmul(x, w_v)

        if update_state:
            _, old_rng = state
            rng = jax.random.fold_in(old_rng, 0)
            hash_rng = jax.random.fold_in(rng, 1)
            buckets = self.hash_vectors(q, hash_rng)
            state = (buckets, rng)
        else:
            buckets, rng = state

        rng = jax.random.fold_in(rng, 2)

        seqlen = x.shape[0]
        assert int(buckets.shape[0]) == self.n_hashes * seqlen

        ticker = jax.lax.tie_in(x, np.arange(self.n_hashes * seqlen))
        buckets_and_t = seqlen * buckets + (ticker % seqlen)
        buckets_and_t = jax.lax.stop_gradient(buckets_and_t)

        # Hash-based sort ("s" at the start of variable names means "sorted")
        sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t,
                                                       ticker,
                                                       dimension=-1)
        _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1)
        sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t)
        sticker = jax.lax.stop_gradient(sticker)
        undo_sort = jax.lax.stop_gradient(undo_sort)

        st = (sticker % seqlen)
        sq = np.take(q, st, axis=0)
        sv = np.take(v, st, axis=0)

        mask_fn = functools.partial(mask_self_attention,
                                    causal=self.causal,
                                    exclude_self=True)
        q_info = st
        so, slogits = attend(
            sq,
            k=None,
            v=sv,
            q_chunk_len=self.chunk_len,
            n_chunks_before=self.n_chunks_before,
            n_chunks_after=self.n_chunks_after,
            mask_fn=mask_fn,
            q_info=q_info,
            dropout=self.attention_dropout,
            rng=rng,
        )

        def unsort_for_output_impl(so, slogits):
            o = np.take(so, undo_sort, axis=0)
            # Sorting is considerably faster than gather, but first we need to get the
            # XLA compiler to abandon the idea of fusing this sort with the input sort
            # (which introduces a computation cycle and leads to a crash).
            # TODO(kitaev): remove "sticker_" variable if XLA is fixed.
            sticker_ = sticker + jax.lax.convert_element_type(
                slogits[0] > 0, sticker.dtype)
            _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1)
            return o, logits

        def unsort_for_output_vjp(so, slogits):
            """Custom gradient for unsort_for_output."""
            so = jax.lax.stop_gradient(so)
            slogits = jax.lax.stop_gradient(slogits)
            o, logits = unsort_for_output_impl(so, slogits)

            def vjpfun(o_logits_grads):
                so_grad = np.take(o_logits_grads[0], sticker, axis=0)
                # TODO(kitaev): this exists to match the forward pass, but I'm not sure
                # if it's actually required.
                buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type(
                    o_logits_grads[1][0] > 0, buckets_and_t.dtype)
                _, slogits_grad = jax.lax.sort_key_val(buckets_and_t_,
                                                       o_logits_grads[1],
                                                       dimension=-1)
                return (so_grad, slogits_grad)

            return (o, logits), vjpfun

        unsort_for_output = jax.custom_transforms(unsort_for_output_impl)
        jax.defvjp_all(unsort_for_output, unsort_for_output_vjp)
        o, logits = unsort_for_output_impl(so, slogits)

        if self.n_hashes > 1:
            o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1]))
            logits = np.reshape(logits, (self.n_hashes, seqlen, 1))
            probs = np.exp(logits - logsumexp(logits, axis=0, keepdims=True))
            o = np.sum(o * probs, axis=0)

        assert o.shape == (seqlen, w_v.shape[-1])
        out = np.matmul(o, w_o)
        return out, state