Exemple #1
0
def double_q_learning(
    q_tm1: ArrayLike,
    a_tm1: ArrayOrScalar,
    r_t: ArrayOrScalar,
    discount_t: ArrayOrScalar,
    q_t_value: ArrayLike,
    q_t_selector: ArrayLike,
) -> ArrayOrScalar:
    """Calculates the double Q-learning temporal difference error.

  See "Double Q-learning" by van Hasselt.
  (https://papers.nips.cc/paper/3964-double-q-learning.pdf).

  Args:
    q_tm1: Q-values at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_t_value: Q-values at time t.
    q_t_selector: selector Q-values at time t.

  Returns:
    Double Q-learning temporal difference error.
  """
    base.rank_assert([q_tm1, a_tm1, r_t, discount_t, q_t_value, q_t_selector],
                     [1, 0, 0, 0, 1, 1])
    base.type_assert([q_tm1, a_tm1, r_t, discount_t, q_t_value, q_t_selector],
                     [float, int, float, float, float, float])

    target_tm1 = r_t + discount_t * q_t_value[q_t_selector.argmax()]
    return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
Exemple #2
0
def persistent_q_learning(
    q_tm1: ArrayLike,
    a_tm1: ArrayOrScalar,
    r_t: ArrayOrScalar,
    discount_t: ArrayOrScalar,
    q_t: ArrayLike,
    action_gap_scale: float,
) -> ArrayOrScalar:
    """Calculates the persistent Q-learning temporal difference error.

  See "Increasing the Action Gap: New Operators for Reinforcement Learning"
  by Bellemare, Ostrovski, Guez et al. (https://arxiv.org/abs/1512.04860).

  Args:
    q_tm1: Q-values at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_t: Q-values at time t.
    action_gap_scale: coefficient in [0, 1] for scaling the action gap term.

  Returns:
    Persistent Q-learning temporal difference error.
  """
    base.rank_assert([q_tm1, a_tm1, r_t, discount_t, q_t], [1, 0, 0, 0, 1])
    base.type_assert([q_tm1, a_tm1, r_t, discount_t, q_t],
                     [float, int, float, float, float])

    corrected_q_t = ((1. - action_gap_scale) * jnp.max(q_t) +
                     action_gap_scale * q_t[a_tm1])
    target_tm1 = r_t + discount_t * corrected_q_t
    return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
Exemple #3
0
def discounted_returns(
    r_t: ArrayLike,
    discount_t: ArrayLike,
    v_t: ArrayLike
) -> ArrayLike:
  """Calculates a discounted return from a trajectory.

  The returns are computed recursively, from `G_{T-1}` to `G_0`, according to:

    Gₜ = rₜ₊₁ + γₜ₊₁ Gₜ₊₁.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/sutton/book/ebook/node61.html).

  Args:
    r_t: reward sequence at time t.
    discount_t: discount sequence at time t.
    v_t: value sequence or scalar at time t.

  Returns:
    Discounted returns.
  """
  base.rank_assert([r_t, discount_t, v_t], [1, 1, [0, 1]])
  base.type_assert([r_t, discount_t, v_t], float)

  # If scalar make into vector.
  bootstrapped_v = jnp.ones_like(discount_t) * v_t
  return lambda_returns(r_t, discount_t, bootstrapped_v, lambda_=1.)
Exemple #4
0
def qv_max(
    v_tm1: ArrayOrScalar,
    r_t: ArrayOrScalar,
    discount_t: ArrayOrScalar,
    q_t: ArrayLike,
) -> ArrayOrScalar:
    """Calculates the QVMAX temporal difference error.

  See "The QV Family Compared to Other Reinforcement Learning Algorithms" by
  Wiering and van Hasselt (2009).
  (http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.713.1931)

  Args:
    v_tm1: state values at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_t: Q-values at time t.

  Returns:
    QVMAX temporal difference error.
  """
    base.rank_assert([v_tm1, r_t, discount_t, q_t], [0, 0, 0, 1])
    base.type_assert([v_tm1, r_t, discount_t, q_t], float)

    target_tm1 = r_t + discount_t * jnp.max(q_t)
    return jax.lax.stop_gradient(target_tm1) - v_tm1
Exemple #5
0
def qv_learning(
    q_tm1: ArrayLike,
    a_tm1: ArrayOrScalar,
    r_t: ArrayOrScalar,
    discount_t: ArrayOrScalar,
    v_t: ArrayOrScalar,
) -> ArrayOrScalar:
    """Calculates the QV-learning temporal difference error.

  See "Two Novel On-policy Reinforcement Learning Algorithms based on
  TD(lambda)-methods" by Wiering and van Hasselt
  (https://ieeexplore.ieee.org/abstract/document/4220845.)

  Args:
    q_tm1: Q-values at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    v_t: state values at time t.

  Returns:
    QV-learning temporal difference error.
  """
    base.rank_assert([q_tm1, a_tm1, r_t, discount_t, v_t], [1, 0, 0, 0, 0])
    base.type_assert([q_tm1, a_tm1, r_t, discount_t, v_t],
                     [float, int, float, float, float])

    target_tm1 = r_t + discount_t * v_t
    return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
Exemple #6
0
def td_learning(
    v_tm1: ArrayOrScalar,
    r_t: ArrayOrScalar,
    discount_t: ArrayOrScalar,
    v_t: ArrayOrScalar,
) -> ArrayOrScalar:
    """Calculates the TD-learning temporal difference error.

  See "Learning to Predict by the Methods of Temporal Differences" by Sutton.
  (https://link.springer.com/article/10.1023/A:1022633531479).

  Args:
    v_tm1: state values at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    v_t: state values at time t.

  Returns:
    TD-learning temporal difference error.
  """

    base.rank_assert([v_tm1, r_t, discount_t, v_t], 0)
    base.type_assert([v_tm1, r_t, discount_t, v_t], float)

    target_tm1 = r_t + discount_t * v_t
    return jax.lax.stop_gradient(target_tm1) - v_tm1
