def compute_target_topk_q(reward, gamma, next_actions, next_q_values, next_states, terminals): """Computes the optimal target Q value with the greedy algorithm. This algorithm corresponds to the method "TT" in Ie et al. https://arxiv.org/abs/1905.12767. Args: reward: [batch_size] tensor, the immediate reward. gamma: float, discount factor with the usual RL meaning. next_actions: [batch_size, slate_size] tensor, the next slate. next_q_values: [batch_size, num_of_documents] tensor, the q values of the documents in the next step. next_states: [batch_size, 1 + num_of_documents] tensor, the features for the user and the docuemnts in the next step. terminals: [batch_size] tensor, indicating if this is a terminal step. Returns: [batch_size] tensor, the target q values. """ slate_size = next_actions.get_shape().as_list()[1] scores, score_no_click = _get_unnormalized_scores(next_states) # Choose the documents with top affinity_scores * Q values to fill a slate and # treat it as if it is the optimal slate. unnormalized_next_q_target = next_q_values * scores _, topk_optimal_slate = tf.math.top_k(unnormalized_next_q_target, k=slate_size) # Get the expected Q-value of the slate containing top-K items. # [batch_size, slate_size] next_q_values_selected = tf.batch_gather( next_q_values, tf.cast(topk_optimal_slate, dtype=tf.int32)) # Get normalized affinity scores on the slate. # [batch_size, slate_size] scores_selected = tf.batch_gather( scores, tf.cast(topk_optimal_slate, dtype=tf.int32)) next_q_target_topk = tf.reduce_sum( input_tensor=next_q_values_selected * scores_selected, axis=1) / (tf.reduce_sum(input_tensor=scores_selected, axis=1) + score_no_click) return reward + gamma * next_q_target_topk * ( 1. - tf.cast(terminals, tf.float32))
def _build_train_op(self): """Builds a training op. Returns: An op performing one step of training from replay data. """ # click_indicator: [B, S] # q_values: [B, A] # actions: [B, S] # slate_q_values: [B, S] # replay_click_q: [B] click_indicator = self._replay.rewards[:, :, self._click_response_index] slate_q_values = tf.batch_gather( self._replay_net_outputs.q_values, tf.cast(self._replay.actions, dtype=tf.int32)) # Only get the Q from the clicked document. replay_click_q = tf.reduce_sum(input_tensor=slate_q_values * click_indicator, axis=1, name='replay_click_q') target = tf.stop_gradient(self._build_target_q_op()) clicked = tf.reduce_sum(input_tensor=click_indicator, axis=1) clicked_indices = tf.squeeze(tf.where(tf.equal(clicked, 1)), axis=1) # clicked_indices is a vector and tf.gather selects the batch dimension. q_clicked = tf.gather(replay_click_q, clicked_indices) target_clicked = tf.gather(target, clicked_indices) def get_train_op(): loss = tf.reduce_mean(input_tensor=tf.square(q_clicked - target_clicked)) if self.summary_writer is not None: with tf.variable_scope('Losses'): tf.summary.scalar('Loss', loss) return loss loss = tf.cond(pred=tf.greater(tf.reduce_sum(input_tensor=clicked), 0), true_fn=get_train_op, false_fn=lambda: tf.constant(0.), name='') return self.optimizer.minimize(loss)