Пример #1
0
def persistent_q_learning(
    q_tm1: ArrayLike,
    a_tm1: ArrayOrScalar,
    r_t: ArrayOrScalar,
    discount_t: ArrayOrScalar,
    q_t: ArrayLike,
    action_gap_scale: float,
) -> ArrayOrScalar:
    """Calculates the persistent Q-learning temporal difference error.

  See "Increasing the Action Gap: New Operators for Reinforcement Learning"
  by Bellemare, Ostrovski, Guez et al. (https://arxiv.org/abs/1512.04860).

  Args:
    q_tm1: Q-values at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_t: Q-values at time t.
    action_gap_scale: coefficient in [0, 1] for scaling the action gap term.

  Returns:
    Persistent Q-learning temporal difference error.
  """
    base.rank_assert([q_tm1, a_tm1, r_t, discount_t, q_t], [1, 0, 0, 0, 1])
    base.type_assert([q_tm1, a_tm1, r_t, discount_t, q_t],
                     [float, int, float, float, float])

    corrected_q_t = ((1. - action_gap_scale) * jnp.max(q_t) +
                     action_gap_scale * q_t[a_tm1])
    target_tm1 = r_t + discount_t * corrected_q_t
    return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
Пример #2
0
def sarsa(
    q_tm1: ArrayLike,
    a_tm1: ArrayOrScalar,
    r_t: ArrayOrScalar,
    discount_t: ArrayOrScalar,
    q_t: ArrayLike,
    a_t: ArrayOrScalar,
) -> ArrayOrScalar:
    """Calculates the SARSA temporal difference error.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/book/ebook/node64.html.)

  Args:
    q_tm1: Q-values at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_t: Q-values at time t.
    a_t: action index at time t.

  Returns:
    SARSA temporal difference error.
  """
    base.rank_assert([q_tm1, a_tm1, r_t, discount_t, q_t, a_t],
                     [1, 0, 0, 0, 1, 0])
    base.type_assert([q_tm1, a_tm1, r_t, discount_t, q_t, a_t],
                     [float, int, float, float, float, int])

    target_tm1 = r_t + discount_t * q_t[a_t]
    return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
Пример #3
0
def td_learning(
    v_tm1: ArrayOrScalar,
    r_t: ArrayOrScalar,
    discount_t: ArrayOrScalar,
    v_t: ArrayOrScalar,
) -> ArrayOrScalar:
    """Calculates the TD-learning temporal difference error.

  See "Learning to Predict by the Methods of Temporal Differences" by Sutton.
  (https://link.springer.com/article/10.1023/A:1022633531479).

  Args:
    v_tm1: state values at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    v_t: state values at time t.

  Returns:
    TD-learning temporal difference error.
  """

    base.rank_assert([v_tm1, r_t, discount_t, v_t], 0)
    base.type_assert([v_tm1, r_t, discount_t, v_t], float)

    target_tm1 = r_t + discount_t * v_t
    return jax.lax.stop_gradient(target_tm1) - v_tm1
Пример #4
0
def qv_max(
    v_tm1: ArrayOrScalar,
    r_t: ArrayOrScalar,
    discount_t: ArrayOrScalar,
    q_t: ArrayLike,
) -> ArrayOrScalar:
    """Calculates the QVMAX temporal difference error.

  See "The QV Family Compared to Other Reinforcement Learning Algorithms" by
  Wiering and van Hasselt (2009).
  (http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.713.1931)

  Args:
    v_tm1: state values at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_t: Q-values at time t.

  Returns:
    QVMAX temporal difference error.
  """
    base.rank_assert([v_tm1, r_t, discount_t, q_t], [0, 0, 0, 1])
    base.type_assert([v_tm1, r_t, discount_t, q_t], float)

    target_tm1 = r_t + discount_t * jnp.max(q_t)
    return jax.lax.stop_gradient(target_tm1) - v_tm1
Пример #5
0
def expected_sarsa(
    q_tm1: ArrayLike,
    a_tm1: ArrayOrScalar,
    r_t: ArrayOrScalar,
    discount_t: ArrayOrScalar,
    q_t: ArrayLike,
    probs_a_t: ArrayLike,
) -> ArrayOrScalar:
    """Calculates the expected SARSA (SARSE) temporal difference error.

  See "A Theoretical and Empirical Analysis of Expected Sarsa" by Seijen,
  van Hasselt, Whiteson et al.
  (http://www.cs.ox.ac.uk/people/shimon.whiteson/pubs/vanseijenadprl09.pdf).

  Args:
    q_tm1: Q-values at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_t: Q-values at time t.
    probs_a_t: action probabilities at time t.

  Returns:
    Expected SARSA temporal difference error.
  """
    base.rank_assert([q_tm1, a_tm1, r_t, discount_t, q_t, probs_a_t],
                     [1, 0, 0, 0, 1, 1])
    base.type_assert([q_tm1, a_tm1, r_t, discount_t, q_t, probs_a_t],
                     [float, int, float, float, float, float])

    target_tm1 = r_t + discount_t * jnp.dot(q_t, probs_a_t)
    return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