Exemple #7
0
def policy_gradient_loss(
    logits_t: ArrayLike,
    a_t: ArrayLike,
    adv_t: ArrayLike,
    w_t: ArrayLike,
) -> ArrayLike:
    """Calculates the policy gradient loss.

  See "Simple Gradient-Following Algorithms for Connectionist RL" by Williams.
  (http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf)

  Args:
    logits_t: a sequence of unnormalized action preferences.
    a_t: a sequence of actions sampled from the preferences `logits_t`.
    adv_t: the observed or estimated advantages from executing actions `a_t`.
    w_t: a per timestep weighting for the loss.

  Returns:
    Loss whose gradient corresponds to a policy gradient update.
  """
    base.rank_assert([logits_t, a_t, adv_t, w_t], [2, 1, 1, 1])
    base.type_assert([logits_t, a_t, adv_t, w_t], [float, int, float, float])

    log_pi_a = distributions.softmax().logprob(a_t, logits_t)
    adv_t = jax.lax.stop_gradient(adv_t)
    loss_per_timestep = -log_pi_a * adv_t
    return jnp.mean(loss_per_timestep * w_t)
Exemple #8
0
def _categorical_l2_project(
    z_p: ArrayLike,
    probs: ArrayLike,
    z_q: ArrayLike
) -> ArrayLike:
  """Projects a categorical distribution (z_p, p) onto a different support z_q.

  The projection step minimizes an L2-metric over the cumulative distribution
  functions (CDFs) of the source and target distributions.

  Let kq be len(z_q) and kp be len(z_p). This projection works for any
  support z_q, in particular kq need not be equal to kp.

  See "A Distributional Perspective on RL" by Bellemare et al.
  (https://arxiv.org/abs/1707.06887).

  Args:
    z_p: support of distribution p.
    probs: probability values.
    z_q: support to project distribution (z_p, probs) onto.

  Returns:
    Projection of (z_p, p) onto support z_q under Cramer distance.
  """
  base.rank_assert([z_p, probs, z_q], 1)
  base.type_assert([z_p, probs, z_q], float)

  kp = z_p.shape[0]
  kq = z_q.shape[0]

  # Construct helper arrays from z_q.
  d_pos = jnp.roll(z_q, shift=-1)
  d_neg = jnp.roll(z_q, shift=1)

  # Clip z_p to be in new support range (vmin, vmax).
  z_p = jnp.clip(z_p, z_q[0], z_q[-1])[None, :]
  assert z_p.shape == (1, kp)

  # Get the distance between atom values in support.
  d_pos = (d_pos - z_q)[:, None]  # z_q[i+1] - z_q[i]
  d_neg = (z_q - d_neg)[:, None]  # z_q[i] - z_q[i-1]
  z_q = z_q[:, None]
  assert z_q.shape == (kq, 1)

  # Ensure that we do not divide by zero, in case of atoms of identical value.
  d_neg = jnp.where(d_neg > 0, 1. / d_neg, jnp.zeros_like(d_neg))
  d_pos = jnp.where(d_pos > 0, 1. / d_pos, jnp.zeros_like(d_pos))

  delta_qp = z_p - z_q  # clip(z_p)[j] - z_q[i]
  d_sign = (delta_qp >= 0.).astype(probs.dtype)
  assert delta_qp.shape == (kq, kp)
  assert d_sign.shape == (kq, kp)

  # Matrix of entries sgn(a_ij) * |a_ij|, with a_ij = clip(z_p)[j] - z_q[i].
  delta_hat = (d_sign * delta_qp * d_pos) - ((1. - d_sign) * delta_qp * d_neg)
  probs = probs[None, :]
  assert delta_hat.shape == (kq, kp)
  assert probs.shape == (1, kp)

  return jnp.sum(jnp.clip(1. - delta_hat, 0., 1.) * probs, axis=-1)
Exemple #9
0
def q_learning(
    q_tm1: ArrayLike,
    a_tm1: ArrayLike,
    r_t: ArrayLike,
    discount_t: ArrayLike,
    q_t: ArrayLike,
) -> ArrayLike:
  """Calculates the Q-learning temporal difference error.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/book/ebook/node65.html).

  Args:
    q_tm1: Q-values at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_t: Q-values at time t.

  Returns:
    Q-learning temporal difference error.
  """
  base.rank_assert([q_tm1, a_tm1, r_t, discount_t, q_t], [1, 0, 0, 0, 1])
  base.type_assert([q_tm1, a_tm1, r_t, discount_t, q_t],
                   [float, int, float, float, float])

  target_tm1 = r_t + discount_t * jnp.max(q_t)
  return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
Exemple #10
0
def td_lambda(
    v_tm1: ArrayLike,
    r_t: ArrayLike,
    discount_t: ArrayLike,
    v_t: ArrayLike,
    lambda_: ArrayOrScalar,
) -> ArrayLike:
    """Calculates the TD(lambda) temporal difference error.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/book/ebook/node74.html).

  Args:
    v_tm1: sequence of state values at time t-1.
    r_t: sequence of rewards at time t.
    discount_t: sequence of discounts at time t.
    v_t: sequence of state values at time t.
    lambda_: mixing parameter lambda, either a scalar or a sequence.

  Returns:
    TD(lambda) temporal difference error.
  """
    base.rank_assert([v_tm1, r_t, discount_t, v_t, lambda_],
                     [1, 1, 1, 1, [0, 1]])
    base.type_assert([v_tm1, r_t, discount_t, v_t, lambda_], float)

    target_tm1 = multistep.lambda_returns(r_t, discount_t, v_t, lambda_)
    return jax.lax.stop_gradient(target_tm1) - v_tm1
