def _build_train_op(self): """Builds the training op for Rainbow. Returns: train_op: An op performing one step of training. """ target_distribution = tf.stop_gradient(self._build_target_distribution()) # size of indices: batch_size x 1. indices = tf.range(tf.shape(self._replay_logits)[0])[:, None] # size of reshaped_actions: batch_size x 2. reshaped_actions = tf.concat([indices, self._replay.actions[:, None]], 1) # For each element of the batch, fetch the logits for its selected action. chosen_action_logits = tf.gather_nd(self._replay_logits, reshaped_actions) loss = tf.nn.softmax_cross_entropy_with_logits( labels=target_distribution, logits=chosen_action_logits) optimizer = tf.train.AdamOptimizer( learning_rate=self.learning_rate, epsilon=self.optimizer_epsilon) update_priorities_op = self._replay.tf_set_priority( self._replay.indices, tf.sqrt(loss + 1e-10)) target_priorities = self._replay.tf_get_priority(self._replay.indices) target_priorities = tf.math.add(target_priorities, 1e-10) target_priorities = 1.0 / tf.sqrt(target_priorities) target_priorities /= tf.reduce_max(target_priorities) weighted_loss = target_priorities * loss with tf.control_dependencies([update_priorities_op]): return optimizer.minimize(tf.reduce_mean(weighted_loss)), weighted_loss
def _build_target_distribution(self): self._reshape_networks() batch_size = tf.shape(self._replay.rewards)[0] # size of rewards: batch_size x 1 rewards = self._replay.rewards[:, None] # size of tiled_support: batch_size x num_atoms tiled_support = tf.tile(self.support, [batch_size]) tiled_support = tf.reshape(tiled_support, [batch_size, self.num_atoms]) # size of target_support: batch_size x num_atoms is_terminal_multiplier = 1. - tf.cast(self._replay.terminals, tf.float32) # Incorporate terminal state to discount factor. # size of gamma_with_terminal: batch_size x 1 gamma_with_terminal = self.cumulative_gamma * is_terminal_multiplier gamma_with_terminal = gamma_with_terminal[:, None] target_support = rewards + gamma_with_terminal * tiled_support # size of next_probabilities: batch_size x num_actions x num_atoms next_probabilities = tf.contrib.layers.softmax( self._replay_next_logits) # size of next_qt: 1 x num_actions next_qt = tf.reduce_sum(self.support * next_probabilities, 2) # size of next_qt_argmax: 1 x batch_size next_qt_argmax = tf.argmax( next_qt + self._replay.next_legal_actions, axis=1)[:, None] batch_indices = tf.range(tf.to_int64(batch_size))[:, None] # size of next_qt_argmax: batch_size x 2 next_qt_argmax = tf.concat([batch_indices, next_qt_argmax], axis=1) # size of next_probabilities: batch_size x num_atoms next_probabilities = tf.gather_nd(next_probabilities, next_qt_argmax) return project_distribution(target_support, next_probabilities, self.support)