Пример #6
0
def discounted_returns(
    r_t: ArrayLike,
    discount_t: ArrayLike,
    v_t: ArrayLike
) -> ArrayLike:
  """Calculates a discounted return from a trajectory.

  The returns are computed recursively, from `G_{T-1}` to `G_0`, according to:

    Gₜ = rₜ₊₁ + γₜ₊₁ Gₜ₊₁.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/sutton/book/ebook/node61.html).

  Args:
    r_t: reward sequence at time t.
    discount_t: discount sequence at time t.
    v_t: value sequence or scalar at time t.

  Returns:
    Discounted returns.
  """
  base.rank_assert([r_t, discount_t, v_t], [1, 1, [0, 1]])
  base.type_assert([r_t, discount_t, v_t], float)

  # If scalar make into vector.
  bootstrapped_v = jnp.ones_like(discount_t) * v_t
  return lambda_returns(r_t, discount_t, bootstrapped_v, lambda_=1.)
Пример #7
0
def double_q_learning(
    q_tm1: ArrayLike,
    a_tm1: ArrayOrScalar,
    r_t: ArrayOrScalar,
    discount_t: ArrayOrScalar,
    q_t_value: ArrayLike,
    q_t_selector: ArrayLike,
) -> ArrayOrScalar:
    """Calculates the double Q-learning temporal difference error.

  See "Double Q-learning" by van Hasselt.
  (https://papers.nips.cc/paper/3964-double-q-learning.pdf).

  Args:
    q_tm1: Q-values at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_t_value: Q-values at time t.
    q_t_selector: selector Q-values at time t.

  Returns:
    Double Q-learning temporal difference error.
  """
    base.rank_assert([q_tm1, a_tm1, r_t, discount_t, q_t_value, q_t_selector],
                     [1, 0, 0, 0, 1, 1])
    base.type_assert([q_tm1, a_tm1, r_t, discount_t, q_t_value, q_t_selector],
                     [float, int, float, float, float, float])

    target_tm1 = r_t + discount_t * q_t_value[q_t_selector.argmax()]
    return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
Пример #8
0
def qv_learning(
    q_tm1: ArrayLike,
    a_tm1: ArrayOrScalar,
    r_t: ArrayOrScalar,
    discount_t: ArrayOrScalar,
    v_t: ArrayOrScalar,
) -> ArrayOrScalar:
    """Calculates the QV-learning temporal difference error.

  See "Two Novel On-policy Reinforcement Learning Algorithms based on
  TD(lambda)-methods" by Wiering and van Hasselt
  (https://ieeexplore.ieee.org/abstract/document/4220845.)

  Args:
    q_tm1: Q-values at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    v_t: state values at time t.

  Returns:
    QV-learning temporal difference error.
  """
    base.rank_assert([q_tm1, a_tm1, r_t, discount_t, v_t], [1, 0, 0, 0, 0])
    base.type_assert([q_tm1, a_tm1, r_t, discount_t, v_t],
                     [float, int, float, float, float])

    target_tm1 = r_t + discount_t * v_t
    return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
Пример #9
0
def _categorical_l2_project(
    z_p: ArrayLike,
    probs: ArrayLike,
    z_q: ArrayLike
) -> ArrayLike:
  """Projects a categorical distribution (z_p, p) onto a different support z_q.

  The projection step minimizes an L2-metric over the cumulative distribution
  functions (CDFs) of the source and target distributions.

  Let kq be len(z_q) and kp be len(z_p). This projection works for any
  support z_q, in particular kq need not be equal to kp.

  See "A Distributional Perspective on RL" by Bellemare et al.
  (https://arxiv.org/abs/1707.06887).

  Args:
    z_p: support of distribution p.
    probs: probability values.
    z_q: support to project distribution (z_p, probs) onto.

  Returns:
    Projection of (z_p, p) onto support z_q under Cramer distance.
  """
  base.rank_assert([z_p, probs, z_q], 1)
  base.type_assert([z_p, probs, z_q], float)

  kp = z_p.shape[0]
  kq = z_q.shape[0]

  # Construct helper arrays from z_q.
  d_pos = jnp.roll(z_q, shift=-1)
  d_neg = jnp.roll(z_q, shift=1)

  # Clip z_p to be in new support range (vmin, vmax).
  z_p = jnp.clip(z_p, z_q[0], z_q[-1])[None, :]
  assert z_p.shape == (1, kp)

  # Get the distance between atom values in support.
  d_pos = (d_pos - z_q)[:, None]  # z_q[i+1] - z_q[i]
  d_neg = (z_q - d_neg)[:, None]  # z_q[i] - z_q[i-1]
  z_q = z_q[:, None]
  assert z_q.shape == (kq, 1)

  # Ensure that we do not divide by zero, in case of atoms of identical value.
  d_neg = jnp.where(d_neg > 0, 1. / d_neg, jnp.zeros_like(d_neg))
  d_pos = jnp.where(d_pos > 0, 1. / d_pos, jnp.zeros_like(d_pos))

  delta_qp = z_p - z_q  # clip(z_p)[j] - z_q[i]
  d_sign = (delta_qp >= 0.).astype(probs.dtype)
  assert delta_qp.shape == (kq, kp)
  assert d_sign.shape == (kq, kp)

  # Matrix of entries sgn(a_ij) * |a_ij|, with a_ij = clip(z_p)[j] - z_q[i].
  delta_hat = (d_sign * delta_qp * d_pos) - ((1. - d_sign) * delta_qp * d_neg)
  probs = probs[None, :]
  assert delta_hat.shape == (kq, kp)
  assert probs.shape == (1, kp)

  return jnp.sum(jnp.clip(1. - delta_hat, 0., 1.) * probs, axis=-1)