Exemple #11
0
def add_ornstein_uhlenbeck_noise(key: ArrayLike, action: ArrayLike,
                                 noise_tm1: ArrayLike, damping: float,
                                 stddev: float) -> ArrayLike:
    """Returns continuous action with noise from Ornstein-Uhlenbeck process.

  See "On the theory of Brownian Motion" by Uhlenbeck and Ornstein.
  (https://journals.aps.org/pr/abstract/10.1103/PhysRev.36.823).

  Args:
    key: a key from `jax.random`.
    action: continuous action scalar or vector.
    noise_tm1: noise sampled from OU process in previous timestep.
    damping: parameter for controlling autocorrelation of OU process.
    stddev: standard deviation of noise distribution.

  Returns:
    noisy action, of the same shape as input action.
  """
    base.rank_assert([action, noise_tm1], 1)
    base.type_assert([action, noise_tm1], float)

    noise_t = (1. - damping) * noise_tm1 + jax.random.normal(
        key, shape=action.shape) * stddev

    return action + noise_t
Exemple #12
0
def expected_sarsa(
    q_tm1: ArrayLike,
    a_tm1: ArrayOrScalar,
    r_t: ArrayOrScalar,
    discount_t: ArrayOrScalar,
    q_t: ArrayLike,
    probs_a_t: ArrayLike,
) -> ArrayOrScalar:
    """Calculates the expected SARSA (SARSE) temporal difference error.

  See "A Theoretical and Empirical Analysis of Expected Sarsa" by Seijen,
  van Hasselt, Whiteson et al.
  (http://www.cs.ox.ac.uk/people/shimon.whiteson/pubs/vanseijenadprl09.pdf).

  Args:
    q_tm1: Q-values at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_t: Q-values at time t.
    probs_a_t: action probabilities at time t.

  Returns:
    Expected SARSA temporal difference error.
  """
    base.rank_assert([q_tm1, a_tm1, r_t, discount_t, q_t, probs_a_t],
                     [1, 0, 0, 0, 1, 1])
    base.type_assert([q_tm1, a_tm1, r_t, discount_t, q_t, probs_a_t],
                     [float, int, float, float, float, float])

    target_tm1 = r_t + discount_t * jnp.dot(q_t, probs_a_t)
    return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
Exemple #13
0
def sarsa(
    q_tm1: ArrayLike,
    a_tm1: ArrayOrScalar,
    r_t: ArrayOrScalar,
    discount_t: ArrayOrScalar,
    q_t: ArrayLike,
    a_t: ArrayOrScalar,
) -> ArrayOrScalar:
    """Calculates the SARSA temporal difference error.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/book/ebook/node64.html.)

  Args:
    q_tm1: Q-values at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_t: Q-values at time t.
    a_t: action index at time t.

  Returns:
    SARSA temporal difference error.
  """
    base.rank_assert([q_tm1, a_tm1, r_t, discount_t, q_t, a_t],
                     [1, 0, 0, 0, 1, 0])
    base.type_assert([q_tm1, a_tm1, r_t, discount_t, q_t, a_t],
                     [float, int, float, float, float, int])

    target_tm1 = r_t + discount_t * q_t[a_t]
    return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
Exemple #14
0
def importance_corrected_td_errors(
    r_t: ArrayLike,
    discount_t: ArrayLike,
    rho_tm1: ArrayLike,
    lambda_: ArrayLike,
    values: ArrayLike,
) -> ArrayLike:
    """Computes the multistep td errors with per decision importance sampling.

  Given a trajectory of length `T+1`, generated under some policy π, for each
  time-step `t` we can estimate a multistep temporal difference error δₜ(ρ,λ),
  by combining rewards, discounts, and state values, according to a mixing
  parameter `λ` and importance sampling ratios ρₜ = π(aₜ|sₜ) / μ(aₜ|sₜ):

    td-errorₜ = ρₜ δₜ(ρ,λ)
    δₜ(ρ,λ) = δₜ + ρₜ₊₁ λₜ₊₁ γₜ₊₁ δₜ₊₁(ρ,λ),

  where δₜ = rₜ₊₁ + γₜ₊₁ vₜ₊₁ - vₜ is the one step, temporal difference error
  for the agent's state value estimates. This is equivalent to computing
  the λ-return with λₜ = ρₜ (e.g. using the `lambda_returns` function from
  above), and then computing errors as  td-errorₜ = ρₜ(Gₜ - vₜ).

  See "A new Q(λ) with interim forward view and Monte Carlo equivalence"
  by Sutton et al. (http://proceedings.mlr.press/v32/sutton14.html).

  Args:
    r_t: sequence of rewards rₜ for timesteps t in [1, T].
    discount_t: sequence of discounts γₜ for timesteps t in [1, T].
    rho_tm1: sequence of importance ratios for all timesteps t in [0, T-1].
    lambda_: mixing parameter; scalar or have per timestep values in [1, T].
    values: sequence of state values under π for all timesteps t in [0, T].

  Returns:
    Off-policy estimates of the multistep lambda returns from each state.
  """
    base.rank_assert([r_t, discount_t, rho_tm1, values], [1, 1, 1, 1])
    base.type_assert([r_t, discount_t, rho_tm1, values], float)

    v_tm1 = values[:-1]  # Predictions to compute errors for.
    v_t = values[1:]  # Values for bootstrapping.
    rho_t = jnp.concatenate(
        (rho_tm1[1:], jnp.array([1.])))  # Unused dummy value.
    lambda_ = jnp.ones_like(
        discount_t) * lambda_  # If scalar, make into vector.

    # Compute the one step temporal difference errors.
    one_step_delta = r_t + discount_t * v_t - v_tm1

    # Work backwards to compute `delta_{T-1}`, ..., `delta_0`.
    delta, errors = 0.0, []
    for i in jnp.arange(one_step_delta.shape[0] - 1, -1, -1):
        delta = one_step_delta[
            i] + discount_t[i] * rho_t[i] * lambda_[i] * delta
        errors.insert(0, delta)

    return rho_tm1 * jnp.array(errors)
