Esempio n. 1
0
def td_lambda(
    v_tm1: ArrayLike,
    r_t: ArrayLike,
    discount_t: ArrayLike,
    v_t: ArrayLike,
    lambda_: ArrayOrScalar,
) -> ArrayLike:
    """Calculates the TD(lambda) temporal difference error.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/book/ebook/node74.html).

  Args:
    v_tm1: sequence of state values at time t-1.
    r_t: sequence of rewards at time t.
    discount_t: sequence of discounts at time t.
    v_t: sequence of state values at time t.
    lambda_: mixing parameter lambda, either a scalar or a sequence.

  Returns:
    TD(lambda) temporal difference error.
  """
    base.rank_assert([v_tm1, r_t, discount_t, v_t, lambda_],
                     [1, 1, 1, 1, [0, 1]])
    base.type_assert([v_tm1, r_t, discount_t, v_t, lambda_], float)

    target_tm1 = multistep.lambda_returns(r_t, discount_t, v_t, lambda_)
    return jax.lax.stop_gradient(target_tm1) - v_tm1
Esempio n. 2
0
def td_lambda(
    v_tm1: Array,
    r_t: Array,
    discount_t: Array,
    v_t: Array,
    lambda_: Numeric,
    stop_target_gradients: bool = True,
) -> Array:
    """Calculates the TD(lambda) temporal difference error.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/book/ebook/node74.html).

  Args:
    v_tm1: sequence of state values at time t-1.
    r_t: sequence of rewards at time t.
    discount_t: sequence of discounts at time t.
    v_t: sequence of state values at time t.
    lambda_: mixing parameter lambda, either a scalar or a sequence.
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.
  Returns:
    TD(lambda) temporal difference error.
  """
    chex.assert_rank([v_tm1, r_t, discount_t, v_t, lambda_],
                     [1, 1, 1, 1, {0, 1}])
    chex.assert_type([v_tm1, r_t, discount_t, v_t, lambda_], float)

    target_tm1 = multistep.lambda_returns(r_t, discount_t, v_t, lambda_)
    target_tm1 = jax.lax.select(stop_target_gradients,
                                jax.lax.stop_gradient(target_tm1), target_tm1)
    return target_tm1 - v_tm1
Esempio n. 3
0
 def test_reduces_to_lambda_returns(self):
     """Test function is the same as lambda_returns when n is sequence length."""
     lambda_t = 0.75
     n = len(self.r_t[0])
     expected = multistep.lambda_returns(self.r_t[0], self.discount_t[0],
                                         self.v_t[0], lambda_t)
     actual = multistep.n_step_bootstrapped_returns(self.r_t[0],
                                                    self.discount_t[0],
                                                    self.v_t[0], n,
                                                    lambda_t)
     np.testing.assert_allclose(expected, actual, rtol=1e-5)
Esempio n. 4
0
def sarsa_lambda(
    q_tm1: ArrayLike,
    a_tm1: ArrayLike,
    r_t: ArrayLike,
    discount_t: ArrayLike,
    q_t: ArrayLike,
    a_t: ArrayLike,
    lambda_: ArrayOrScalar,
    stop_target_gradients: bool = True,
) -> ArrayLike:
    """Calculates the SARSA(lambda) temporal difference error.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/book/ebook/node77.html).

  Args:
    q_tm1: sequence of Q-values at time t-1.
    a_tm1: sequence of action indices at time t-1.
    r_t: sequence of rewards at time t.
    discount_t: sequence of discounts at time t.
    q_t: sequence of Q-values at time t.
    a_t: sequence of action indices at time t.
    lambda_: mixing parameter lambda, either a scalar or a sequence.
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.

  Returns:
    SARSA(lambda) temporal difference error.
  """
    base.rank_assert([q_tm1, a_tm1, r_t, discount_t, q_t, a_t, lambda_],
                     [2, 1, 1, 1, 2, 1, [0, 1]])
    base.type_assert([q_tm1, a_tm1, r_t, discount_t, q_t, a_t, lambda_],
                     [float, int, float, float, float, int, float])

    qa_tm1 = base.batched_index(q_tm1, a_tm1)
    qa_t = base.batched_index(q_t, a_t)
    target_tm1 = multistep.lambda_returns(r_t, discount_t, qa_t, lambda_)

    if stop_target_gradients:
        target_tm1 = jax.lax.stop_gradient(target_tm1)
    return target_tm1 - qa_tm1
Esempio n. 5
0
def q_lambda(
    q_tm1: Array,
    a_tm1: Array,
    r_t: Array,
    discount_t: Array,
    q_t: Array,
    lambda_: Numeric,
    stop_target_gradients: bool = True,
) -> Array:
    """Calculates Peng's or Watkins' Q(lambda) temporal difference error.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/book/ebook/node78.html).

  Args:
    q_tm1: sequence of Q-values at time t-1.
    a_tm1: sequence of action indices at time t-1.
    r_t: sequence of rewards at time t.
    discount_t: sequence of discounts at time t.
    q_t: sequence of Q-values at time t.
    lambda_: mixing parameter lambda, either a scalar (e.g. Peng's Q(lambda)) or
      a sequence (e.g. Watkin's Q(lambda)).
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.

  Returns:
    Q(lambda) temporal difference error.
  """
    chex.assert_rank([q_tm1, a_tm1, r_t, discount_t, q_t, lambda_],
                     [2, 1, 1, 1, 2, {0, 1}])
    chex.assert_type([q_tm1, a_tm1, r_t, discount_t, q_t, lambda_],
                     [float, int, float, float, float, float])

    qa_tm1 = base.batched_index(q_tm1, a_tm1)
    v_t = jnp.max(q_t, axis=-1)
    target_tm1 = multistep.lambda_returns(r_t, discount_t, v_t, lambda_)

    target_tm1 = jax.lax.select(stop_target_gradients,
                                jax.lax.stop_gradient(target_tm1), target_tm1)
    return target_tm1 - qa_tm1