Пример #10
0
def policy_gradient_loss(
    logits_t: ArrayLike,
    a_t: ArrayLike,
    adv_t: ArrayLike,
    w_t: ArrayLike,
) -> ArrayLike:
    """Calculates the policy gradient loss.

  See "Simple Gradient-Following Algorithms for Connectionist RL" by Williams.
  (http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf)

  Args:
    logits_t: a sequence of unnormalized action preferences.
    a_t: a sequence of actions sampled from the preferences `logits_t`.
    adv_t: the observed or estimated advantages from executing actions `a_t`.
    w_t: a per timestep weighting for the loss.

  Returns:
    Loss whose gradient corresponds to a policy gradient update.
  """
    base.rank_assert([logits_t, a_t, adv_t, w_t], [2, 1, 1, 1])
    base.type_assert([logits_t, a_t, adv_t, w_t], [float, int, float, float])

    log_pi_a = distributions.softmax().logprob(a_t, logits_t)
    adv_t = jax.lax.stop_gradient(adv_t)
    loss_per_timestep = -log_pi_a * adv_t
    return jnp.mean(loss_per_timestep * w_t)
Пример #11
0
def td_lambda(
    v_tm1: ArrayLike,
    r_t: ArrayLike,
    discount_t: ArrayLike,
    v_t: ArrayLike,
    lambda_: ArrayOrScalar,
) -> ArrayLike:
    """Calculates the TD(lambda) temporal difference error.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/book/ebook/node74.html).

  Args:
    v_tm1: sequence of state values at time t-1.
    r_t: sequence of rewards at time t.
    discount_t: sequence of discounts at time t.
    v_t: sequence of state values at time t.
    lambda_: mixing parameter lambda, either a scalar or a sequence.

  Returns:
    TD(lambda) temporal difference error.
  """
    base.rank_assert([v_tm1, r_t, discount_t, v_t, lambda_],
                     [1, 1, 1, 1, [0, 1]])
    base.type_assert([v_tm1, r_t, discount_t, v_t, lambda_], float)

    target_tm1 = multistep.lambda_returns(r_t, discount_t, v_t, lambda_)
    return jax.lax.stop_gradient(target_tm1) - v_tm1
Пример #12
0
def q_learning(
    q_tm1: ArrayLike,
    a_tm1: ArrayLike,
    r_t: ArrayLike,
    discount_t: ArrayLike,
    q_t: ArrayLike,
) -> ArrayLike:
  """Calculates the Q-learning temporal difference error.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/book/ebook/node65.html).

  Args:
    q_tm1: Q-values at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_t: Q-values at time t.

  Returns:
    Q-learning temporal difference error.
  """
  base.rank_assert([q_tm1, a_tm1, r_t, discount_t, q_t], [1, 0, 0, 0, 1])
  base.type_assert([q_tm1, a_tm1, r_t, discount_t, q_t],
                   [float, int, float, float, float])

  target_tm1 = r_t + discount_t * jnp.max(q_t)
  return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
Пример #13
0
def add_ornstein_uhlenbeck_noise(key: ArrayLike, action: ArrayLike,
                                 noise_tm1: ArrayLike, damping: float,
                                 stddev: float) -> ArrayLike:
    """Returns continuous action with noise from Ornstein-Uhlenbeck process.

  See "On the theory of Brownian Motion" by Uhlenbeck and Ornstein.
  (https://journals.aps.org/pr/abstract/10.1103/PhysRev.36.823).

  Args:
    key: a key from `jax.random`.
    action: continuous action scalar or vector.
    noise_tm1: noise sampled from OU process in previous timestep.
    damping: parameter for controlling autocorrelation of OU process.
    stddev: standard deviation of noise distribution.

  Returns:
    noisy action, of the same shape as input action.
  """
    base.rank_assert([action, noise_tm1], 1)
    base.type_assert([action, noise_tm1], float)

    noise_t = (1. - damping) * noise_tm1 + jax.random.normal(
        key, shape=action.shape) * stddev

    return action + noise_t