Exemple #15
0
def categorical_q_learning(
    q_atoms_tm1: ArrayLike,
    q_logits_tm1: ArrayLike,
    a_tm1: ArrayOrScalar,
    r_t: ArrayOrScalar,
    discount_t: ArrayOrScalar,
    q_atoms_t: ArrayLike,
    q_logits_t: ArrayLike,
) -> ArrayOrScalar:
    """Implements Q-learning for categorical Q distributions.

  See "A Distributional Perspective on Reinforcement Learning", by
    Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf).

  Args:
    q_atoms_tm1: atoms of Q distribution at time t-1.
    q_logits_tm1: logits of Q distribution at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_atoms_t: atoms of Q distribution at time t.
    q_logits_t: logits of Q distribution at time t.

  Returns:
    Categorical Q-learning loss (i.e. temporal difference error).
  """
    base.rank_assert([
        q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t,
        q_logits_t
    ], [1, 2, 0, 0, 0, 1, 2])
    base.type_assert([
        q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t,
        q_logits_t
    ], [float, float, int, float, float, float, float])

    # Scale and shift time-t distribution atoms by discount and reward.
    target_z = r_t + discount_t * q_atoms_t

    # Convert logits to distribution, then find greedy action in state s_t.
    q_t_probs = jax.nn.softmax(q_logits_t)
    q_t_mean = jnp.sum(q_t_probs * q_atoms_t[jnp.newaxis, :], axis=1)
    pi_t = jnp.argmax(q_t_mean)

    # Compute distribution for greedy action.
    p_target_z = q_t_probs[pi_t]

    # Project using the Cramer distance.
    target = jax.lax.stop_gradient(
        _categorical_l2_project(target_z, p_target_z, q_atoms_tm1))

    # Compute loss (i.e. temporal difference error).
    logit_qa_tm1 = q_logits_tm1[a_tm1]
    return distributions.categorical_cross_entropy(labels=target,
                                                   logits=logit_qa_tm1)
Exemple #16
0
def categorical_double_q_learning(
    q_atoms_tm1: ArrayLike,
    q_logits_tm1: ArrayLike,
    a_tm1: ArrayOrScalar,
    r_t: ArrayOrScalar,
    discount_t: ArrayOrScalar,
    q_atoms_t: ArrayLike,
    q_logits_t: ArrayLike,
    q_t_selector: ArrayLike,
) -> ArrayOrScalar:
    """Implements double Q-learning for categorical Q distributions.

  See "A Distributional Perspective on Reinforcement Learning", by
    Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf)
  and "Double Q-learning" by van Hasselt.
  (https://papers.nips.cc/paper/3964-double-q-learning.pdf).

  Args:
    q_atoms_tm1: atoms of Q distribution at time t-1.
    q_logits_tm1: logits of Q distribution at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_atoms_t: atoms of Q distribution at time t.
    q_logits_t: logits of Q distribution at time t.
    q_t_selector: selector Q-values at time t.

  Returns:
    Categorical double Q-learning loss (i.e. temporal difference error).
  """
    base.rank_assert([
        q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t,
        q_logits_t, q_t_selector
    ], [1, 2, 0, 0, 0, 1, 2, 1])
    base.type_assert([
        q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t,
        q_logits_t, q_t_selector
    ], [float, float, int, float, float, float, float, float])

    # Scale and shift time-t distribution atoms by discount and reward.
    target_z = r_t + discount_t * q_atoms_t

    # Select logits for greedy action in state s_t and convert to distribution.
    p_target_z = jax.nn.softmax(q_logits_t[q_t_selector.argmax()])

    # Project using the Cramer distance.
    target = jax.lax.stop_gradient(
        _categorical_l2_project(target_z, p_target_z, q_atoms_tm1))

    # Compute loss (i.e. temporal difference error).
    logit_qa_tm1 = q_logits_tm1[a_tm1]
    return distributions.categorical_cross_entropy(labels=target,
                                                   logits=logit_qa_tm1)
Exemple #17
0
def general_off_policy_returns_from_action_values(
    q_t: ArrayLike,
    a_t: ArrayLike,
    r_t: ArrayLike,
    discount_t: ArrayLike,
    c_t: ArrayLike,
    pi_t: ArrayLike,
) -> ArrayLike:
  """Calculates errors for various off-policy correction algorithms.

  Given a window of experience of length `K`, generated by a behaviour policy μ,
  for each time-step `t` we can estimate the return `G_t` from that step
  onwards, under some target policy π, using the rewards in the trajectory, the
  actions selected by μ and the action-values under π, according to equation:

    Gₜ = rₜ₊₁ + γₜ₊₁ * (E[q(aₜ₊₁)] - cₜ * q(aₜ₊₁) + cₜ * Gₜ₊₁),

  where, depending on the choice of `c_t`, the algorithm implements:

    Importance Sampling             c_t = π(x_t, a_t) / μ(x_t, a_t),
    Harutyunyan's et al. Q(lambda)  c_t = λ,
    Precup's et al. Tree-Backup     c_t = π(x_t, a_t),
    Munos' et al. Retrace           c_t = λ min(1, π(x_t, a_t) / μ(x_t, a_t)).

  See "Safe and Efficient Off-Policy Reinforcement Learning" by Munos et al.
  (https://arxiv.org/abs/1606.02647).

  Args:
    q_t: Q-values at time t.
    a_t: action index at time t.
    r_t: reward at time t.
    discount_t: discount at time t.
    c_t: importance weights at time t.
    pi_t: target policy probs at time t.

  Returns:
    Off-policy estimates of the multistep lambda returns from each state..
  """
  base.rank_assert([q_t, a_t, r_t, discount_t, c_t, pi_t],
                   [2, 1, 1, 1, 1, 2])
  base.type_assert([q_t, a_t, r_t, discount_t, c_t, pi_t],
                   [float, int, float, float, float, float])

  # Get the expected values and the values of actually selected actions.
  exp_q_t = (pi_t * q_t).sum(axis=-1)
  # The generalized returns are independent of Q-values and cs at the final
  # state.
  q_a_t = base.batched_index(q_t, a_t)[:-1]
  c_t = c_t[:-1]

  return general_off_policy_returns_from_q_and_v(
      q_a_t, exp_q_t, r_t, discount_t, c_t)
