Exemple #1
0
def double_qlearning(q_tm1,
                     a_tm1,
                     r_t,
                     pcont_t,
                     q_t_value,
                     q_t_selector,
                     name="DoubleQLearning"):
    """Implements the double Q-learning loss as a TensorFlow op.

  The loss is `0.5` times the squared difference between `q_tm1[a_tm1]` and
  the target `r_t + pcont_t * q_t_value[argmax q_t_selector]`.

  See "Double Q-learning" by van Hasselt.
  (https://papers.nips.cc/paper/3964-double-q-learning.pdf).

  Args:
    q_tm1: Tensor holding Q-values for first timestep in a batch of
      transitions, shape [B x num_actions].
    a_tm1: Tensor holding action indices, shape [B].
    r_t: Tensor holding rewards, shape [B].
    pcont_t: Tensor holding pcontinue values, shape [B].
    q_t_value: Tensor of Q-values for second timestep in a batch of transitions,
      used to estimate the value of the best action, shape [B x num_actions].
    q_t_selector: Tensor of Q-values for second timestep in a batch of
      transitions used to estimate the best action, shape [B x num_actions].
    name: name to prefix ops created within this op.

  Returns:
    A namedtuple with fields:

    * `loss`: a tensor containing the batch of losses, shape [B].
    * `extra`: a namedtuple with fields:
        * `target`: batch of target values for `q_tm1[a_tm1]`, shape [B]
        * `td_error`: batch of temporal difference errors, shape [B]
        * `best_action`: batch of greedy actions wrt `q_t_selector`, shape [B]
  """
    # Rank and compatibility checks.
    base_ops.wrap_rank_shape_assert(
        [[q_tm1, q_t_value, q_t_selector], [a_tm1, r_t, pcont_t]], [2, 1],
        name)

    # double Q-learning op.
    with tf.name_scope(
            name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t_value,
                          q_t_selector]):

        # Build target and select head to update.
        best_action = tf.argmax(q_t_selector, 1, output_type=tf.int32)
        double_q_bootstrapped = indexing_ops.batched_index(
            q_t_value, best_action)
        target = tf.stop_gradient(r_t + pcont_t * double_q_bootstrapped)
        qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)

        # Temporal difference error and loss.
        # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
        td_error = target - qa_tm1
        loss = 0.5 * tf.square(td_error)
        return base_ops.LossOutput(loss,
                                   DoubleQExtra(target, td_error, best_action))
Exemple #2
0
def persistent_qlearning(q_tm1,
                         a_tm1,
                         r_t,
                         pcont_t,
                         q_t,
                         action_gap_scale=0.5,
                         name="PersistentQLearning"):
    """Implements the persistent Q-learning loss as a TensorFlow op.

  The loss is `0.5` times the squared difference between `q_tm1[a_tm1]` and
  `r_t + pcont_t * [(1-action_gap_scale) max q_t + action_gap_scale qa_t]`

  See "Increasing the Action Gap: New Operators for Reinforcement Learning"
  by Bellemare, Ostrovski, Guez et al. (https://arxiv.org/abs/1512.04860).

  Args:
    q_tm1: Tensor holding Q-values for first timestep in a batch of
      transitions, shape [B x num_actions].
    a_tm1: Tensor holding action indices, shape [B].
    r_t: Tensor holding rewards, shape [B].
    pcont_t: Tensor holding pcontinue values, shape [B].
    q_t: Tensor holding Q-values for second timestep in a batch of
      transitions, shape [B x num_actions].
      These values are used for estimating the value of the best action. In
      DQN they come from the target network.
    action_gap_scale: coefficient in [0, 1] for scaling the action gap term.
    name: name to prefix ops created within this op.

  Returns:
    A namedtuple with fields:

    * `loss`: a tensor containing the batch of losses, shape [B].
    * `extra`: a namedtuple with fields:
        * `target`: batch of target values for `q_tm1[a_tm1]`, shape [B].
        * `td_error`: batch of temporal difference errors, shape [B].
  """
    # Rank and compatibility checks.
    base_ops.wrap_rank_shape_assert([[q_tm1, q_t], [a_tm1, r_t, pcont_t]],
                                    [2, 1], name)
    base_ops.assert_arg_bounded(action_gap_scale, 0, 1, name,
                                "action_gap_scale")

    # persistent Q-learning op.
    with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t]):

        # Build target and select head to update.
        with tf.name_scope("target"):
            max_q_t = tf.reduce_max(q_t, axis=1)
            qa_t = indexing_ops.batched_index(q_t, a_tm1)
            corrected_q_t = (
                1 - action_gap_scale) * max_q_t + action_gap_scale * qa_t
            target = tf.stop_gradient(r_t + pcont_t * corrected_q_t)
        qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)

        # Temporal difference error and loss.
        # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
        td_error = target - qa_tm1
        loss = 0.5 * tf.square(td_error)
        return base_ops.LossOutput(loss, QExtra(target, td_error))
Exemple #3
0
def sarsa_lambda(q_tm1,
                 a_tm1,
                 r_t,
                 pcont_t,
                 q_t,
                 a_t,
                 lambda_,
                 name="SarsaLambda"):
  """Implements SARSA(lambda) loss as a TensorFlow op.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/book/ebook/node77.html).

  Args:
    q_tm1: `Tensor` holding a sequence of Q-values starting at the first
      timestep; shape `[T, B, num_actions]`
    a_tm1: `Tensor` holding a sequence of action indices, shape `[T, B]`
    r_t: Tensor holding a sequence of rewards, shape `[T, B]`
    pcont_t: `Tensor` holding a sequence of pcontinue values, shape `[T, B]`
    q_t: `Tensor` holding a sequence of Q-values for second timestep;
      shape `[T, B, num_actions]`.
    a_t: `Tensor` holding a sequence of action indices for second timestep;
      shape `[T, B]`
    lambda_: a scalar specifying the ratio of mixing between bootstrapped and
      MC returns.
    name: a name of the op.

  Returns:
    A namedtuple with fields:

    * `loss`: a tensor containing the batch of losses, shape `[T, B]`.
    * `extra`: a namedtuple with fields:
        * `target`: batch of target values for `q_tm1[a_tm1]`, shape `[T, B]`.
        * `td_error`: batch of temporal difference errors, shape `[T, B]`.
  """
  # Rank and compatibility checks.
  base_ops.wrap_rank_shape_assert(
      [[q_tm1, q_t], [a_tm1, r_t, pcont_t, a_t]], [3, 2], name)

  # SARSALambda op.
  with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t, a_t]):

    # Select head to update and build target.
    qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)
    qa_t = indexing_ops.batched_index(q_t, a_t)
    target = sequence_ops.multistep_forward_view(
        r_t, pcont_t, qa_t, lambda_, back_prop=False)
    target = tf.stop_gradient(target)

    # Temporal difference error and loss.
    # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
    td_error = target - qa_tm1
    loss = 0.5 * tf.square(td_error)
    return base_ops.LossOutput(loss, QExtra(target, td_error))
Exemple #4
0
def persistent_qlearning(
    q_tm1, a_tm1, r_t, pcont_t, q_t, action_gap_scale=0.5,
    name="PersistentQLearning"):
  """Implements the persistent Q-learning loss as a TensorFlow op.

  The loss is `0.5` times the squared difference between `q_tm1[a_tm1]` and
  `r_t + pcont_t * [(1-action_gap_scale) max q_t + action_gap_scale qa_t]`

  See "Increasing the Action Gap: New Operators for Reinforcement Learning"
  by Bellemare, Ostrovski, Guez et al. (https://arxiv.org/abs/1512.04860).

  Args:
    q_tm1: Tensor holding Q-values for first timestep in a batch of
      transitions, shape `[B x num_actions]`.
    a_tm1: Tensor holding action indices, shape `[B]`.
    r_t: Tensor holding rewards, shape `[B]`.
    pcont_t: Tensor holding pcontinue values, shape `[B]`.
    q_t: Tensor holding Q-values for second timestep in a batch of
      transitions, shape `[B x num_actions]`.
      These values are used for estimating the value of the best action. In
      DQN they come from the target network.
    action_gap_scale: coefficient in [0, 1] for scaling the action gap term.
    name: name to prefix ops created within this op.

  Returns:
    A namedtuple with fields:

    * `loss`: a tensor containing the batch of losses, shape `[B]`.
    * `extra`: a namedtuple with fields:
        * `target`: batch of target values for `q_tm1[a_tm1]`, shape `[B]`.
        * `td_error`: batch of temporal difference errors, shape `[B]`.
  """
  # Rank and compatibility checks.
  base_ops.wrap_rank_shape_assert(
      [[q_tm1, q_t], [a_tm1, r_t, pcont_t]], [2, 1], name)
  base_ops.assert_arg_bounded(action_gap_scale, 0, 1, name, "action_gap_scale")

  # persistent Q-learning op.
  with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t]):

    # Build target and select head to update.
    with tf.name_scope("target"):
      max_q_t = tf.reduce_max(q_t, axis=1)
      qa_t = indexing_ops.batched_index(q_t, a_tm1)
      corrected_q_t = (1 - action_gap_scale) * max_q_t + action_gap_scale * qa_t
      target = tf.stop_gradient(r_t + pcont_t * corrected_q_t)
    qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)

    # Temporal difference error and loss.
    # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
    td_error = target - qa_tm1
    loss = 0.5 * tf.square(td_error)
    return base_ops.LossOutput(loss, QExtra(target, td_error))