Пример #14
0
def categorical_kl_divergence(
    p_logits: ArrayLike,
    q_logits: ArrayLike,
    temperature: float = 1.
) -> ArrayLike:
  """Compute the KL between two categorical distributions from their logits.

  Args:
    p_logits: unnormalized logits for the first distribution.
    q_logits: unnormalized logits for the second distribution.
    temperature: the temperature for the softmax distribution, defaults at 1.

  Returns:
    the kl divergence between the distributions.
  """
  base.type_assert([p_logits, q_logits], float)

  p_logits /= temperature
  q_logits /= temperature

  p = jax.nn.softmax(p_logits)
  log_p = jax.nn.log_softmax(p_logits)
  log_q = jax.nn.log_softmax(q_logits)
  kl = jnp.sum(p * (log_p - log_q), axis=-1)
  return jax.nn.relu(kl)  # Guard against numerical issues giving negative KL.
Пример #15
0
 def test_mixed_inputs_should_not_raise(self):
     a_float = 1.
     an_int = 2
     a_np_float = np.asarray([3., 4.])
     a_jax_int = jnp.asarray([5, 6])
     base.type_assert([a_float, an_int, a_np_float, a_jax_int],
                      [float, int, float, int])
Пример #16
0
 def test_unsupported_type_should_raise(self):
     a_float = 1.
     an_int = 2
     a_np_float = np.asarray([3., 4.])
     a_jax_int = jnp.asarray([5, 6])
     with self.assertRaises(ValueError):
         base.type_assert([a_float, an_int, a_np_float, a_jax_int],
                          [np.complex, np.complex, float, int])
Пример #17
0
 def test_different_length_should_raise(self):
     a_float = 1.
     an_int = 2
     a_np_float = np.asarray([3., 4.])
     a_jax_int = jnp.asarray([5, 6])
     with self.assertRaises(ValueError):
         base.type_assert([a_float, an_int, a_np_float, a_jax_int],
                          [int, float, int])
Пример #18
0
 def test_mixed_inputs_should_raise(self):
     a_float = 1.
     an_int = 2
     a_np_float = np.asarray([3., 4.])
     a_jax_int = jnp.asarray([5, 6])
     with self.assertRaises(ValueError):
         base.type_assert([a_float, an_int, a_np_float, a_jax_int],
                          [float, int, float, float])
Пример #19
0
def importance_corrected_td_errors(
    r_t: ArrayLike,
    discount_t: ArrayLike,
    rho_tm1: ArrayLike,
    lambda_: ArrayLike,
    values: ArrayLike,
) -> ArrayLike:
    """Computes the multistep td errors with per decision importance sampling.

  Given a trajectory of length `T+1`, generated under some policy π, for each
  time-step `t` we can estimate a multistep temporal difference error δₜ(ρ,λ),
  by combining rewards, discounts, and state values, according to a mixing
  parameter `λ` and importance sampling ratios ρₜ = π(aₜ|sₜ) / μ(aₜ|sₜ):

    td-errorₜ = ρₜ δₜ(ρ,λ)
    δₜ(ρ,λ) = δₜ + ρₜ₊₁ λₜ₊₁ γₜ₊₁ δₜ₊₁(ρ,λ),

  where δₜ = rₜ₊₁ + γₜ₊₁ vₜ₊₁ - vₜ is the one step, temporal difference error
  for the agent's state value estimates. This is equivalent to computing
  the λ-return with λₜ = ρₜ (e.g. using the `lambda_returns` function from
  above), and then computing errors as  td-errorₜ = ρₜ(Gₜ - vₜ).

  See "A new Q(λ) with interim forward view and Monte Carlo equivalence"
  by Sutton et al. (http://proceedings.mlr.press/v32/sutton14.html).

  Args:
    r_t: sequence of rewards rₜ for timesteps t in [1, T].
    discount_t: sequence of discounts γₜ for timesteps t in [1, T].
    rho_tm1: sequence of importance ratios for all timesteps t in [0, T-1].
    lambda_: mixing parameter; scalar or have per timestep values in [1, T].
    values: sequence of state values under π for all timesteps t in [0, T].

  Returns:
    Off-policy estimates of the multistep lambda returns from each state.
  """
    base.rank_assert([r_t, discount_t, rho_tm1, values], [1, 1, 1, 1])
    base.type_assert([r_t, discount_t, rho_tm1, values], float)

    v_tm1 = values[:-1]  # Predictions to compute errors for.
    v_t = values[1:]  # Values for bootstrapping.
    rho_t = jnp.concatenate(
        (rho_tm1[1:], jnp.array([1.])))  # Unused dummy value.
    lambda_ = jnp.ones_like(
        discount_t) * lambda_  # If scalar, make into vector.

    # Compute the one step temporal difference errors.
    one_step_delta = r_t + discount_t * v_t - v_tm1

    # Work backwards to compute `delta_{T-1}`, ..., `delta_0`.
    delta, errors = 0.0, []
    for i in jnp.arange(one_step_delta.shape[0] - 1, -1, -1):
        delta = one_step_delta[
            i] + discount_t[i] * rho_t[i] * lambda_[i] * delta
        errors.insert(0, delta)

    return rho_tm1 * jnp.array(errors)