Exemple #18
0
def quantile_q_learning(dist_q_tm1: ArrayLike,
                        tau_q_tm1: ArrayLike,
                        a_tm1: ArrayOrScalar,
                        r_t: ArrayOrScalar,
                        discount_t: ArrayOrScalar,
                        dist_q_t_selector: ArrayLike,
                        dist_q_t: ArrayLike,
                        huber_param: float = 0.) -> ArrayOrScalar:
    """Implements Q-learning for quantile-valued Q distributions.

  See "Distributional Reinforcement Learning with Quantile Regression" by
  Dabney et al. (https://arxiv.org/abs/1710.10044).

  Args:
    dist_q_tm1: Q distribution at time t-1.
    tau_q_tm1: Q distribution probability thresholds.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    dist_q_t_selector: Q distribution at time t for selecting greedy action in
      target policy. This is separate from dist_q_t as in Double Q-Learning, but
      can be computed with the target network and a separate set of samples.
    dist_q_t: target Q distribution at time t.
    huber_param: Huber loss parameter, defaults to 0 (no Huber loss).

  Returns:
    Quantile regression Q learning loss.
  """
    base.rank_assert([
        dist_q_tm1, tau_q_tm1, a_tm1, r_t, discount_t, dist_q_t_selector,
        dist_q_t
    ], [2, 1, 0, 0, 0, 2, 2])
    base.type_assert([
        dist_q_tm1, tau_q_tm1, a_tm1, r_t, discount_t, dist_q_t_selector,
        dist_q_t
    ], [float, float, int, float, float, float, float])

    # Only update the taken actions.
    dist_qa_tm1 = dist_q_tm1[:, a_tm1]

    # Select target action according to greedy policy w.r.t. dist_q_t_selector.
    q_t_selector = jnp.mean(dist_q_t_selector, axis=0)
    a_t = jnp.argmax(q_t_selector)
    dist_qa_t = dist_q_t[:, a_t]

    # Compute target, do not backpropagate into it.
    dist_target = r_t + discount_t * dist_qa_t
    dist_target = jax.lax.stop_gradient(dist_target)

    return _quantile_regression_loss(dist_qa_tm1, tau_q_tm1, dist_target,
                                     huber_param)
Exemple #19
0
def vtrace(
    v_tm1: ArrayLike,
    v_t: ArrayLike,
    r_t: ArrayLike,
    discount_t: ArrayLike,
    rho_t: ArrayLike,
    lambda_: float = 1.0,
    clip_rho_threshold: float = 1.0,
    stop_target_gradients: bool = True,
) -> ArrayLike:
  """Calculates V-Trace errors from policy logits.

  Args:
    v_tm1: values at time t-1.
    v_t: values at time t.
    r_t: reward at time t.
    discount_t: discount at time t.
    rho_t: importance sampling ratios.
    lambda_: scalar mixing parameter lambda.
    clip_rho_threshold: clip threshold for importance weights.
    stop_target_gradients: whether or not to apply stop gradient to targets.

  Returns:
    V-Trace error.
  """
  base.rank_assert(
      [v_tm1, v_t, r_t, discount_t, rho_t], [1, 1, 1, 1, 1])
  base.type_assert(
      [v_tm1, v_t, r_t, discount_t, rho_t], [float, float, float, float, float])

  # Clip importance sampling ratios.
  c_t = jnp.minimum(1.0, rho_t) * lambda_
  clipped_rhos = jnp.minimum(clip_rho_threshold, rho_t)

  # Compute the temporal difference errors.
  td_errors = clipped_rhos * (r_t + discount_t * v_t - v_tm1)

  # Work backwards computing the td-errors.
  err = 0.0
  errors = []
  for i in jnp.arange(v_t.shape[0] - 1, -1, -1):
    err = td_errors[i] + discount_t[i] * c_t[i] * err
    errors.insert(0, err)

  # Add the value of the initial state to get the estimates of the returns.
  target_tm1 = jnp.array(errors) + v_tm1

  # Stop gradients and return temporal difference error.
  if stop_target_gradients:
    target_tm1 = jax.lax.stop_gradient(target_tm1)
  return target_tm1 - v_tm1