Exemple #5
0
def double_qlearning(
    q_tm1, a_tm1, r_t, pcont_t, q_t_value, q_t_selector,
    name="DoubleQLearning"):
  """Implements the double Q-learning loss as a TensorFlow op.

  The loss is `0.5` times the squared difference between `q_tm1[a_tm1]` and
  the target `r_t + pcont_t * q_t_value[argmax q_t_selector]`.

  See "Double Q-learning" by van Hasselt.
  (https://papers.nips.cc/paper/3964-double-q-learning.pdf).

  Args:
    q_tm1: Tensor holding Q-values for first timestep in a batch of
      transitions, shape `[B x num_actions]`.
    a_tm1: Tensor holding action indices, shape `[B]`.
    r_t: Tensor holding rewards, shape `[B]`.
    pcont_t: Tensor holding pcontinue values, shape `[B]`.
    q_t_value: Tensor of Q-values for second timestep in a batch of transitions,
      used to estimate the value of the best action, shape `[B x num_actions]`.
    q_t_selector: Tensor of Q-values for second timestep in a batch of
      transitions used to estimate the best action, shape `[B x num_actions]`.
    name: name to prefix ops created within this op.

  Returns:
    A namedtuple with fields:

    * `loss`: a tensor containing the batch of losses, shape `[B]`.
    * `extra`: a namedtuple with fields:
        * `target`: batch of target values for `q_tm1[a_tm1]`, shape `[B]`
        * `td_error`: batch of temporal difference errors, shape `[B]`
        * `best_action`: batch of greedy actions wrt `q_t_selector`, shape `[B]`
  """
  # Rank and compatibility checks.
  base_ops.wrap_rank_shape_assert(
      [[q_tm1, q_t_value, q_t_selector], [a_tm1, r_t, pcont_t]], [2, 1], name)

  # double Q-learning op.
  with tf.name_scope(
      name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t_value, q_t_selector]):

    # Build target and select head to update.
    best_action = tf.argmax(q_t_selector, 1, output_type=tf.int32)
    double_q_bootstrapped = indexing_ops.batched_index(q_t_value, best_action)
    target = tf.stop_gradient(r_t + pcont_t * double_q_bootstrapped)
    qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)

    # Temporal difference error and loss.
    # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
    td_error = target - qa_tm1
    loss = 0.5 * tf.square(td_error)
    return base_ops.LossOutput(
        loss, DoubleQExtra(target, td_error, best_action))
Exemple #6
0
def qlambda(
    q_tm1, a_tm1, r_t, pcont_t, q_t, lambda_, name="GeneralizedQLambda"):
  """Implements Peng's and Watkins' Q(lambda) loss as a TensorFlow op.

  This function is general enough to implement both Peng's and Watkins'
  Q-lambda algorithms.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/book/ebook/node78.html).

  Args:
    q_tm1: `Tensor` holding a sequence of Q-values starting at the first
      timestep; shape `[T, B, num_actions]`
    a_tm1: `Tensor` holding a sequence of action indices, shape `[T, B]`
    r_t: Tensor holding a sequence of rewards, shape `[T, B]`
    pcont_t: `Tensor` holding a sequence of pcontinue values, shape `[T, B]`
    q_t: `Tensor` holding a sequence of Q-values for second timestep;
      shape `[T, B, num_actions]`. In a target network setting,
      this quantity is often supplied by the target network.
    lambda_: a scalar or `Tensor` of shape `[T, B]`
      specifying the ratio of mixing between bootstrapped and MC returns;
      if lambda_ is the same for all time steps then the function implements
      Peng's Q-learning algorithm; if lambda_ = 0 at every sub-optimal action
      and a constant otherwise, then the function implements Watkins'
      Q-learning algorithm. Generally lambda_ can be a Tensor of any values
      in the range [0, 1] supplied by the user.
    name: a name of the op.

  Returns:
    A namedtuple with fields:

    * `loss`: a tensor containing the batch of losses, shape `[T, B]`.
    * `extra`: a namedtuple with fields:
        * `target`: batch of target values for `q_tm1[a_tm1]`, shape `[T, B]`.
        * `td_error`: batch of temporal difference errors, shape `[T, B]`.
  """
  # Rank and compatibility checks.
  base_ops.wrap_rank_shape_assert([[q_tm1, q_t]], [3], name)
  if isinstance(lambda_, tf.Tensor) and lambda_.get_shape().ndims > 0:
    base_ops.wrap_rank_shape_assert([[a_tm1, r_t, pcont_t, lambda_]], [2], name)
  else:
    base_ops.wrap_rank_shape_assert([[a_tm1, r_t, pcont_t]], [2], name)

  # QLambda op.
  with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t]):

    # Build target and select head to update.
    with tf.name_scope("target"):
      state_values = tf.reduce_max(q_t, axis=2)
      target = sequence_ops.multistep_forward_view(
          r_t, pcont_t, state_values, lambda_, back_prop=False)
      target = tf.stop_gradient(target)
    qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)

    # Temporal difference error and loss.
    # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
    td_error = target - qa_tm1
    loss = 0.5 * tf.square(td_error)
    return base_ops.LossOutput(loss, QExtra(target, td_error))
Exemple #7
0
def sarsa(q_tm1, a_tm1, r_t, pcont_t, q_t, a_t, name="Sarsa"):
  """Implements the SARSA loss as a TensorFlow op.

  The loss is `0.5` times the squared difference between `q_tm1[a_tm1]` and
  the target `r_t + pcont_t * q_t[a_t]`.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/book/ebook/node64.html.)

  Args:
    q_tm1: Tensor holding Q-values for first timestep in a batch of
      transitions, shape `[B x num_actions]`.
    a_tm1: Tensor holding action indices, shape `[B]`.
    r_t: Tensor holding rewards, shape `[B]`.
    pcont_t: Tensor holding pcontinue values, shape `[B]`.
    q_t: Tensor holding Q-values for second timestep in a batch of
      transitions, shape `[B x num_actions]`.
    a_t: Tensor holding action indices for second timestep, shape `[B]`.
    name: name to prefix ops created within this op.

  Returns:
    A namedtuple with fields:

    * `loss`: a tensor containing the batch of losses, shape `[B]`.
    * `extra`: a namedtuple with fields:
        * `target`: batch of target values for `q_tm1[a_tm1]`, shape `[B]`.
        * `td_error`: batch of temporal difference errors, shape `[B]`.
  """
  # Rank and compatibility checks.
  base_ops.wrap_rank_shape_assert(
      [[q_tm1, q_t], [a_t, r_t, pcont_t]], [2, 1], name)

  # SARSA op.
  with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t, a_t]):

    # Select head to update and build target.
    qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)
    qa_t = indexing_ops.batched_index(q_t, a_t)
    target = tf.stop_gradient(r_t + pcont_t * qa_t)

    # Temporal difference error and loss.
    # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
    td_error = target - qa_tm1
    loss = 0.5 * tf.square(td_error)
    return base_ops.LossOutput(loss, QExtra(target, td_error))
Exemple #8
0
def sarsa(q_tm1, a_tm1, r_t, pcont_t, q_t, a_t, name="Sarsa"):
  """Implements the SARSA loss as a TensorFlow op.

  The loss is `0.5` times the squared difference between `q_tm1[a_tm1]` and
  the target `r_t + pcont_t * q_t[a_t]`.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/book/ebook/node64.html.)

  Args:
    q_tm1: Tensor holding Q-values for first timestep in a batch of
      transitions, shape `[B x num_actions]`.
    a_tm1: Tensor holding action indices, shape `[B]`.
    r_t: Tensor holding rewards, shape `[B]`.
    pcont_t: Tensor holding pcontinue values, shape `[B]`.
    q_t: Tensor holding Q-values for second timestep in a batch of
      transitions, shape `[B x num_actions]`.
    a_t: Tensor holding action indices for second timestep, shape `[B]`.
    name: name to prefix ops created within this op.

  Returns:
    A namedtuple with fields:

    * `loss`: a tensor containing the batch of losses, shape `[B]`.
    * `extra`: a namedtuple with fields:
        * `target`: batch of target values for `q_tm1[a_tm1]`, shape `[B]`.
        * `td_error`: batch of temporal difference errors, shape `[B]`.
  """
  # Rank and compatibility checks.
  base_ops.wrap_rank_shape_assert(
      [[q_tm1, q_t], [a_t, r_t, pcont_t]], [2, 1], name)

  # SARSA op.
  with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t, a_t]):

    # Select head to update and build target.
    qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)
    qa_t = indexing_ops.batched_index(q_t, a_t)
    target = tf.stop_gradient(r_t + pcont_t * qa_t)

    # Temporal difference error and loss.
    # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
    td_error = target - qa_tm1
    loss = 0.5 * tf.square(td_error)
    return base_ops.LossOutput(loss, QExtra(target, td_error))
Exemple #9
0
    def testValueSequence(self):
        """Indexing value functions by action with a minibatch of sequences."""
        values = [[[1.1, 1.2, 1.3], [1.4, 1.5, 1.6]],
                  [[2.1, 2.2, 2.3], [2.4, 2.5, 2.6]],
                  [[3.1, 3.2, 3.3], [3.4, 3.5, 3.6]],
                  [[4.1, 4.2, 4.3], [4.4, 4.5, 4.6]]]
        action_indices = [[0, 2], [1, 0], [2, 1], [0, 2]]
        result = indexing_ops.batched_index(values, action_indices)
        expected_result = [[1.1, 1.6], [2.2, 2.4], [3.3, 3.5], [4.1, 4.6]]

        with self.test_session() as sess:
            self.assertAllClose(sess.run(result), expected_result)