Пример #20
0
def categorical_q_learning(
    q_atoms_tm1: ArrayLike,
    q_logits_tm1: ArrayLike,
    a_tm1: ArrayOrScalar,
    r_t: ArrayOrScalar,
    discount_t: ArrayOrScalar,
    q_atoms_t: ArrayLike,
    q_logits_t: ArrayLike,
) -> ArrayOrScalar:
    """Implements Q-learning for categorical Q distributions.

  See "A Distributional Perspective on Reinforcement Learning", by
    Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf).

  Args:
    q_atoms_tm1: atoms of Q distribution at time t-1.
    q_logits_tm1: logits of Q distribution at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_atoms_t: atoms of Q distribution at time t.
    q_logits_t: logits of Q distribution at time t.

  Returns:
    Categorical Q-learning loss (i.e. temporal difference error).
  """
    base.rank_assert([
        q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t,
        q_logits_t
    ], [1, 2, 0, 0, 0, 1, 2])
    base.type_assert([
        q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t,
        q_logits_t
    ], [float, float, int, float, float, float, float])

    # Scale and shift time-t distribution atoms by discount and reward.
    target_z = r_t + discount_t * q_atoms_t

    # Convert logits to distribution, then find greedy action in state s_t.
    q_t_probs = jax.nn.softmax(q_logits_t)
    q_t_mean = jnp.sum(q_t_probs * q_atoms_t[jnp.newaxis, :], axis=1)
    pi_t = jnp.argmax(q_t_mean)

    # Compute distribution for greedy action.
    p_target_z = q_t_probs[pi_t]

    # Project using the Cramer distance.
    target = jax.lax.stop_gradient(
        _categorical_l2_project(target_z, p_target_z, q_atoms_tm1))

    # Compute loss (i.e. temporal difference error).
    logit_qa_tm1 = q_logits_tm1[a_tm1]
    return distributions.categorical_cross_entropy(labels=target,
                                                   logits=logit_qa_tm1)
Пример #21
0
def categorical_double_q_learning(
    q_atoms_tm1: ArrayLike,
    q_logits_tm1: ArrayLike,
    a_tm1: ArrayOrScalar,
    r_t: ArrayOrScalar,
    discount_t: ArrayOrScalar,
    q_atoms_t: ArrayLike,
    q_logits_t: ArrayLike,
    q_t_selector: ArrayLike,
) -> ArrayOrScalar:
    """Implements double Q-learning for categorical Q distributions.

  See "A Distributional Perspective on Reinforcement Learning", by
    Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf)
  and "Double Q-learning" by van Hasselt.
  (https://papers.nips.cc/paper/3964-double-q-learning.pdf).

  Args:
    q_atoms_tm1: atoms of Q distribution at time t-1.
    q_logits_tm1: logits of Q distribution at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_atoms_t: atoms of Q distribution at time t.
    q_logits_t: logits of Q distribution at time t.
    q_t_selector: selector Q-values at time t.

  Returns:
    Categorical double Q-learning loss (i.e. temporal difference error).
  """
    base.rank_assert([
        q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t,
        q_logits_t, q_t_selector
    ], [1, 2, 0, 0, 0, 1, 2, 1])
    base.type_assert([
        q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t,
        q_logits_t, q_t_selector
    ], [float, float, int, float, float, float, float, float])

    # Scale and shift time-t distribution atoms by discount and reward.
    target_z = r_t + discount_t * q_atoms_t

    # Select logits for greedy action in state s_t and convert to distribution.
    p_target_z = jax.nn.softmax(q_logits_t[q_t_selector.argmax()])

    # Project using the Cramer distance.
    target = jax.lax.stop_gradient(
        _categorical_l2_project(target_z, p_target_z, q_atoms_tm1))

    # Compute loss (i.e. temporal difference error).
    logit_qa_tm1 = q_logits_tm1[a_tm1]
    return distributions.categorical_cross_entropy(labels=target,
                                                   logits=logit_qa_tm1)