Exemple #20
0
def transformed_retrace(
    q_tm1: ArrayLike,
    q_t: ArrayLike,
    a_tm1: ArrayLike,
    a_t: ArrayLike,
    r_t: ArrayLike,
    discount_t: ArrayLike,
    pi_t: ArrayLike,
    mu_t: ArrayLike,
    lambda_: float,
    eps: float = 1e-8,
    stop_target_gradients: bool = True,
    tx_pair: TxPair = IDENTITY_PAIR,
) -> ArrayLike:
    """Calculates transformed Retrace errors.

  See "Recurrent Experience Replay in Distributed Reinforcement Learning" by
  Kapturowski et al. (https://openreview.net/pdf?id=r1lyTjAqYX).

  Args:
    q_tm1: Q-values at time t-1.
    q_t: Q-values at time t.
    a_tm1: action index at time t-1.
    a_t: action index at time t.
    r_t: reward at time t.
    discount_t: discount at time t.
    pi_t: target policy probs at time t.
    mu_t: behavior policy probs at time t.
    lambda_: scalar mixing parameter lambda.
    eps: small value to add to mu_t for numerical stability.
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.
    tx_pair: TxPair of value function transformation and its inverse.

  Returns:
    Transformed Retrace error.
  """
    base.rank_assert([q_tm1, q_t, a_tm1, a_t, r_t, discount_t, pi_t, mu_t],
                     [2, 2, 1, 1, 1, 1, 2, 1])
    base.type_assert([q_tm1, q_t, a_tm1, a_t, r_t, discount_t, pi_t, mu_t],
                     [float, float, int, int, float, float, float, float])

    pi_a_t = base.batched_index(pi_t, a_t)
    c_t = jnp.minimum(1.0, pi_a_t / (mu_t + eps)) * lambda_
    target_tm1 = transformed_general_off_policy_returns_from_action_values(
        tx_pair, q_t, a_t, r_t, discount_t, c_t, pi_t)
    if stop_target_gradients:
        target_tm1 = jax.lax.stop_gradient(target_tm1)

    q_a_tm1 = base.batched_index(q_tm1, a_tm1)
    return target_tm1 - q_a_tm1
Exemple #21
0
def retrace(
    q_tm1: ArrayLike,
    q_t: ArrayLike,
    a_tm1: ArrayLike,
    a_t: ArrayLike,
    r_t: ArrayLike,
    discount_t: ArrayLike,
    pi_t: ArrayLike,
    mu_t: ArrayLike,
    lambda_: float,
    eps: float = 1e-8,
    stop_target_gradients: bool = True,
) -> ArrayLike:
    """Calculates Retrace errors.

  See "Safe and Efficient Off-Policy Reinforcement Learning" by Munos et al.
  (https://arxiv.org/abs/1606.02647).

  Args:
    q_tm1: Q-values at time t-1.
    q_t: Q-values at time t.
    a_tm1: action index at time t-1.
    a_t: action index at time t.
    r_t: reward at time t.
    discount_t: discount at time t.
    pi_t: target policy probs at time t.
    mu_t: behavior policy probs at time t.
    lambda_: scalar mixing parameter lambda.
    eps: small value to add to mu_t for numerical stability.
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.

  Returns:
    Retrace error.
  """
    base.rank_assert([q_tm1, q_t, a_tm1, a_t, r_t, discount_t, pi_t, mu_t],
                     [2, 2, 1, 1, 1, 1, 2, 1])
    base.type_assert([q_tm1, q_t, a_tm1, a_t, r_t, discount_t, pi_t, mu_t],
                     [float, float, int, int, float, float, float, float])

    pi_a_t = base.batched_index(pi_t, a_t)
    c_t = jnp.minimum(1.0, pi_a_t / (mu_t + eps)) * lambda_
    target_tm1 = multistep.general_off_policy_returns_from_action_values(
        q_t, a_t, r_t, discount_t, c_t, pi_t)

    q_a_tm1 = base.batched_index(q_tm1, a_tm1)

    if stop_target_gradients:
        target_tm1 = jax.lax.stop_gradient(target_tm1)
    return target_tm1 - q_a_tm1
Exemple #22
0
def general_off_policy_returns_from_q_and_v(
    q_t: ArrayLike,
    v_t: ArrayLike,
    r_t: ArrayLike,
    discount_t: ArrayLike,
    c_t: ArrayLike,
) -> ArrayLike:
  """Calculates targets for various off-policy evaluation algorithms.

  Given a window of experience of length `K+1`, generated by a behaviour policy
  μ, for each time-step `t` we can estimate the return `G_t` from that step
  onwards, under some target policy π, using the rewards in the trajectory, the
  values under π of states and actions selected by μ, according to equation:

    Gₜ = rₜ₊₁ + γₜ₊₁ * (vₜ₊₁ - cₜ₊₁ * q(aₜ₊₁) + cₜ₊₁* Gₜ₊₁),

  where, depending on the choice of `c_t`, the algorithm implements:

    Importance Sampling             c_t = π(x_t, a_t) / μ(x_t, a_t),
    Harutyunyan's et al. Q(lambda)  c_t = λ,
    Precup's et al. Tree-Backup     c_t = π(x_t, a_t),
    Munos' et al. Retrace           c_t = λ min(1, π(x_t, a_t) / μ(x_t, a_t)).

  See "Safe and Efficient Off-Policy Reinforcement Learning" by Munos et al.
  (https://arxiv.org/abs/1606.02647).

  Args:
    q_t: Q-values under π of actions executed by μ at times [1, ..., K - 1].
    v_t: Values under π at times [1, ..., K].
    r_t: rewards at times [1, ..., K].
    discount_t: discounts at times [1, ..., K].
    c_t: weights at times [1, ..., K - 1].

  Returns:
    Off-policy estimates of the generalized returns from states visited at times
    [0, ..., K - 1].
  """
  base.rank_assert([q_t, v_t, r_t, discount_t, c_t], 1)
  base.type_assert([q_t, v_t, r_t, discount_t, c_t], float)

  # Work backwards to compute `G_K-1`, ..., `G_1`, `G_0`.
  g = r_t[-1] + discount_t[-1] * v_t[-1]  # G_K-1.
  returns = [g]
  for i in jnp.arange(q_t.shape[0] - 1, -1, -1):  # [K - 2, ..., 0]
    g = r_t[i] + discount_t[i] * (v_t[i] - c_t[i] * q_t[i] + c_t[i] * g)
    returns.insert(0, g)

  return jnp.array(returns)