Exemple #10
0
    def testOrdinaryValues(self, keepdims):
        """Indexing value functions by action for a minibatch of values."""
        values = [[1.1, 1.2, 1.3], [1.4, 1.5, 1.6], [2.1, 2.2, 2.3],
                  [2.4, 2.5, 2.6], [3.1, 3.2, 3.3], [3.4, 3.5, 3.6],
                  [4.1, 4.2, 4.3], [4.4, 4.5, 4.6]]
        action_indices = [0, 2, 1, 0, 2, 1, 0, 2]
        result = indexing_ops.batched_index(values,
                                            action_indices,
                                            keepdims=keepdims)
        expected_result = [1.1, 1.6, 2.2, 2.4, 3.3, 3.5, 4.1, 4.6]
        if keepdims:
            expected_result = np.expand_dims(expected_result, axis=-1)

        with self.test_session() as sess:
            self.assertAllClose(sess.run(result), expected_result)
Exemple #11
0
  def testOrdinaryValues(self):
    """Indexing value functions by action for a minibatch of values."""
    values = [[1.1, 1.2, 1.3],
              [1.4, 1.5, 1.6],
              [2.1, 2.2, 2.3],
              [2.4, 2.5, 2.6],
              [3.1, 3.2, 3.3],
              [3.4, 3.5, 3.6],
              [4.1, 4.2, 4.3],
              [4.4, 4.5, 4.6]]
    action_indices = [0, 2, 1, 0, 2, 1, 0, 2]
    result = indexing_ops.batched_index(values, action_indices)
    expected_result = [1.1, 1.6, 2.2, 2.4, 3.3, 3.5, 4.1, 4.6]

    with self.test_session() as sess:
      self.assertAllClose(sess.run(result), expected_result)
Exemple #12
0
def qv_learning(q_tm1, a_tm1, r_t, pcont_t, v_t, name="QVLearning"):
  """Implements the QV loss as a TensorFlow op.

  The loss is `0.5` times the squared difference between `q_tm1[a_tm1]` and
  the target `r_t + pcont_t * v_t`, where `v_t` is separately learned through
  temporal difference learning (c.f. `value_ops.td_learning`).

  See "Two Novel On-policy Reinforcement Learning Algorithms based on
  TD(lambda)-methods" by Wiering and van Hasselt
  (https://ieeexplore.ieee.org/abstract/document/4220845.)

  Args:
    q_tm1: Tensor holding Q-values for first timestep in a batch of
      transitions, shape `[B x num_actions]`.
    a_tm1: Tensor holding action indices, shape `[B]`.
    r_t: Tensor holding rewards, shape `[B]`.
    pcont_t: Tensor holding pcontinue values, shape `[B]`.
    v_t: Tensor holding state-values for second timestep in a batch of
      transitions, shape `[B]`.
    name: name to prefix ops created within this op.

  Returns:
    A namedtuple with fields:

    * `loss`: a tensor containing the batch of losses, shape `[B]`.
    * `extra`: a namedtuple with fields:
        * `target`: batch of target values for `q_tm1[a_tm1]`, shape `[B]`.
        * `td_error`: batch of temporal difference errors, shape `[B]`.
  """
  # Rank and compatibility checks.
  base_ops.wrap_rank_shape_assert(
      [[q_tm1], [a_tm1, r_t, pcont_t, v_t]], [2, 1], name)

  # QV op.
  with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, v_t]):

    # Build target and select head to update.
    with tf.name_scope("target"):
      target = tf.stop_gradient(r_t + pcont_t * v_t)
    qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)

    # Temporal difference error and loss.
    # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
    td_error = target - qa_tm1
    loss = 0.5 * tf.square(td_error)
    return base_ops.LossOutput(loss, QExtra(target, td_error))
Exemple #13
0
  def testValueSequence(self):
    """Indexing value functions by action with a minibatch of sequences."""
    values = [[[1.1, 1.2, 1.3], [1.4, 1.5, 1.6]],
              [[2.1, 2.2, 2.3], [2.4, 2.5, 2.6]],
              [[3.1, 3.2, 3.3], [3.4, 3.5, 3.6]],
              [[4.1, 4.2, 4.3], [4.4, 4.5, 4.6]]]
    action_indices = [[0, 2],
                      [1, 0],
                      [2, 1],
                      [0, 2]]
    result = indexing_ops.batched_index(values, action_indices)
    expected_result = [[1.1, 1.6],
                       [2.2, 2.4],
                       [3.3, 3.5],
                       [4.1, 4.6]]

    with self.test_session() as sess:
      self.assertAllClose(sess.run(result), expected_result)
Exemple #14
0
    def __init__(
        self,
        obs_spec: dm_env.specs.Array,
        action_spec: dm_env.specs.BoundedArray,
        ensemble: Sequence[snt.AbstractModule],
        target_ensemble: Sequence[snt.AbstractModule],
        batch_size: int,
        agent_discount: float,
        replay_capacity: int,
        min_replay_size: int,
        sgd_period: int,
        target_update_period: int,
        optimizer: tf.train.Optimizer,
        mask_prob: float,
        noise_scale: float,
        epsilon_fn: Callable[[int], float] = lambda _: 0.,
        seed: int = None,
    ):
        """Bootstrapped DQN with additive prior functions."""
        # Dqn configurations.
        self._ensemble = ensemble
        self._target_ensemble = target_ensemble
        self._num_actions = action_spec.maximum - action_spec.minimum + 1
        self._batch_size = batch_size
        self._sgd_period = sgd_period
        self._target_update_period = target_update_period
        self._min_replay_size = min_replay_size
        self._epsilon_fn = epsilon_fn
        self._replay = replay.Replay(capacity=replay_capacity)
        self._mask_prob = mask_prob
        self._noise_scale = noise_scale
        self._rng = np.random.RandomState(seed)
        tf.set_random_seed(seed)

        self._total_steps = 0
        self._total_episodes = 0
        self._active_head = 0
        self._num_ensemble = len(ensemble)
        assert len(ensemble) == len(target_ensemble)

        # Making the tensorflow graph
        session = tf.Session()

        # Placeholders = (obs, action, reward, discount, next_obs, mask, noise)
        o_tm1 = tf.placeholder(shape=(None, ) + obs_spec.shape,
                               dtype=obs_spec.dtype)
        a_tm1 = tf.placeholder(shape=(None, ), dtype=action_spec.dtype)
        r_t = tf.placeholder(shape=(None, ), dtype=tf.float32)
        d_t = tf.placeholder(shape=(None, ), dtype=tf.float32)
        o_t = tf.placeholder(shape=(None, ) + obs_spec.shape,
                             dtype=obs_spec.dtype)
        m_t = tf.placeholder(shape=(None, self._num_ensemble),
                             dtype=tf.float32)
        z_t = tf.placeholder(shape=(None, self._num_ensemble),
                             dtype=tf.float32)

        losses = []
        value_fns = []
        target_updates = []
        for k in range(self._num_ensemble):
            model = self._ensemble[k]
            target_model = self._target_ensemble[k]
            q_values = model(o_tm1)

            train_value = batched_index(q_values, a_tm1)
            target_value = tf.stop_gradient(
                tf.reduce_max(target_model(o_t), axis=-1))
            target_y = r_t + z_t[:, k] + agent_discount * d_t * target_value
            loss = tf.square(train_value - target_y) * m_t[:, k]

            value_fn = session.make_callable(q_values, [o_tm1])
            target_update = update_target_variables(
                target_variables=target_model.get_all_variables(),
                source_variables=model.get_all_variables(),
            )

            losses.append(loss)
            value_fns.append(value_fn)
            target_updates.append(target_update)

        sgd_op = optimizer.minimize(tf.stack(losses))
        self._value_fns = value_fns
        self._sgd_step = session.make_callable(
            sgd_op, [o_tm1, a_tm1, r_t, d_t, o_t, m_t, z_t])
        self._update_target_nets = session.make_callable(target_updates)
        session.run(tf.global_variables_initializer())