Пример #22
0
def general_off_policy_returns_from_action_values(
    q_t: ArrayLike,
    a_t: ArrayLike,
    r_t: ArrayLike,
    discount_t: ArrayLike,
    c_t: ArrayLike,
    pi_t: ArrayLike,
) -> ArrayLike:
  """Calculates errors for various off-policy correction algorithms.

  Given a window of experience of length `K`, generated by a behaviour policy μ,
  for each time-step `t` we can estimate the return `G_t` from that step
  onwards, under some target policy π, using the rewards in the trajectory, the
  actions selected by μ and the action-values under π, according to equation:

    Gₜ = rₜ₊₁ + γₜ₊₁ * (E[q(aₜ₊₁)] - cₜ * q(aₜ₊₁) + cₜ * Gₜ₊₁),

  where, depending on the choice of `c_t`, the algorithm implements:

    Importance Sampling             c_t = π(x_t, a_t) / μ(x_t, a_t),
    Harutyunyan's et al. Q(lambda)  c_t = λ,
    Precup's et al. Tree-Backup     c_t = π(x_t, a_t),
    Munos' et al. Retrace           c_t = λ min(1, π(x_t, a_t) / μ(x_t, a_t)).

  See "Safe and Efficient Off-Policy Reinforcement Learning" by Munos et al.
  (https://arxiv.org/abs/1606.02647).

  Args:
    q_t: Q-values at time t.
    a_t: action index at time t.
    r_t: reward at time t.
    discount_t: discount at time t.
    c_t: importance weights at time t.
    pi_t: target policy probs at time t.

  Returns:
    Off-policy estimates of the multistep lambda returns from each state..
  """
  base.rank_assert([q_t, a_t, r_t, discount_t, c_t, pi_t],
                   [2, 1, 1, 1, 1, 2])
  base.type_assert([q_t, a_t, r_t, discount_t, c_t, pi_t],
                   [float, int, float, float, float, float])

  # Get the expected values and the values of actually selected actions.
  exp_q_t = (pi_t * q_t).sum(axis=-1)
  # The generalized returns are independent of Q-values and cs at the final
  # state.
  q_a_t = base.batched_index(q_t, a_t)[:-1]
  c_t = c_t[:-1]

  return general_off_policy_returns_from_q_and_v(
      q_a_t, exp_q_t, r_t, discount_t, c_t)
Пример #23
0
def vtrace(
    v_tm1: ArrayLike,
    v_t: ArrayLike,
    r_t: ArrayLike,
    discount_t: ArrayLike,
    rho_t: ArrayLike,
    lambda_: float = 1.0,
    clip_rho_threshold: float = 1.0,
    stop_target_gradients: bool = True,
) -> ArrayLike:
  """Calculates V-Trace errors from policy logits.

  Args:
    v_tm1: values at time t-1.
    v_t: values at time t.
    r_t: reward at time t.
    discount_t: discount at time t.
    rho_t: importance sampling ratios.
    lambda_: scalar mixing parameter lambda.
    clip_rho_threshold: clip threshold for importance weights.
    stop_target_gradients: whether or not to apply stop gradient to targets.

  Returns:
    V-Trace error.
  """
  base.rank_assert(
      [v_tm1, v_t, r_t, discount_t, rho_t], [1, 1, 1, 1, 1])
  base.type_assert(
      [v_tm1, v_t, r_t, discount_t, rho_t], [float, float, float, float, float])

  # Clip importance sampling ratios.
  c_t = jnp.minimum(1.0, rho_t) * lambda_
  clipped_rhos = jnp.minimum(clip_rho_threshold, rho_t)

  # Compute the temporal difference errors.
  td_errors = clipped_rhos * (r_t + discount_t * v_t - v_tm1)

  # Work backwards computing the td-errors.
  err = 0.0
  errors = []
  for i in jnp.arange(v_t.shape[0] - 1, -1, -1):
    err = td_errors[i] + discount_t[i] * c_t[i] * err
    errors.insert(0, err)

  # Add the value of the initial state to get the estimates of the returns.
  target_tm1 = jnp.array(errors) + v_tm1

  # Stop gradients and return temporal difference error.
  if stop_target_gradients:
    target_tm1 = jax.lax.stop_gradient(target_tm1)
  return target_tm1 - v_tm1
