Ejemplo n.º 1
0
  def _build_train_op(self):
    """Builds the training op for Rainbow.

    Returns:
      train_op: An op performing one step of training.
    """
    target_distribution = tf.stop_gradient(self._build_target_distribution())

    # size of indices: batch_size x 1.
    indices = tf.range(tf.shape(self._replay_logits)[0])[:, None]
    # size of reshaped_actions: batch_size x 2.
    reshaped_actions = tf.concat([indices, self._replay.actions[:, None]], 1)
    # For each element of the batch, fetch the logits for its selected action.
    chosen_action_logits = tf.gather_nd(self._replay_logits, reshaped_actions)

    loss = tf.nn.softmax_cross_entropy_with_logits(
        labels=target_distribution,
        logits=chosen_action_logits)

    optimizer = tf.train.AdamOptimizer(
        learning_rate=self.learning_rate,
        epsilon=self.optimizer_epsilon)

    update_priorities_op = self._replay.tf_set_priority(
        self._replay.indices, tf.sqrt(loss + 1e-10))

    target_priorities = self._replay.tf_get_priority(self._replay.indices)
    target_priorities = tf.math.add(target_priorities, 1e-10)
    target_priorities = 1.0 / tf.sqrt(target_priorities)
    target_priorities /= tf.reduce_max(target_priorities)

    weighted_loss = target_priorities * loss

    with tf.control_dependencies([update_priorities_op]):
      return optimizer.minimize(tf.reduce_mean(weighted_loss)), weighted_loss
Ejemplo n.º 2
0
  def _build_target_distribution(self):
    self._reshape_networks()
    batch_size = tf.shape(self._replay.rewards)[0]
    # size of rewards: batch_size x 1
    rewards = self._replay.rewards[:, None]
    # size of tiled_support: batch_size x num_atoms
    tiled_support = tf.tile(self.support, [batch_size])
    tiled_support = tf.reshape(tiled_support, [batch_size, self.num_atoms])
    # size of target_support: batch_size x num_atoms

    is_terminal_multiplier = 1. - tf.cast(self._replay.terminals, tf.float32)
    # Incorporate terminal state to discount factor.
    # size of gamma_with_terminal: batch_size x 1
    gamma_with_terminal = self.cumulative_gamma * is_terminal_multiplier
    gamma_with_terminal = gamma_with_terminal[:, None]

    target_support = rewards + gamma_with_terminal * tiled_support
    # size of next_probabilities: batch_size  x num_actions x num_atoms
    next_probabilities = tf.contrib.layers.softmax(
        self._replay_next_logits)

    # size of next_qt: 1 x num_actions
    next_qt = tf.reduce_sum(self.support * next_probabilities, 2)
    # size of next_qt_argmax: 1 x batch_size
    next_qt_argmax = tf.argmax(
        next_qt + self._replay.next_legal_actions, axis=1)[:, None]
    batch_indices = tf.range(tf.to_int64(batch_size))[:, None]
    # size of next_qt_argmax: batch_size x 2
    next_qt_argmax = tf.concat([batch_indices, next_qt_argmax], axis=1)
    # size of next_probabilities: batch_size x num_atoms
    next_probabilities = tf.gather_nd(next_probabilities, next_qt_argmax)
    return project_distribution(target_support, next_probabilities,
                                self.support)