Exemple #15
0
    def _forward(self, inputs: Any) -> None:
        """Trainer forward pass

        Args:
            inputs (Any): input data from the data table (transitions)
        """

        # Unpack input data as follows:
        # o_tm1 = dictionary of observations one for each agent
        # a_tm1 = dictionary of actions taken from obs in o_tm1
        # e_tm1 [Optional] = extra data that the agents persist in replay.
        # r_t = dictionary of rewards or rewards sequences
        #   (if using N step transitions) ensuing from actions a_tm1
        # d_t = environment discount ensuing from actions a_tm1.
        #   This discount is applied to future rewards after r_t.
        # o_t = dictionary of next observations or next observation sequences
        # e_t = [Optional] = extra data that the agents persist in replay.
        o_tm1, a_tm1, e_tm1, r_t, d_t, o_t, e_t = inputs.data
        s_tm1 = e_tm1["s_t"]
        s_t = e_t["s_t"]

        # Do forward passes through the networks and calculate the losses
        with tf.GradientTape(persistent=True) as tape:
            q_acts = []  # Q vals
            q_targets = []  # Target Q vals
            for agent in self._agents:
                agent_key = self.agent_net_keys[agent]

                o_tm1_feed, o_t_feed, a_tm1_feed = self._get_feed(
                    o_tm1, o_t, a_tm1, agent)
                q_tm1 = self._q_networks[agent_key](o_tm1_feed)
                q_t_value = self._target_q_networks[agent_key](o_t_feed)
                q_t_selector = self._q_networks[agent_key](o_t_feed)
                best_action = tf.argmax(q_t_selector,
                                        axis=1,
                                        output_type=tf.int32)

                # TODO Make use of q_t_selector for fingerprinting. Speak to Claude.
                q_act = batched_index(q_tm1, a_tm1_feed,
                                      keepdims=True)  # [B, 1]
                q_target = batched_index(q_t_value, best_action,
                                         keepdims=True)  # [B, 1]

                q_acts.append(q_act)
                q_targets.append(q_target)

            rewards = tf.concat(
                [tf.reshape(val, (-1, 1)) for val in list(r_t.values())],
                axis=1)
            rewards = tf.reduce_mean(rewards, axis=1)  # [B]

            pcont = tf.concat(
                [tf.reshape(val, (-1, 1)) for val in list(d_t.values())],
                axis=1)
            pcont = tf.reduce_mean(pcont, axis=1)
            discount = tf.cast(self._discount, list(d_t.values())[0].dtype)
            pcont = discount * pcont  # [B]

            q_acts = tf.concat(q_acts, axis=1)  # [B, num_agents]
            q_targets = tf.concat(q_targets, axis=1)  # [B, num_agents]

            q_tot_mixed = self._mixing_network(q_acts, s_tm1)  # [B, 1, 1]
            q_tot_target_mixed = self._target_mixing_network(q_targets,
                                                             s_t)  # [B, 1, 1]

            # Calculate Q loss.
            targets = rewards + pcont * q_tot_target_mixed
            td_error = targets - q_tot_mixed

            # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
            self.loss = 0.5 * tf.reduce_mean(tf.square(td_error))
            self.tape = tape
Exemple #16
0
def sarse(
    q_tm1, a_tm1, r_t, pcont_t, q_t, probs_a_t, debug=False, name="Sarse"):
  """Implements the SARSE (Expected SARSA) loss as a TensorFlow op.

  The loss is `0.5` times the squared difference between `q_tm1[a_tm1]` and
  the target `r_t + pcont_t * (sum_a probs_a_t[a] * q_t[a])`.

  See "A Theoretical and Empirical Analysis of Expected Sarsa" by Seijen,
  van Hasselt, Whiteson et al.
  (http://www.cs.ox.ac.uk/people/shimon.whiteson/pubs/vanseijenadprl09.pdf).

  Args:
    q_tm1: Tensor holding Q-values for first timestep in a batch of
      transitions, shape `[B x num_actions]`.
    a_tm1: Tensor holding action indices, shape `[B]`.
    r_t: Tensor holding rewards, shape `[B]`.
    pcont_t: Tensor holding pcontinue values, shape `[B]`.
    q_t: Tensor holding Q-values for second timestep in a batch of
      transitions, shape `[B x num_actions]`.
    probs_a_t: Tensor holding action probabilities for second timestep,
      shape `[B x num_actions]`.
    debug: Boolean flag, when set to True adds ops to check whether probs_a_t
      is a batch of (approximately) valid probability distributions.
    name: name to prefix ops created by this function.

  Returns:
    A namedtuple with fields:

    * `loss`: a tensor containing the batch of losses, shape `[B]`.
    * `extra`: a namedtuple with fields:
        * `target`: batch of target values for `q_tm1[a_tm1]`, shape `[B]`.
        * `td_error`: batch of temporal difference errors, shape `[B]`.
  """
  # Rank and compatibility checks.
  base_ops.wrap_rank_shape_assert(
      [[q_tm1, q_t, probs_a_t], [a_tm1, r_t, pcont_t]], [2, 1], name)

  # SARSE (Expected SARSA) op.
  with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t, probs_a_t]):

    # Debug ops.
    deps = []
    if debug:
      cumulative_prob = tf.reduce_sum(probs_a_t, axis=1)
      almost_prob = tf.less(tf.abs(tf.subtract(cumulative_prob, 1.0)), 1e-6)
      deps.append(tf.Assert(
          tf.reduce_all(almost_prob),
          ["probs_a_t tensor does not sum to 1", probs_a_t]))

    # With dependency on possible debug ops.
    with tf.control_dependencies(deps):

      # Select head to update and build target.
      qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)
      target = tf.stop_gradient(
          r_t + pcont_t * tf.reduce_sum(tf.multiply(q_t, probs_a_t), axis=1))

      # Temporal difference error and loss.
      # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
      td_error = target - qa_tm1
      loss = 0.5 * tf.square(td_error)
      return base_ops.LossOutput(loss, QExtra(target, td_error))
Exemple #17
0
def sarse(
    q_tm1, a_tm1, r_t, pcont_t, q_t, probs_a_t, debug=False, name="Sarse"):
  """Implements the SARSE (Expected SARSA) loss as a TensorFlow op.

  The loss is `0.5` times the squared difference between `q_tm1[a_tm1]` and
  the target `r_t + pcont_t * (sum_a probs_a_t[a] * q_t[a])`.

  See "A Theoretical and Empirical Analysis of Expected Sarsa" by Seijen,
  van Hasselt, Whiteson et al.
  (http://www.cs.ox.ac.uk/people/shimon.whiteson/pubs/vanseijenadprl09.pdf).

  Args:
    q_tm1: Tensor holding Q-values for first timestep in a batch of
      transitions, shape `[B x num_actions]`.
    a_tm1: Tensor holding action indices, shape `[B]`.
    r_t: Tensor holding rewards, shape `[B]`.
    pcont_t: Tensor holding pcontinue values, shape `[B]`.
    q_t: Tensor holding Q-values for second timestep in a batch of
      transitions, shape `[B x num_actions]`.
    probs_a_t: Tensor holding action probabilities for second timestep,
      shape `[B x num_actions]`.
    debug: Boolean flag, when set to True adds ops to check whether probs_a_t
      is a batch of (approximately) valid probability distributions.
    name: name to prefix ops created by this function.

  Returns:
    A namedtuple with fields:

    * `loss`: a tensor containing the batch of losses, shape `[B]`.
    * `extra`: a namedtuple with fields:
        * `target`: batch of target values for `q_tm1[a_tm1]`, shape `[B]`.
        * `td_error`: batch of temporal difference errors, shape `[B]`.
  """
  # Rank and compatibility checks.
  base_ops.wrap_rank_shape_assert(
      [[q_tm1, q_t, probs_a_t], [a_tm1, r_t, pcont_t]], [2, 1], name)

  # SARSE (Expected SARSA) op.
  with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t, probs_a_t]):

    # Debug ops.
    deps = []
    if debug:
      cumulative_prob = tf.reduce_sum(probs_a_t, axis=1)
      almost_prob = tf.less(tf.abs(tf.subtract(cumulative_prob, 1.0)), 1e-6)
      deps.append(tf.Assert(
          tf.reduce_all(almost_prob),
          ["probs_a_t tensor does not sum to 1", probs_a_t]))

    # With dependency on possible debug ops.
    with tf.control_dependencies(deps):

      # Select head to update and build target.
      qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)
      target = tf.stop_gradient(
          r_t + pcont_t * tf.reduce_sum(tf.multiply(q_t, probs_a_t), axis=1))

      # Temporal difference error and loss.
      # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
      td_error = target - qa_tm1
      loss = 0.5 * tf.square(td_error)
      return base_ops.LossOutput(loss, QExtra(target, td_error))
Exemple #18
0
def qlambda(
    q_tm1, a_tm1, r_t, pcont_t, q_t, lambda_, name="GeneralizedQLambda"):
  """Implements Peng's and Watkins' Q(lambda) loss as a TensorFlow op.

  This function is general enough to implement both Peng's and Watkins'
  Q-lambda algorithms.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/book/ebook/node78.html).

  Args:
    q_tm1: `Tensor` holding a sequence of Q-values starting at the first
      timestep; shape `[T, B, num_actions]`
    a_tm1: `Tensor` holding a sequence of action indices, shape `[T, B]`
    r_t: Tensor holding a sequence of rewards, shape `[T, B]`
    pcont_t: `Tensor` holding a sequence of pcontinue values, shape `[T, B]`
    q_t: `Tensor` holding a sequence of Q-values for second timestep;
      shape `[T, B, num_actions]`. In a target network setting,
      this quantity is often supplied by the target network.
    lambda_: a scalar or `Tensor` of shape `[T, B]`
      specifying the ratio of mixing between bootstrapped and MC returns;
      if lambda_ is the same for all time steps then the function implements
      Peng's Q-learning algorithm; if lambda_ = 0 at every sub-optimal action
      and a constant otherwise, then the function implements Watkins'
      Q-learning algorithm. Generally lambda_ can be a Tensor of any values
      in the range [0, 1] supplied by the user.
    name: a name of the op.

  Returns:
    A namedtuple with fields:

    * `loss`: a tensor containing the batch of losses, shape `[T, B]`.
    * `extra`: a namedtuple with fields:
        * `target`: batch of target values for `q_tm1[a_tm1]`, shape `[T, B]`.
        * `td_error`: batch of temporal difference errors, shape `[T, B]`.
  """
  # Rank and compatibility checks.
  base_ops.wrap_rank_shape_assert([[q_tm1, q_t]], [3], name)
  if isinstance(
      lambda_, tf.Tensor
  ) and lambda_.get_shape().ndims is not None and lambda_.get_shape().ndims > 0:
    base_ops.wrap_rank_shape_assert([[a_tm1, r_t, pcont_t, lambda_]], [2], name)
  else:
    base_ops.wrap_rank_shape_assert([[a_tm1, r_t, pcont_t]], [2], name)

  # QLambda op.
  with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t]):

    # Build target and select head to update.
    with tf.name_scope("target"):
      state_values = tf.reduce_max(q_t, axis=2)
      target = sequence_ops.multistep_forward_view(
          r_t, pcont_t, state_values, lambda_, back_prop=False)
      target = tf.stop_gradient(target)
    qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)

    # Temporal difference error and loss.
    # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
    td_error = target - qa_tm1
    loss = 0.5 * tf.square(td_error)
    return base_ops.LossOutput(loss, QExtra(target, td_error))