Пример #24
0
def quantile_q_learning(dist_q_tm1: ArrayLike,
                        tau_q_tm1: ArrayLike,
                        a_tm1: ArrayOrScalar,
                        r_t: ArrayOrScalar,
                        discount_t: ArrayOrScalar,
                        dist_q_t_selector: ArrayLike,
                        dist_q_t: ArrayLike,
                        huber_param: float = 0.) -> ArrayOrScalar:
    """Implements Q-learning for quantile-valued Q distributions.

  See "Distributional Reinforcement Learning with Quantile Regression" by
  Dabney et al. (https://arxiv.org/abs/1710.10044).

  Args:
    dist_q_tm1: Q distribution at time t-1.
    tau_q_tm1: Q distribution probability thresholds.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    dist_q_t_selector: Q distribution at time t for selecting greedy action in
      target policy. This is separate from dist_q_t as in Double Q-Learning, but
      can be computed with the target network and a separate set of samples.
    dist_q_t: target Q distribution at time t.
    huber_param: Huber loss parameter, defaults to 0 (no Huber loss).

  Returns:
    Quantile regression Q learning loss.
  """
    base.rank_assert([
        dist_q_tm1, tau_q_tm1, a_tm1, r_t, discount_t, dist_q_t_selector,
        dist_q_t
    ], [2, 1, 0, 0, 0, 2, 2])
    base.type_assert([
        dist_q_tm1, tau_q_tm1, a_tm1, r_t, discount_t, dist_q_t_selector,
        dist_q_t
    ], [float, float, int, float, float, float, float])

    # Only update the taken actions.
    dist_qa_tm1 = dist_q_tm1[:, a_tm1]

    # Select target action according to greedy policy w.r.t. dist_q_t_selector.
    q_t_selector = jnp.mean(dist_q_t_selector, axis=0)
    a_t = jnp.argmax(q_t_selector)
    dist_qa_t = dist_q_t[:, a_t]

    # Compute target, do not backpropagate into it.
    dist_target = r_t + discount_t * dist_qa_t
    dist_target = jax.lax.stop_gradient(dist_target)

    return _quantile_regression_loss(dist_qa_tm1, tau_q_tm1, dist_target,
                                     huber_param)
Пример #25
0
def transformed_retrace(
    q_tm1: ArrayLike,
    q_t: ArrayLike,
    a_tm1: ArrayLike,
    a_t: ArrayLike,
    r_t: ArrayLike,
    discount_t: ArrayLike,
    pi_t: ArrayLike,
    mu_t: ArrayLike,
    lambda_: float,
    eps: float = 1e-8,
    stop_target_gradients: bool = True,
    tx_pair: TxPair = IDENTITY_PAIR,
) -> ArrayLike:
    """Calculates transformed Retrace errors.

  See "Recurrent Experience Replay in Distributed Reinforcement Learning" by
  Kapturowski et al. (https://openreview.net/pdf?id=r1lyTjAqYX).

  Args:
    q_tm1: Q-values at time t-1.
    q_t: Q-values at time t.
    a_tm1: action index at time t-1.
    a_t: action index at time t.
    r_t: reward at time t.
    discount_t: discount at time t.
    pi_t: target policy probs at time t.
    mu_t: behavior policy probs at time t.
    lambda_: scalar mixing parameter lambda.
    eps: small value to add to mu_t for numerical stability.
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.
    tx_pair: TxPair of value function transformation and its inverse.

  Returns:
    Transformed Retrace error.
  """
    base.rank_assert([q_tm1, q_t, a_tm1, a_t, r_t, discount_t, pi_t, mu_t],
                     [2, 2, 1, 1, 1, 1, 2, 1])
    base.type_assert([q_tm1, q_t, a_tm1, a_t, r_t, discount_t, pi_t, mu_t],
                     [float, float, int, int, float, float, float, float])

    pi_a_t = base.batched_index(pi_t, a_t)
    c_t = jnp.minimum(1.0, pi_a_t / (mu_t + eps)) * lambda_
    target_tm1 = transformed_general_off_policy_returns_from_action_values(
        tx_pair, q_t, a_t, r_t, discount_t, c_t, pi_t)
    if stop_target_gradients:
        target_tm1 = jax.lax.stop_gradient(target_tm1)

    q_a_tm1 = base.batched_index(q_tm1, a_tm1)
    return target_tm1 - q_a_tm1
Пример #26
0
def retrace(
    q_tm1: ArrayLike,
    q_t: ArrayLike,
    a_tm1: ArrayLike,
    a_t: ArrayLike,
    r_t: ArrayLike,
    discount_t: ArrayLike,
    pi_t: ArrayLike,
    mu_t: ArrayLike,
    lambda_: float,
    eps: float = 1e-8,
    stop_target_gradients: bool = True,
) -> ArrayLike:
    """Calculates Retrace errors.

  See "Safe and Efficient Off-Policy Reinforcement Learning" by Munos et al.
  (https://arxiv.org/abs/1606.02647).

  Args:
    q_tm1: Q-values at time t-1.
    q_t: Q-values at time t.
    a_tm1: action index at time t-1.
    a_t: action index at time t.
    r_t: reward at time t.
    discount_t: discount at time t.
    pi_t: target policy probs at time t.
    mu_t: behavior policy probs at time t.
    lambda_: scalar mixing parameter lambda.
    eps: small value to add to mu_t for numerical stability.
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.

  Returns:
    Retrace error.
  """
    base.rank_assert([q_tm1, q_t, a_tm1, a_t, r_t, discount_t, pi_t, mu_t],
                     [2, 2, 1, 1, 1, 1, 2, 1])
    base.type_assert([q_tm1, q_t, a_tm1, a_t, r_t, discount_t, pi_t, mu_t],
                     [float, float, int, int, float, float, float, float])

    pi_a_t = base.batched_index(pi_t, a_t)
    c_t = jnp.minimum(1.0, pi_a_t / (mu_t + eps)) * lambda_
    target_tm1 = multistep.general_off_policy_returns_from_action_values(
        q_t, a_t, r_t, discount_t, c_t, pi_t)

    q_a_tm1 = base.batched_index(q_tm1, a_tm1)

    if stop_target_gradients:
        target_tm1 = jax.lax.stop_gradient(target_tm1)
    return target_tm1 - q_a_tm1