Exemple #23
0
def vtrace_td_error_and_advantage(
    v_tm1: ArrayLike,
    v_t: ArrayLike,
    r_t: ArrayLike,
    discount_t: ArrayLike,
    rho_t: ArrayLike,
    lambda_: float = 1.0,
    clip_rho_threshold: float = 1.0,
    clip_pg_rho_threshold: float = 1.0,
    stop_target_gradients: bool = True,
) -> VTraceOutput:
  """Calculates V-Trace errors and PG advantage from importance weights.

  See "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor
  Learner Architectures" by Espeholt et al. (https://arxiv.org/abs/1802.01561)

  Args:
    v_tm1: values at time t-1.
    v_t: values at time t.
    r_t: reward at time t.
    discount_t: discount at time t.
    rho_t: importance weights at time t.
    lambda_: scalar mixing parameter lambda.
    clip_rho_threshold: clip threshold for importance ratios.
    clip_pg_rho_threshold: clip threshold for policy gradient importance ratios.
    stop_target_gradients: whether or not to apply stop gradient to targets.

  Returns:
    a tuple of V-Trace error, policy gradient advantage, and estimated Q-values.
  """
  base.rank_assert([v_tm1, v_t, r_t, discount_t, rho_t], 1)
  base.type_assert([v_tm1, v_t, r_t, discount_t, rho_t], float)

  errors = value_learning.vtrace(
      v_tm1, v_t, r_t, discount_t, rho_t,
      lambda_, clip_rho_threshold, stop_target_gradients)
  targets_tm1 = errors + v_tm1
  q_bootstrap = jnp.concatenate([
      lambda_ * targets_tm1[1:] + (1 - lambda_) * v_tm1[1:],
      v_t[-1:],
  ], axis=0)
  q_estimate = r_t + discount_t * q_bootstrap
  clipped_pg_rho_tm1 = jnp.minimum(clip_pg_rho_threshold, rho_t)
  pg_advantages = clipped_pg_rho_tm1 * (q_estimate - v_tm1)
  return VTraceOutput(
      errors=errors, pg_advantage=pg_advantages, q_estimate=q_estimate)
Exemple #24
0
def add_gaussian_noise(key: ArrayLike, action: ArrayLike,
                       stddev: float) -> ArrayLike:
    """Returns continuous action with noise drawn from a Gaussian distribution.

  Args:
    key: a key from `jax.random`.
    action: continuous action scalar or vector.
    stddev: standard deviation of noise distribution.

  Returns:
    noisy action, of the same shape as input action.
  """
    base.rank_assert(action, [[0, 1]])
    base.type_assert(action, float)

    noise = jax.random.normal(key, shape=action.shape) * stddev
    return action + noise
Exemple #25
0
def categorical_td_learning(
    v_atoms_tm1: ArrayLike,
    v_logits_tm1: ArrayLike,
    r_t: ArrayLike,
    discount_t: ArrayLike,
    v_atoms_t: ArrayLike,
    v_logits_t: ArrayLike
) -> ArrayLike:
  """Implements TD-learning for categorical value distributions.

  See "A Distributional Perspective on Reinforcement Learning", by
    Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf).

  Args:
    v_atoms_tm1: atoms of V distribution at time t-1.
    v_logits_tm1: logits of V distribution at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    v_atoms_t: atoms of V distribution at time t.
    v_logits_t: logits of V distribution at time t.

  Returns:
    Categorical Q learning loss (i.e. temporal difference error).
  """
  base.rank_assert(
      [v_atoms_tm1, v_logits_tm1, r_t, discount_t, v_atoms_t, v_logits_t],
      [1, 1, 0, 0, 1, 1])
  base.type_assert(
      [v_atoms_tm1, v_logits_tm1, r_t, discount_t, v_atoms_t, v_logits_t],
      [float, float, float, float, float, float])

  # Scale and shift time-t distribution atoms by discount and reward.
  target_z = r_t + discount_t * v_atoms_t

  # Convert logits to distribution.
  v_t_probs = jax.nn.softmax(v_logits_t)

  # Project using the Cramer distance.
  target = jax.lax.stop_gradient(
      _categorical_l2_project(target_z, v_t_probs, v_atoms_tm1))

  # Compute loss (i.e. temporal difference error).
  return distributions.categorical_cross_entropy(
      labels=target, logits=v_logits_tm1)
Exemple #26
0
def transformed_q_lambda(
    q_tm1: ArrayLike,
    a_tm1: ArrayLike,
    r_t: ArrayLike,
    discount_t: ArrayLike,
    q_t: ArrayLike,
    lambda_: ArrayLike,
    stop_target_gradients: bool = True,
    tx_pair: TxPair = IDENTITY_PAIR,
) -> ArrayLike:
    """Calculates Peng's or Watkins' Q(lambda) temporal difference error.

  See "General non-linear Bellman equations" by van Hasselt et al.
  (https://arxiv.org/abs/1907.03687).

  Args:
    q_tm1: sequence of Q-values at time t-1.
    a_tm1: sequence of action indices at time t-1.
    r_t: sequence of rewards at time t.
    discount_t: sequence of discounts at time t.
    q_t: sequence of Q-values at time t.
    lambda_: mixing parameter lambda, either a scalar (e.g. Peng's Q(lambda)) or
      a sequence (e.g. Watkin's Q(lambda)).
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.
    tx_pair: TxPair of value function transformation and its inverse.

  Returns:
    Q(lambda) temporal difference error.
  """
    base.rank_assert([q_tm1, a_tm1, r_t, discount_t, q_t, lambda_],
                     [2, 1, 1, 1, 2, [0, 1]])
    base.type_assert([q_tm1, a_tm1, r_t, discount_t, q_t, lambda_],
                     [float, int, float, float, float, float])

    qa_tm1 = base.batched_index(q_tm1, a_tm1)
    v_t = jnp.max(q_t, axis=-1)
    target_tm1 = transformed_lambda_returns(tx_pair, r_t, discount_t, v_t,
                                            lambda_)
    if stop_target_gradients:
        target_tm1 = jax.lax.stop_gradient(target_tm1)

    return target_tm1 - qa_tm1