Exemple #19
0
  def testInputShapeChecks(self):
    """Input shape checks can catch some, but not all, shape problems."""
    # 1. Inputs have incorrect or incompatible ranks:
    for args in [dict(values=[[5, 5]], indices=1),
                 dict(values=[5, 5], indices=[1]),
                 dict(values=[[[5, 5]]], indices=[1]),
                 dict(values=[[5, 5]], indices=[[[1]]]),]:
      with self.assertRaisesRegexp(ValueError, "do not correspond"):
        indexing_ops.batched_index(**args)

    # 2. Inputs have correct, compatible ranks but incompatible sizes:
    for args in [dict(values=[[5, 5]], indices=[1, 1]),
                 dict(values=[[5, 5], [5, 5]], indices=[1]),
                 dict(values=[[[5, 5], [5, 5]]], indices=[[1, 1], [1, 1]]),
                 dict(values=[[[5, 5], [5, 5]]], indices=[[1], [1]]),]:
      with self.assertRaisesRegexp(ValueError, "incompatible shapes"):
        indexing_ops.batched_index(**args)

    # (Correct ranks and sizes work fine, though):
    indexing_ops.batched_index(
        values=[[5, 5]], indices=[1])
    indexing_ops.batched_index(
        values=[[[5, 5], [5, 5]]], indices=[[1, 1]])

    # 3. Shape-checking works with fully-specified placeholders, or even
    # partially-specified placeholders that still provide evidence of having
    # incompatible shapes or incorrect ranks.
    for sizes in [dict(q_size=[4, 3], a_size=[4, 1]),
                  dict(q_size=[4, 2, 3], a_size=[4, 1]),
                  dict(q_size=[4, 3], a_size=[5, None]),
                  dict(q_size=[None, 2, 3], a_size=[4, 1]),
                  dict(q_size=[4, 2, 3], a_size=[None, 1]),
                  dict(q_size=[4, 2, 3], a_size=[5, None]),
                  dict(q_size=[None, None], a_size=[None, None]),
                  dict(q_size=[None, None, None], a_size=[None]),]:
      with self.assertRaises(ValueError):
        indexing_ops.batched_index(
            tf.placeholder(tf.float32, sizes["q_size"]),
            tf.placeholder(tf.int32, sizes["a_size"]))

    # But it can't work with 100% certainty if full shape information is not
    # known ahead of time. These cases generate no errors; some make warnings:
    for sizes in [dict(q_size=None, a_size=None),
                  dict(q_size=None, a_size=[4]),
                  dict(q_size=[4, 2], a_size=None),
                  dict(q_size=[None, 2], a_size=[None]),
                  dict(q_size=[None, 2, None], a_size=[None, 2]),
                  dict(q_size=[4, None, None], a_size=[4, None]),
                  dict(q_size=[None, None], a_size=[None]),
                  dict(q_size=[None, None, None], a_size=[None, None]),]:
      indexing_ops.batched_index(
          tf.placeholder(tf.float32, sizes["q_size"]),
          tf.placeholder(tf.int32, sizes["a_size"]))

    # And it can't detect invalid indices at construction time, either.
    indexing_ops.batched_index(values=[[5, 5, 5]], indices=[1000000000])
Exemple #20
0
def _general_off_policy_corrected_multistep_target(r_t,
                                                   pcont_t,
                                                   target_policy_t,
                                                   c_t,
                                                   q_t,
                                                   a_t,
                                                   back_prop=False,
                                                   name=None):
  """Evaluates targets for various off-policy value correction based algorithms.

  `target_policy_t` is the policy that this function aims to evaluate. New
  action-value estimates (target values `T`) must be expressible in this
  recurrent form:
  ```none
  T(x_{t-1}, a_{t-1}) = r_t + γ[ 𝔼_π Q(x_t, .) - c_t Q(x_t, a_t) +
                                                 c_t T(x_t, a_t) ]
  ```
  `T(x_t, a_t)` is an estimate of expected discounted future returns based
  on the current Q value estimates `Q(x_t, a_t)` and rewards `r_t`. The
  evaluated target values can be used as supervised targets for learning the Q
  function itself or as returns for various policy gradient algorithms.
  `Q==T` if convergence is reached. As the formula is recurrent, it will
  evaluate multistep returns for non-zero importance weights `c_t`.

  In the usual moving and target network setup `q_t` should be calculated by
  the target network while the `target_policy_t` may be evaluated by either of
  the networks. If `target_policy_t` is evaluated by the current moving network
  the algorithm implemented will have a similar flavour as double DQN.

  Depending on the choice of c_t, the algorithm can implement:
  ```none
  Importance Sampling             c_t = π(x_t, a_t) / μ(x_t, a_t),
  Harutyunyan's et al. Q(lambda)  c_t = λ,
  Precup's et al. Tree-Backup     c_t = π(x_t, a_t),
  Munos' et al. Retrace           c_t = λ min(1, π(x_t, a_t) / μ(x_t, a_t)).
  ```
  Please refer to page 3 for more details:
  https://arxiv.org/pdf/1606.02647v1.pdf

  Args:
    r_t: 2-D tensor holding rewards received during the transition
      that corresponds to each major index.
      Shape is `[T, B]`.
    pcont_t: 2-D tensor holding pcontinue values received during the
      transition that corresponds to each major index.
      Shape is `[T, B]`.
    target_policy_t:  3-D tensor holding per-action policy probabilities for
      the states encountered just AFTER the transitions that correspond to
      each major index, according to the target policy (i.e. the policy we
      wish to learn). These usually derive from the learning net.
      Shape is `[T, B, num_actions]`.
    c_t: 2-D tensor holding importance weights; see discussion above.
      Shape is `[T, B]`.
    q_t: 3-D tensor holding per-action Q-values for the states
      encountered just AFTER taking the transitions that correspond to each
      major index. Shape is `[T, B, num_actions]`.
    a_t: 2-D tensor holding the indices of actions executed during the
      transition AFTER the transition that corresponds to each major index.
      Shape is `[T, B]`.
    back_prop: whether to backpropagate gradients through time.
    name: name of the op.

  Returns:
    Tensor of shape `[T, B, num_actions]` containing Q values.
  """
  # Formula (4) in https://arxiv.org/pdf/1606.02647v1.pdf can be expressed
  # in a recursive form where T is a new target value:
  # T(x_{t-1}, a_{t-1}) = r_t + γ[ 𝔼_π Q(x_t, .) - c_t Q(x_t, a_t) +
  #                                                c_t T(x_t, a_t) ]
  # This recurrent form allows us to express Retrace by using
  # `scan_discounted_sum`.
  # Define:
  #   T_tm1   = T(x_{t-1}, a_{t-1})
  #   T_t     = T(x_t, a_t)
  #   exp_q_t = 𝔼_π Q(x_{t+1},.)
  #   qa_t    = Q(x_t, a_t)
  # Hence:
  #   T_tm1   = (r_t + γ * exp_q_t - c_t * qa_t) + γ * c_t * T_t
  # Define:
  #   current = r_t + γ * (exp_q_t - c_t * qa_t)
  # Thus:
  #   T_tm1 = scan_discounted_sum(current, γ * c_t, reverse=True)
  args = [r_t, pcont_t, target_policy_t, c_t, q_t, a_t]
  with tf.name_scope(
      name, 'general_returns_based_off_policy_target', values=args):
    exp_q_t = tf.reduce_sum(target_policy_t * q_t, axis=2)
    qa_t = indexing_ops.batched_index(q_t, a_t)
    current = r_t + pcont_t * (exp_q_t - c_t * qa_t)
    initial_value = qa_t[-1]
    return sequence_ops.scan_discounted_sum(
        current,
        pcont_t * c_t,
        initial_value,
        reverse=True,
        back_prop=back_prop)