Пример #27
0
def log_loss(
    predictions: ArrayLike,
    targets: ArrayLike,
) -> ArrayLike:
    """Calculates the log loss of predictions wrt targets.

  Args:
    predictions: a vector of probabilities of arbitrary shape.
    targets: a vector of probabilities of shape compatible with predictions.

  Returns:
    a vector of same shape of `predictions`.
  """
    base.type_assert([predictions, targets], float)
    return -jnp.log(likelihood(predictions, targets))
Пример #28
0
def general_off_policy_returns_from_q_and_v(
    q_t: ArrayLike,
    v_t: ArrayLike,
    r_t: ArrayLike,
    discount_t: ArrayLike,
    c_t: ArrayLike,
) -> ArrayLike:
  """Calculates targets for various off-policy evaluation algorithms.

  Given a window of experience of length `K+1`, generated by a behaviour policy
  μ, for each time-step `t` we can estimate the return `G_t` from that step
  onwards, under some target policy π, using the rewards in the trajectory, the
  values under π of states and actions selected by μ, according to equation:

    Gₜ = rₜ₊₁ + γₜ₊₁ * (vₜ₊₁ - cₜ₊₁ * q(aₜ₊₁) + cₜ₊₁* Gₜ₊₁),

  where, depending on the choice of `c_t`, the algorithm implements:

    Importance Sampling             c_t = π(x_t, a_t) / μ(x_t, a_t),
    Harutyunyan's et al. Q(lambda)  c_t = λ,
    Precup's et al. Tree-Backup     c_t = π(x_t, a_t),
    Munos' et al. Retrace           c_t = λ min(1, π(x_t, a_t) / μ(x_t, a_t)).

  See "Safe and Efficient Off-Policy Reinforcement Learning" by Munos et al.
  (https://arxiv.org/abs/1606.02647).

  Args:
    q_t: Q-values under π of actions executed by μ at times [1, ..., K - 1].
    v_t: Values under π at times [1, ..., K].
    r_t: rewards at times [1, ..., K].
    discount_t: discounts at times [1, ..., K].
    c_t: weights at times [1, ..., K - 1].

  Returns:
    Off-policy estimates of the generalized returns from states visited at times
    [0, ..., K - 1].
  """
  base.rank_assert([q_t, v_t, r_t, discount_t, c_t], 1)
  base.type_assert([q_t, v_t, r_t, discount_t, c_t], float)

  # Work backwards to compute `G_K-1`, ..., `G_1`, `G_0`.
  g = r_t[-1] + discount_t[-1] * v_t[-1]  # G_K-1.
  returns = [g]
  for i in jnp.arange(q_t.shape[0] - 1, -1, -1):  # [K - 2, ..., 0]
    g = r_t[i] + discount_t[i] * (v_t[i] - c_t[i] * q_t[i] + c_t[i] * g)
    returns.insert(0, g)

  return jnp.array(returns)
Пример #29
0
def add_gaussian_noise(key: ArrayLike, action: ArrayLike,
                       stddev: float) -> ArrayLike:
    """Returns continuous action with noise drawn from a Gaussian distribution.

  Args:
    key: a key from `jax.random`.
    action: continuous action scalar or vector.
    stddev: standard deviation of noise distribution.

  Returns:
    noisy action, of the same shape as input action.
  """
    base.type_assert(action, float)

    noise = jax.random.normal(key, shape=action.shape) * stddev
    return action + noise
Пример #30
0
def likelihood(predictions: ArrayLike, targets: ArrayLike) -> ArrayLike:
    """Calculates the likelihood of predictions wrt targets.

  Args:
    predictions: a vector of arbitrary shape.
    targets: a vector of shape compatible with predictions.

  Returns:
    a vector of same shape of `predictions`.
  """
    base.type_assert([predictions, targets], float)
    likelihood_vals = predictions**targets * (1. - predictions)**(1. - targets)
    # Note: 0**0 evaluates to NaN on TPUs, manually set these cases to 1.
    filter_indices = jnp.logical_or(
        jnp.logical_and(targets == 1, predictions == 1),
        jnp.logical_and(targets == 0, predictions == 0))
    return jnp.where(filter_indices, 1, likelihood_vals)