def pixel_control_rewards(
    observations: ArrayLike,
    cell_size: int,
) -> base.ArrayLike:
    """Calculates cumulants for pixel control tasks from an observation sequence.

  The observations are first split in a grid of KxK cells. For each cell a
  distinct pseudo reward is computed as the average absolute change in pixel
  intensity across all pixels in the cell. The change in intensity is averaged
  across both pixels and channels (e.g. RGB).

  The `observations` provided to this function should be cropped suitably, to
  ensure that the observations' height and width are a multiple of `cell_size`.
  The values of the `observations` tensor should be rescaled to [0, 1].

  See "Reinforcement Learning with Unsupervised Auxiliary Tasks" by Jaderberg,
  Mnih, Czarnecki et al. (https://arxiv.org/abs/1611.05397).

  Args:
    observations: A tensor of shape `[T+1,H,W,C]`, where
      * `T` is the sequence length,
      * `H` is height,
      * `W` is width,
      * `C` is a channel dimension.
    cell_size: The size of each cell.

  Returns:
    A tensor of pixel control rewards calculated from the observation. The
    shape is `[T,H',W']`, where `H'=H/cell_size` and `W'=W/cell_size`.
  """
    base.rank_assert(observations, 4)
    base.type_assert(observations, float)

    # Shape info.
    h = observations.shape[1] // cell_size  # new height.
    w = observations.shape[2] // cell_size  # new width.
    # Calculate the absolute differences across the sequence.
    abs_diff = jnp.abs(observations[1:] - observations[:-1])
    # Average within cells to get the cumulants.
    abs_diff = abs_diff.reshape(
        (-1, h, cell_size, w, cell_size, observations.shape[3]))
    return abs_diff.mean(axis=(2, 4, 5))
Exemple #28
0
def feature_control_rewards(
    features: ArrayLike,
    cumulant_type='absolute_change',
    discount=None,
) -> base.ArrayLike:
    """Calculates cumulants for feature control tasks from a sequence of features.

  For each feature dimension, a distinct pseudo reward is computed based on the
  change in the feature value between consecutive timesteps. Depending on
  `cumulant_type`, cumulants may be equal the features themselves, the absolute
  difference between their values in consecutive steps, their increase/decrease,
  or may take the form of a potential-based reward discounted by `discount`.

  See "Reinforcement Learning with Unsupervised Auxiliary Tasks" by Jaderberg,
  Mnih, Czarnecki et al. (https://arxiv.org/abs/1611.05397).

  Args:
    features: A tensor of shape `[T+1,D]` of features.
    cumulant_type: either 'feature' (feature is the reward), `absolute_change`
      (the reward equals the absolute difference between consecutive
      timesteps), `increase` (the reward equals the increase in the
      value of the feature), `decrease` (the reward equals the decrease in the
      value of the feature), or 'potential' (r=gamma*phi_{t+1} - phi_t).
    discount: (optional) discount for potential based rewards.

  Returns:
    A tensor of cumulants calculated from the features. The shape is `[T,D]`.
  """
    base.rank_assert(features, 2)
    base.type_assert(features, float)

    if cumulant_type == 'feature':
        return features[1:]
    elif cumulant_type == 'absolute_change':
        return jnp.abs(features[1:] - features[:-1])
    elif cumulant_type == 'increase':
        return features[1:] - features[:-1]
    elif cumulant_type == 'decrease':
        return features[:-1] - features[1:]
    elif cumulant_type == 'potential':
        return discount * features[1:] - features[:-1]
Exemple #29
0
def sarsa_lambda(
    q_tm1: ArrayLike,
    a_tm1: ArrayLike,
    r_t: ArrayLike,
    discount_t: ArrayLike,
    q_t: ArrayLike,
    a_t: ArrayLike,
    lambda_: ArrayOrScalar,
    stop_target_gradients: bool = True,
) -> ArrayLike:
    """Calculates the SARSA(lambda) temporal difference error.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/book/ebook/node77.html).

  Args:
    q_tm1: sequence of Q-values at time t-1.
    a_tm1: sequence of action indices at time t-1.
    r_t: sequence of rewards at time t.
    discount_t: sequence of discounts at time t.
    q_t: sequence of Q-values at time t.
    a_t: sequence of action indices at time t.
    lambda_: mixing parameter lambda, either a scalar or a sequence.
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.

  Returns:
    SARSA(lambda) temporal difference error.
  """
    base.rank_assert([q_tm1, a_tm1, r_t, discount_t, q_t, a_t, lambda_],
                     [2, 1, 1, 1, 2, 1, [0, 1]])
    base.type_assert([q_tm1, a_tm1, r_t, discount_t, q_t, a_t, lambda_],
                     [float, int, float, float, float, int, float])

    qa_tm1 = base.batched_index(q_tm1, a_tm1)
    qa_t = base.batched_index(q_t, a_t)
    target_tm1 = multistep.lambda_returns(r_t, discount_t, qa_t, lambda_)

    if stop_target_gradients:
        target_tm1 = jax.lax.stop_gradient(target_tm1)
    return target_tm1 - qa_tm1
Exemple #30
0
def entropy_loss(
    logits_t: ArrayLike,
    w_t: ArrayLike,
) -> ArrayLike:
    """Calculates the entropy regularization loss.

  See "Function Optimization using Connectionist RL Algorithms" by Williams.
  (https://www.tandfonline.com/doi/abs/10.1080/09540099108946587)

  Args:
    logits_t: a sequence of unnormalized action preferences.
    w_t: a per timestep weighting for the loss.

  Returns:
    Entropy loss.
  """
    base.rank_assert([logits_t, w_t], [2, 1])
    base.type_assert([logits_t, w_t], float)

    entropy_per_timestep = distributions.softmax().entropy(logits_t)
    return -jnp.mean(entropy_per_timestep * w_t)