Exemple #21
0
def _general_off_policy_corrected_multistep_target(r_t,
                                                   pcont_t,
                                                   target_policy_t,
                                                   c_t,
                                                   q_t,
                                                   a_t,
                                                   back_prop=False,
                                                   name=None):
    """Evaluates targets for various off-policy value correction based algorithms.

  `target_policy_t` is the policy that this function aims to evaluate. New
  action-value estimates (target values `T`) must be expressible in this
  recurrent form:
  ```none
  T(x_{t-1}, a_{t-1}) = r_t + γ[ 𝔼_π Q(x_t, .) - c_t Q(x_t, a_t) +
                                                 c_t T(x_t, a_t) ]
  ```
  `T(x_t, a_t)` is an estimate of expected discounted future returns based
  on the current Q value estimates `Q(x_t, a_t)` and rewards `r_t`. The
  evaluated target values can be used as supervised targets for learning the Q
  function itself or as returns for various policy gradient algorithms.
  `Q==T` if convergence is reached. As the formula is recurrent, it will
  evaluate multistep returns for non-zero importance weights `c_t`.

  In the usual moving and target network setup `q_t` should be calculated by
  the target network while the `target_policy_t` may be evaluated by either of
  the networks. If `target_policy_t` is evaluated by the current moving network
  the algorithm implemented will have a similar flavour as double DQN.

  Depending on the choice of c_t, the algorithm can implement:
  ```none
  Importance Sampling             c_t = π(x_t, a_t) / μ(x_t, a_t),
  Harutyunyan's et al. Q(lambda)  c_t = λ,
  Precup's et al. Tree-Backup     c_t = π(x_t, a_t),
  Munos' et al. Retrace           c_t = λ min(1, π(x_t, a_t) / μ(x_t, a_t)).
  ```
  Please refer to page 3 for more details:
  https://arxiv.org/pdf/1606.02647v1.pdf

  Args:
    r_t: 2-D tensor holding rewards received during the transition
      that corresponds to each major index.
      Shape is `[T, B]`.
    pcont_t: 2-D tensor holding pcontinue values received during the
      transition that corresponds to each major index.
      Shape is `[T, B]`.
    target_policy_t:  3-D tensor holding per-action policy probabilities for
      the states encountered just AFTER the transitions that correspond to
      each major index, according to the target policy (i.e. the policy we
      wish to learn). These usually derive from the learning net.
      Shape is `[T, B, num_actions]`.
    c_t: 2-D tensor holding importance weights; see discussion above.
      Shape is `[T, B]`.
    q_t: 3-D tensor holding per-action Q-values for the states
      encountered just AFTER taking the transitions that correspond to each
      major index. Shape is `[T, B, num_actions]`.
    a_t: 2-D tensor holding the indices of actions executed during the
      transition AFTER the transition that corresponds to each major index.
      Shape is `[T, B]`.
    back_prop: whether to backpropagate gradients through time.
    name: name of the op.

  Returns:
    Tensor of shape `[T, B, num_actions]` containing Q values.
  """
    # Formula (4) in https://arxiv.org/pdf/1606.02647v1.pdf can be expressed
    # in a recursive form where T is a new target value:
    # T(x_{t-1}, a_{t-1}) = r_t + γ[ 𝔼_π Q(x_t, .) - c_t Q(x_t, a_t) +
    #                                                c_t T(x_t, a_t) ]
    # This recurrent form allows us to express Retrace by using
    # `scan_discounted_sum`.
    # Define:
    #   T_tm1   = T(x_{t-1}, a_{t-1})
    #   T_t     = T(x_t, a_t)
    #   exp_q_t = 𝔼_π Q(x_t,.)
    #   qa_t    = Q(x_t, a_t)
    # Hence:
    #   T_tm1   = r_t + γ * (exp_q_t - c_t * qa_t) + γ * c_t * T_t
    # Define:
    #   current = r_t + γ * (exp_q_t - c_t * qa_t)
    # Thus:
    #   T_tm1 = scan_discounted_sum(current, γ * c_t, reverse=True)
    args = [r_t, pcont_t, target_policy_t, c_t, q_t, a_t]
    with tf.name_scope(name,
                       'general_returns_based_off_policy_target',
                       values=args):
        exp_q_t = tf.reduce_sum(target_policy_t * q_t, axis=2)
        qa_t = indexing_ops.batched_index(q_t, a_t)
        current = r_t + pcont_t * (exp_q_t - c_t * qa_t)
        initial_value = qa_t[-1]
        return sequence_ops.scan_discounted_sum(current,
                                                pcont_t * c_t,
                                                initial_value,
                                                reverse=True,
                                                back_prop=back_prop)
Exemple #22
0
def retrace_core(lambda_,
                 q_tm1,
                 a_tm1,
                 r_t,
                 pcont_t,
                 target_policy_t,
                 behaviour_policy_t,
                 targnet_q_t,
                 a_t,
                 stop_targnet_gradients=True,
                 name=None):
  """Retrace algorithm core loss calculation op.

  Given a minibatch of temporally-contiguous sequences of Q values, policy
  probabilities, and various other typical RL algorithm inputs, this
  Op creates a subgraph that computes a loss according to the
  Retrace multi-step off-policy value learning algorithm. This Op supports the
  use of target networks, but does not require them.

  This function is the "core" Retrace op only because its arguments are less
  user-friendly and more implementation-convenient. For a more user-friendly
  operator, consider using `retrace`. For more details of Retrace, refer to
  [the arXiv paper](http://arxiv.org/abs/1606.02647).

  Construct the "core" retrace loss subgraph for a batch of sequences.

  Note that two pairs of arguments (one holding target network values; the
  other, actions) are temporally-offset versions of each other and will share
  many values in common (nb: a good setting for using `IndexedSlices`). *This
  op does not include any checks that these pairs of arguments are
  consistent*---that is, it does not ensure that temporally-offset
  arguments really do share the values they are supposed to share.

  In argument descriptions, `T` counts the number of transitions over which
  the Retrace loss is computed, and `B` is the minibatch size. All tensor
  arguments are indexed first by transition, with specific details of this
  indexing in the argument descriptions (pay close attention to "subscripts"
  in variable names).

  Args:
    lambda_: Positive scalar value or 0-D `Tensor` controlling the degree to
      which future timesteps contribute to the loss computed at each
      transition.
    q_tm1: 3-D tensor holding per-action Q-values for the states encountered
      just before taking the transitions that correspond to each major index.
      Since these values are the predicted values we wish to update (in other
      words, the values we intend to change as we learn), in a target network
      setting, these nearly always come from the "non-target" network, which
      we usually call the "learning network".
      Shape is `[T, B, num_actions]`.
    a_tm1: 2-D tensor holding the indices of actions executed during the
      transition that corresponds to each major index.
      Shape is `[T, B]`.
    r_t: 2-D tensor holding rewards received during the transition
      that corresponds to each major index.
      Shape is `[T, B]`.
    pcont_t: 2-D tensor holding pcontinue values received during the
      transition that corresponds to each major index.
      Shape is `[T, B]`.
    target_policy_t: 3-D tensor holding per-action policy probabilities for
      the states encountered just AFTER the transitions that correspond to
      each major index, according to the target policy (i.e. the policy we
      wish to learn). These usually derive from the learning net.
      Shape is `[T, B, num_actions]`.
    behaviour_policy_t: 2-D tensor holding the *behaviour* policy's
      probabilities of having taken action `a_t` at the states encountered
      just AFTER the transitions that correspond to each major index. Derived
      from whatever policy you used to generate the data. All values MUST be
      greater that 0. Shape is `[T, B]`.
    targnet_q_t: 3-D tensor holding per-action Q-values for the states
      encountered just AFTER taking the transitions that correspond to each
      major index. Since these values are used to calculate target values for
      the network, in a target in a target network setting, these should
      probably come from the target network.
      Shape is `[T, B, num_actions]`.
    a_t: 2-D tensor holding the indices of actions executed during the
      transition AFTER the transition that corresponds to each major index.
      Shape is `[T, B]`.
    stop_targnet_gradients: `bool` that enables a sensible default way of
      handling gradients through the Retrace op (essentially, gradients
      are not permitted to involve the `targnet_q_t` input).
      Can be disabled if you require a different arragement, but
      you'll probably want to block some gradients somewhere.
    name: name to prefix ops created by this function.

  Returns:
    A namedtuple with fields:

    * `loss`: Tensor containing the batch of losses, shape `[B]`.
    * `extra`: A namedtuple with fields:
        * `retrace_weights`: Tensor containing batch of retrace weights,
        shape `[T, B]`.
        * `target`: Tensor containing target action values, shape `[T, B]`.
  """
  all_args = [
      lambda_, q_tm1, a_tm1, r_t, pcont_t, target_policy_t, behaviour_policy_t,
      targnet_q_t, a_t
  ]

  with tf.name_scope(name, 'RetraceCore', all_args):
    (lambda_, q_tm1, a_tm1, r_t, pcont_t, target_policy_t, behaviour_policy_t,
     targnet_q_t, a_t) = (
         tf.convert_to_tensor(arg) for arg in all_args)

    # Evaluate importance weights.
    c_t = _retrace_weights(
        indexing_ops.batched_index(target_policy_t, a_t),
        behaviour_policy_t) * lambda_
    # Targets are evaluated by using only Q values from the target network.
    # This provides fixed regression targets until the next target network
    # update.
    target = _general_off_policy_corrected_multistep_target(
        r_t, pcont_t, target_policy_t, c_t, targnet_q_t, a_t,
        not stop_targnet_gradients)

    if stop_targnet_gradients:
      target = tf.stop_gradient(target)
    # Regress Q values of the learning network towards the targets evaluated
    # by using the target network.
    qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)
    delta = target - qa_tm1
    loss = 0.5 * tf.square(delta)

    return base_ops.LossOutput(
        loss, RetraceCoreExtra(retrace_weights=c_t, target=target))
