def compute_target_topk_q(reward, gamma, next_actions, next_q_values,
                          next_states, terminals):
    """Computes the optimal target Q value with the greedy algorithm.

  This algorithm corresponds to the method "TT" in
  Ie et al. https://arxiv.org/abs/1905.12767.

  Args:
    reward: [batch_size] tensor, the immediate reward.
    gamma: float, discount factor with the usual RL meaning.
    next_actions: [batch_size, slate_size] tensor, the next slate.
    next_q_values: [batch_size, num_of_documents] tensor, the q values of the
      documents in the next step.
    next_states: [batch_size, 1 + num_of_documents] tensor, the features for the
      user and the docuemnts in the next step.
    terminals: [batch_size] tensor, indicating if this is a terminal step.

  Returns:
    [batch_size] tensor, the target q values.
  """
    slate_size = next_actions.get_shape().as_list()[1]
    scores, score_no_click = _get_unnormalized_scores(next_states)

    # Choose the documents with top affinity_scores * Q values to fill a slate and
    # treat it as if it is the optimal slate.
    unnormalized_next_q_target = next_q_values * scores
    _, topk_optimal_slate = tf.math.top_k(unnormalized_next_q_target,
                                          k=slate_size)

    # Get the expected Q-value of the slate containing top-K items.
    # [batch_size, slate_size]
    next_q_values_selected = tf.batch_gather(
        next_q_values, tf.cast(topk_optimal_slate, dtype=tf.int32))

    # Get normalized affinity scores on the slate.
    # [batch_size, slate_size]
    scores_selected = tf.batch_gather(
        scores, tf.cast(topk_optimal_slate, dtype=tf.int32))

    next_q_target_topk = tf.reduce_sum(
        input_tensor=next_q_values_selected * scores_selected,
        axis=1) / (tf.reduce_sum(input_tensor=scores_selected, axis=1) +
                   score_no_click)

    return reward + gamma * next_q_target_topk * (
        1. - tf.cast(terminals, tf.float32))
    def _build_train_op(self):
        """Builds a training op.

    Returns:
      An op performing one step of training from replay data.
    """
        # click_indicator: [B, S]
        # q_values: [B, A]
        # actions: [B, S]
        # slate_q_values: [B, S]
        # replay_click_q: [B]
        click_indicator = self._replay.rewards[:, :,
                                               self._click_response_index]
        slate_q_values = tf.batch_gather(
            self._replay_net_outputs.q_values,
            tf.cast(self._replay.actions, dtype=tf.int32))
        # Only get the Q from the clicked document.
        replay_click_q = tf.reduce_sum(input_tensor=slate_q_values *
                                       click_indicator,
                                       axis=1,
                                       name='replay_click_q')

        target = tf.stop_gradient(self._build_target_q_op())

        clicked = tf.reduce_sum(input_tensor=click_indicator, axis=1)
        clicked_indices = tf.squeeze(tf.where(tf.equal(clicked, 1)), axis=1)
        # clicked_indices is a vector and tf.gather selects the batch dimension.
        q_clicked = tf.gather(replay_click_q, clicked_indices)
        target_clicked = tf.gather(target, clicked_indices)

        def get_train_op():
            loss = tf.reduce_mean(input_tensor=tf.square(q_clicked -
                                                         target_clicked))
            if self.summary_writer is not None:
                with tf.variable_scope('Losses'):
                    tf.summary.scalar('Loss', loss)

            return loss

        loss = tf.cond(pred=tf.greater(tf.reduce_sum(input_tensor=clicked), 0),
                       true_fn=get_train_op,
                       false_fn=lambda: tf.constant(0.),
                       name='')

        return self.optimizer.minimize(loss)