Exemple #23
0
  def testFullShapeAvailableAtRuntimeOnly(self):
    """What happens when shape information isn't available statically?

    The short answer is: it still works. The long answer is: it still works, but
    arguments that shouldn't work due to argument shape mismatch can sometimes
    work without raising any errors! This can cause insidious bugs. This test
    verifies correct behaviour and also demonstrates kinds of shape mismatch
    that can go undetected. Look for `!!!DANGER!!!` below.

    Why this is possible: internally, `batched_index` flattens its inputs,
    then transforms the action indices you provide into indices into its
    flattened Q values tensor. So long as these flattened indices don't go
    out-of-bounds, and so long as your arguments are compatible with a few
    other bookkeeping operations, the operation will succeed.

    The moral: always provide as much shape information as you can! See also
    `testInputShapeChecks` for more on what shape checking can accomplish when
    only partial shape information is available.
    """

    ## 1. No shape information is available during construction time.
    q_values = tf.placeholder(tf.float32)
    actions = tf.placeholder(tf.int32)
    values = indexing_ops.batched_index(q_values, actions)

    with self.test_session() as sess:
      # First, correct and compatible Q values and indices work as intended.
      self.assertAllClose(
          [51],
          sess.run(values, feed_dict={q_values: [[50, 51]], actions: [1]}))
      self.assertAllClose(
          [[51, 52]],
          sess.run(values,
                   feed_dict={q_values: [[[50, 51], [52, 53]]],
                              actions: [[1, 0]]}))

      # !!!DANGER!!! These "incompatible" shapes are silently tolerated!
      # (These examples are probably not exhaustive, either!)
      qs_2x2 = [[5, 5], [5, 5]]
      qs_2x2x2 = [[[5, 5], [5, 5]],
                  [[5, 5], [5, 5]]]
      sess.run(values, feed_dict={q_values: qs_2x2, actions: [0]})
      sess.run(values, feed_dict={q_values: qs_2x2, actions: 0})
      sess.run(values, feed_dict={q_values: qs_2x2x2, actions: [[0]]})
      sess.run(values, feed_dict={q_values: qs_2x2x2, actions: [0]})
      sess.run(values, feed_dict={q_values: qs_2x2x2, actions: 0})

    ## 2a. Shape information is only known for the batch size (2-D case).
    q_values = tf.placeholder(tf.float32, shape=[2, None])
    actions = tf.placeholder(tf.int32, shape=[2])
    values = indexing_ops.batched_index(q_values, actions)
    with self.test_session() as sess:
      # Correct and compatible Q values and indices work as intended.
      self.assertAllClose(
          [51, 52],
          sess.run(values,
                   feed_dict={q_values: [[50, 51], [52, 53]], actions: [1, 0]}))
      # There are no really terrible shape errors that go uncaught in this case.

    ## 2b. Shape information is only known for the batch size (3-D case).
    q_values = tf.placeholder(tf.float32, shape=[None, 2, None])
    actions = tf.placeholder(tf.int32, shape=[None, 2])
    values = indexing_ops.batched_index(q_values, actions)
    with self.test_session() as sess:
      # Correct and compatible Q values and indices work as intended.
      self.assertAllClose(
          [[51, 52]],
          sess.run(values,
                   feed_dict={q_values: [[[50, 51], [52, 53]]],
                              actions: [[1, 0]]}))

      # !!!DANGER!!! This "incompatible" shape is silently tolerated!
      sess.run(values, feed_dict={q_values: qs_2x2x2, actions: [[0, 0]]})

    ## 3. Shape information is only known for the sequence length.
    q_values = tf.placeholder(tf.float32, shape=[2, None, None])
    actions = tf.placeholder(tf.int32, shape=[2, None])
    values = indexing_ops.batched_index(q_values, actions)
    with self.test_session() as sess:
      # Correct and compatible Q values and indices work as intended.
      self.assertAllClose(
          [[51, 52], [54, 57]],
          sess.run(values,
                   feed_dict={q_values: [[[50, 51], [52, 53]],
                                         [[54, 55], [56, 57]]],
                              actions: [[1, 0], [0, 1]]}))

      # !!!DANGER!!! This "incompatible" shape is silently tolerated!
      sess.run(values, feed_dict={q_values: qs_2x2x2, actions: [[0], [0]]})

    ## 4a. Shape information is only known for the number of actions (2-D case).
    q_values = tf.placeholder(tf.float32, shape=[None, 2])
    actions = tf.placeholder(tf.int32, shape=[None])
    values = indexing_ops.batched_index(q_values, actions)
    with self.test_session() as sess:
      # Correct and compatible Q values and indices work as intended.
      self.assertAllClose(
          [51, 52],
          sess.run(values,
                   feed_dict={q_values: [[50, 51], [52, 53]], actions: [1, 0]}))

      # !!!DANGER!!! This "incompatible" shape is silently tolerated!
      sess.run(values, feed_dict={q_values: qs_2x2, actions: [0]})

    ## 4b. Shape information is only known for the number of actions (3-D case).
    q_values = tf.placeholder(tf.float32, shape=[None, None, 2])
    actions = tf.placeholder(tf.int32, shape=[None, None])
    values = indexing_ops.batched_index(q_values, actions)
    with self.test_session() as sess:
      # Correct and compatible Q values and indices work as intended.
      self.assertAllClose(
          [[51, 52]],
          sess.run(values,
                   feed_dict={q_values: [[[50, 51], [52, 53]]],
                              actions: [[1, 0]]}))

      # !!!DANGER!!! These "incompatible" shapes are silently tolerated!
      sess.run(values, feed_dict={q_values: qs_2x2x2, actions: [[0, 0]]})
      sess.run(values, feed_dict={q_values: qs_2x2x2, actions: [[0]]})

    ## 5a. Value shape is not known ahead of time.
    q_values = tf.placeholder(tf.float32)
    actions = tf.placeholder(tf.int32, shape=[2])
    values = indexing_ops.batched_index(q_values, actions)
    with self.test_session() as sess:
      # Correct and compatible Q values and indices work as intended.
      self.assertAllClose(
          [51, 52],
          sess.run(values,
                   feed_dict={q_values: [[50, 51], [52, 53]], actions: [1, 0]}))

      # !!!DANGER!!! This "incompatible" shape is silently tolerated!
      sess.run(values, feed_dict={q_values: qs_2x2x2, actions: [0, 0]})

    ## 5b. Action shape is not known ahead of time.
    q_values = tf.placeholder(tf.float32, shape=[None, None, 2])
    actions = tf.placeholder(tf.int32)
    values = indexing_ops.batched_index(q_values, actions)
    with self.test_session() as sess:
      # Correct and compatible Q values and indices work as intended.
      self.assertAllClose(
          [[51, 52], [54, 57]],
          sess.run(values,
                   feed_dict={q_values: [[[50, 51], [52, 53]],
                                         [[54, 55], [56, 57]]],
                              actions: [[1, 0], [0, 1]]}))

      # !!!DANGER!!! This "incompatible" shape is silently tolerated!
      sess.run(values, feed_dict={q_values: qs_2x2x2, actions: [0, 0]})
Exemple #24
0
def retrace_core(lambda_,
                 q_tm1,
                 a_tm1,
                 r_t,
                 pcont_t,
                 target_policy_t,
                 behaviour_policy_t,
                 targnet_q_t,
                 a_t,
                 stop_targnet_gradients=True,
                 name=None):
    """Retrace algorithm core loss calculation op.

  Given a minibatch of temporally-contiguous sequences of Q values, policy
  probabilities, and various other typical RL algorithm inputs, this
  Op creates a subgraph that computes a loss according to the
  Retrace multi-step off-policy value learning algorithm. This Op supports the
  use of target networks, but does not require them.

  This function is the "core" Retrace op only because its arguments are less
  user-friendly and more implementation-convenient. For a more user-friendly
  operator, consider using `retrace`. For more details of Retrace, refer to
  [the arXiv paper](http://arxiv.org/abs/1606.02647).

  Construct the "core" retrace loss subgraph for a batch of sequences.

  Note that two pairs of arguments (one holding target network values; the
  other, actions) are temporally-offset versions of each other and will share
  many values in common (nb: a good setting for using `IndexedSlices`). *This
  op does not include any checks that these pairs of arguments are
  consistent*---that is, it does not ensure that temporally-offset
  arguments really do share the values they are supposed to share.

  In argument descriptions, `T` counts the number of transitions over which
  the Retrace loss is computed, and `B` is the minibatch size. All tensor
  arguments are indexed first by transition, with specific details of this
  indexing in the argument descriptions (pay close attention to "subscripts"
  in variable names).

  Args:
    lambda_: Positive scalar value or 0-D `Tensor` controlling the degree to
      which future timesteps contribute to the loss computed at each
      transition.
    q_tm1: 3-D tensor holding per-action Q-values for the states encountered
      just before taking the transitions that correspond to each major index.
      Since these values are the predicted values we wish to update (in other
      words, the values we intend to change as we learn), in a target network
      setting, these nearly always come from the "non-target" network, which
      we usually call the "learning network".
      Shape is `[T, B, num_actions]`.
    a_tm1: 2-D tensor holding the indices of actions executed during the
      transition that corresponds to each major index.
      Shape is `[T, B]`.
    r_t: 2-D tensor holding rewards received during the transition
      that corresponds to each major index.
      Shape is `[T, B]`.
    pcont_t: 2-D tensor holding pcontinue values received during the
      transition that corresponds to each major index.
      Shape is `[T, B]`.
    target_policy_t: 3-D tensor holding per-action policy probabilities for
      the states encountered just AFTER the transitions that correspond to
      each major index, according to the target policy (i.e. the policy we
      wish to learn). These usually derive from the learning net.
      Shape is `[T, B, num_actions]`.
    behaviour_policy_t: 2-D tensor holding the *behaviour* policy's
      probabilities of having taken action `a_t` at the states encountered
      just AFTER the transitions that correspond to each major index. Derived
      from whatever policy you used to generate the data. All values MUST be
      greater that 0. Shape is `[T, B]`.
    targnet_q_t: 3-D tensor holding per-action Q-values for the states
      encountered just AFTER taking the transitions that correspond to each
      major index. Since these values are used to calculate target values for
      the network, in a target in a target network setting, these should
      probably come from the target network.
      Shape is `[T, B, num_actions]`.
    a_t: 2-D tensor holding the indices of actions executed during the
      transition AFTER the transition that corresponds to each major index.
      Shape is `[T, B]`.
    stop_targnet_gradients: `bool` that enables a sensible default way of
      handling gradients through the Retrace op (essentially, gradients
      are not permitted to involve the `targnet_q_t` input).
      Can be disabled if you require a different arragement, but
      you'll probably want to block some gradients somewhere.
    name: name to prefix ops created by this function.

  Returns:
    A namedtuple with fields:

    * `loss`: Tensor containing the batch of losses, shape `[B]`.
    * `extra`: A namedtuple with fields:
        * `retrace_weights`: Tensor containing batch of retrace weights,
        shape `[T, B]`.
        * `target`: Tensor containing target action values, shape `[T, B]`.
  """
    all_args = [
        lambda_, q_tm1, a_tm1, r_t, pcont_t, target_policy_t,
        behaviour_policy_t, targnet_q_t, a_t
    ]

    with tf.name_scope(name, 'RetraceCore', all_args):
        (lambda_, q_tm1, a_tm1, r_t, pcont_t, target_policy_t,
         behaviour_policy_t, targnet_q_t, a_t) = (tf.convert_to_tensor(arg)
                                                  for arg in all_args)

        # Evaluate importance weights.
        c_t = _retrace_weights(
            indexing_ops.batched_index(target_policy_t, a_t),
            behaviour_policy_t) * lambda_
        # Targets are evaluated by using only Q values from the target network.
        # This provides fixed regression targets until the next target network
        # update.
        target = _general_off_policy_corrected_multistep_target(
            r_t, pcont_t, target_policy_t, c_t, targnet_q_t, a_t,
            not stop_targnet_gradients)

        if stop_targnet_gradients:
            target = tf.stop_gradient(target)
        # Regress Q values of the learning network towards the targets evaluated
        # by using the target network.
        qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)
        delta = target - qa_tm1
        loss = 0.5 * tf.square(delta)

        return base_ops.LossOutput(
            loss, RetraceCoreExtra(retrace_weights=c_t, target=target))
Exemple #25
0
    def __init__(
        self,
        obs_spec: dm_env.specs.Array,
        action_spec: dm_env.specs.BoundedArray,
        q_network: snt.AbstractModule,
        target_q_network: snt.AbstractModule,
        rho_network: snt.AbstractModule,
        l_network: Sequence[snt.AbstractModule],
        target_l_network: Sequence[snt.AbstractModule],
        batch_size: int,
        discount: float,
        replay_capacity: int,
        min_replay_size: int,
        sgd_period: int,
        target_update_period: int,
        optimizer_primal: tf.train.Optimizer,
        optimizer_dual: tf.train.Optimizer,
        optimizer_l: tf.train.Optimizer,
        learn_iters: int,
        l_approximators: int,
        min_l: float,
        kappa: float,
        eta1: float,
        eta2: float,
        seed: int = None,
    ):
        """Information seeking learner."""
        # ISL configurations.
        self.q_network = q_network
        self._target_q_network = target_q_network
        self.rho_network = rho_network
        self.l_network = l_network
        self._target_l_network = target_l_network
        self._num_actions = action_spec.maximum - action_spec.minimum + 1
        self._obs_shape = obs_spec.shape
        self._batch_size = batch_size
        self._sgd_period = sgd_period
        self._target_update_period = target_update_period
        self._optimizer_primal = optimizer_primal
        self._optimizer_dual = optimizer_dual
        self._optimizer_l = optimizer_l
        self._min_replay_size = min_replay_size
        self._replay = replay.Replay(
            capacity=replay_capacity
        )  #ISLReplay(capacity=replay_capacity, average_l=0, mu=0)  #
        self._rng = np.random.RandomState(seed)
        tf.set_random_seed(seed)
        self._kappa = kappa
        self._min_l = min_l
        self._eta1 = eta1
        self._eta2 = eta2
        self._learn_iters = learn_iters
        self._l_approximators = l_approximators
        self._total_steps = 0
        self._total_episodes = 0
        self._learn_iter_counter = 0

        # Making the tensorflow graph
        o = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype)
        q = q_network(tf.expand_dims(o, 0))
        rho = rho_network(tf.expand_dims(o, 0))
        l = []
        for k in range(self._l_approximators):
            l.append(
                tf.concat([
                    l_network[k][a](tf.expand_dims(o, 0))
                    for a in range(self._num_actions)
                ],
                          axis=1))

        # Placeholders = (obs, action, reward, discount, next_obs)
        o_tm1 = tf.placeholder(shape=(None, ) + obs_spec.shape,
                               dtype=obs_spec.dtype)
        a_tm1 = tf.placeholder(shape=(None, ), dtype=action_spec.dtype)
        r_t = tf.placeholder(shape=(None, ), dtype=tf.float32)
        d_t = tf.placeholder(shape=(None, ), dtype=tf.float32)
        o_t = tf.placeholder(shape=(None, ) + obs_spec.shape,
                             dtype=obs_spec.dtype)
        chosen_l = tf.placeholder(shape=1,
                                  dtype=tf.int32,
                                  name='chosen_l_tensor')

        q_tm1 = q_network(o_tm1)
        rho_tm1 = rho_network(o_tm1)
        train_q_value = batched_index(q_tm1, a_tm1)
        train_rho_value = batched_index(rho_tm1, a_tm1)
        train_rho_value_no_grad = tf.stop_gradient(train_rho_value)
        if self._target_update_period > 1:
            q_t = target_q_network(o_t)
        else:
            q_t = q_network(o_t)

        l_tm1_all = tf.stack([
            tf.concat([
                self.l_network[k][a](o_tm1) for a in range(self._num_actions)
            ],
                      axis=1) for k in range(self._l_approximators)
        ],
                             axis=-1)
        l_tm1 = tf.squeeze(tf.gather(l_tm1_all, chosen_l, axis=-1), axis=-1)
        train_l_value = batched_index(l_tm1, a_tm1)

        if self._target_update_period > 1:
            l_online_t_all = tf.stack([
                tf.concat([
                    self.l_network[k][a](o_t) for a in range(self._num_actions)
                ],
                          axis=1) for k in range(self._l_approximators)
            ],
                                      axis=-1)
            l_online_t = tf.squeeze(tf.gather(l_online_t_all,
                                              chosen_l,
                                              axis=-1),
                                    axis=-1)
            l_t_all = tf.stack([
                tf.concat([
                    self._target_l_network[k][a](o_t)
                    for a in range(self._num_actions)
                ],
                          axis=1) for k in range(self._l_approximators)
            ],
                               axis=-1)
            l_t = tf.squeeze(tf.gather(l_t_all, chosen_l, axis=-1), axis=-1)
            max_ind = tf.math.argmax(l_online_t, axis=1)
        else:
            l_t_all = tf.stack([
                tf.concat([
                    self.l_network[k][a](o_t) for a in range(self._num_actions)
                ],
                          axis=1) for k in range(self._l_approximators)
            ],
                               axis=-1)
            l_t = tf.squeeze(tf.gather(l_t_all, chosen_l, axis=-1), axis=-1)
            max_ind = tf.math.argmax(l_t, axis=1)

        soft_max_value = tf.stop_gradient(
            tf.py_function(func=self.soft_max, inp=[q_t, l_t],
                           Tout=tf.float32))
        q_target_value = r_t + discount * d_t * soft_max_value
        delta_primal = train_q_value - q_target_value
        loss_primal = tf.add(eta2 * train_rho_value_no_grad * delta_primal,
                             (1 - eta2) * 0.5 * tf.square(delta_primal),
                             name='loss_q')

        delta_dual = tf.stop_gradient(delta_primal)
        loss_dual = tf.square(delta_dual - train_rho_value, name='loss_rho')

        l_greedy_estimate = tf.add((1 - eta1) * tf.math.abs(delta_primal),
                                   eta1 * tf.math.abs(train_rho_value_no_grad),
                                   name='l_greedy_estimate')
        l_target_value = tf.stop_gradient(
            l_greedy_estimate + discount * d_t * batched_index(l_t, max_ind),
            name='l_target')
        loss_l = 0.5 * tf.square(train_l_value - l_target_value)

        train_op_primal = self._optimizer_primal.minimize(loss_primal)
        train_op_dual = self._optimizer_dual.minimize(loss_dual)
        train_op_l = self._optimizer_l.minimize(loss_l)

        # create target update operations
        if self._target_update_period > 1:
            target_updates = []
            target_update = update_target_variables(
                target_variables=self._target_q_network.get_all_variables(),
                source_variables=self.q_network.get_all_variables(),
            )
            target_updates.append(target_update)
            for k in range(self._l_approximators):
                for a in range(self._num_actions):
                    model = self.l_network[k][a]
                    target_model = self._target_l_network[k][a]
                    target_update = update_target_variables(
                        target_variables=target_model.get_all_variables(),
                        source_variables=model.get_all_variables(),
                    )
                    target_updates.append(target_update)

        # Make session and callables.
        session = tf.Session()
        self._sgd = session.make_callable(
            [train_op_l, train_op_primal, train_op_dual],
            [o_tm1, a_tm1, r_t, d_t, o_t, chosen_l])
        self._q_fn = session.make_callable(q, [o])
        self._rho_fn = session.make_callable(rho, [o])
        self._l_fn = []
        for k in range(self._l_approximators):
            self._l_fn.append(session.make_callable(l[k], [o]))
        if self._target_update_period > 1:
            self._update_target_nets = session.make_callable(target_updates)
        session.run(tf.global_variables_initializer())