Exemple #1
0
def transformed_q_lambda(
    q_tm1: Array,
    a_tm1: Array,
    r_t: Array,
    discount_t: Array,
    q_t: Array,
    lambda_: Array,
    stop_target_gradients: bool = True,
    tx_pair: TxPair = IDENTITY_PAIR,
) -> Array:
    """Calculates Peng's or Watkins' Q(lambda) temporal difference error.

  See "General non-linear Bellman equations" by van Hasselt et al.
  (https://arxiv.org/abs/1907.03687).

  Args:
    q_tm1: sequence of Q-values at time t-1.
    a_tm1: sequence of action indices at time t-1.
    r_t: sequence of rewards at time t.
    discount_t: sequence of discounts at time t.
    q_t: sequence of Q-values at time t.
    lambda_: mixing parameter lambda, either a scalar (e.g. Peng's Q(lambda)) or
      a sequence (e.g. Watkin's Q(lambda)).
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.
    tx_pair: TxPair of value function transformation and its inverse.

  Returns:
    Q(lambda) temporal difference error.
  """
    chex.assert_rank([q_tm1, a_tm1, r_t, discount_t, q_t, lambda_],
                     [2, 1, 1, 1, 2, {0, 1}])
    chex.assert_type([q_tm1, a_tm1, r_t, discount_t, q_t, lambda_],
                     [float, int, float, float, float, float])

    qa_tm1 = base.batched_index(q_tm1, a_tm1)
    v_t = jnp.max(q_t, axis=-1)
    target_tm1 = transformed_lambda_returns(tx_pair, r_t, discount_t, v_t,
                                            lambda_, stop_target_gradients)
    return target_tm1 - qa_tm1
Exemple #2
0
def q_lambda(
    q_tm1: Array,
    a_tm1: Array,
    r_t: Array,
    discount_t: Array,
    q_t: Array,
    lambda_: Numeric,
    stop_target_gradients: bool = True,
) -> Array:
    """Calculates Peng's or Watkins' Q(lambda) temporal difference error.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/book/ebook/node78.html).

  Args:
    q_tm1: sequence of Q-values at time t-1.
    a_tm1: sequence of action indices at time t-1.
    r_t: sequence of rewards at time t.
    discount_t: sequence of discounts at time t.
    q_t: sequence of Q-values at time t.
    lambda_: mixing parameter lambda, either a scalar (e.g. Peng's Q(lambda)) or
      a sequence (e.g. Watkin's Q(lambda)).
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.

  Returns:
    Q(lambda) temporal difference error.
  """
    chex.assert_rank([q_tm1, a_tm1, r_t, discount_t, q_t, lambda_],
                     [2, 1, 1, 1, 2, {0, 1}])
    chex.assert_type([q_tm1, a_tm1, r_t, discount_t, q_t, lambda_],
                     [float, int, float, float, float, float])

    qa_tm1 = base.batched_index(q_tm1, a_tm1)
    v_t = jnp.max(q_t, axis=-1)
    target_tm1 = multistep.lambda_returns(r_t, discount_t, v_t, lambda_)

    if stop_target_gradients:
        target_tm1 = jax.lax.stop_gradient(target_tm1)
    return target_tm1 - qa_tm1
Exemple #3
0
def kl_divergence_with_probs(p=None, q=None, epsilon=1e-20):
    """Compute the KL between two categorical distributions from their probabilities.

  Args:
    p: [..., dim] array with probs for the first distribution.
    q: [..., dim] array with probs for the second distribution.
    epsilon: a small float to normalize probabilities with.

  Returns:
    an array of KL divergence terms taken over the last axis.
  """
    chex.assert_type([p, q], float)
    chex.assert_equal_shape([p, q])

    log_p = jnp.log(p + epsilon)
    log_q = jnp.log(q + epsilon)
    kl = jnp.sum(p * (log_p - log_q), axis=-1)

    ## KL divergence should be positive, this helps with numerical stability
    loss = jax.nn.relu(kl)

    return loss
Exemple #4
0
def categorical_td_learning(v_atoms_tm1: Array, v_logits_tm1: Array,
                            r_t: Numeric, discount_t: Numeric,
                            v_atoms_t: Array, v_logits_t: Array) -> Numeric:
    """Implements TD-learning for categorical value distributions.

  See "A Distributional Perspective on Reinforcement Learning", by
    Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf).

  Args:
    v_atoms_tm1: atoms of V distribution at time t-1.
    v_logits_tm1: logits of V distribution at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    v_atoms_t: atoms of V distribution at time t.
    v_logits_t: logits of V distribution at time t.

  Returns:
    Categorical Q learning loss (i.e. temporal difference error).
  """
    chex.assert_rank(
        [v_atoms_tm1, v_logits_tm1, r_t, discount_t, v_atoms_t, v_logits_t],
        [1, 1, 0, 0, 1, 1])
    chex.assert_type(
        [v_atoms_tm1, v_logits_tm1, r_t, discount_t, v_atoms_t, v_logits_t],
        [float, float, float, float, float, float])

    # Scale and shift time-t distribution atoms by discount and reward.
    target_z = r_t + discount_t * v_atoms_t

    # Convert logits to distribution.
    v_t_probs = jax.nn.softmax(v_logits_t)

    # Project using the Cramer distance.
    target = jax.lax.stop_gradient(
        _categorical_l2_project(target_z, v_t_probs, v_atoms_tm1))

    # Compute loss (i.e. temporal difference error).
    return distributions.categorical_cross_entropy(labels=target,
                                                   logits=v_logits_tm1)
def _quantile_regression_loss(
    dist_src: Array,
    tau_src: Array,
    dist_target: Array,
    huber_param: float = 0.
) -> Numeric:
  """Compute (Huber) QR loss between two discrete quantile-valued distributions.

  See "Distributional Reinforcement Learning with Quantile Regression" by
  Dabney et al. (https://arxiv.org/abs/1710.10044).

  Args:
    dist_src: source probability distribution.
    tau_src: source distribution probability thresholds.
    dist_target: target probability distribution.
    huber_param: Huber loss parameter, defaults to 0 (no Huber loss).

  Returns:
    Quantile regression loss.
  """
  chex.assert_rank([dist_src, tau_src, dist_target], 1)
  chex.assert_type([dist_src, tau_src, dist_target], float)

  # Calculate quantile error.
  delta = dist_target[None, :] - dist_src[:, None]
  delta_neg = (delta < 0.).astype(jnp.float32)
  delta_neg = jax.lax.stop_gradient(delta_neg)
  weight = jnp.abs(tau_src[:, None] - delta_neg)

  # Calculate Huber loss.
  if huber_param > 0.:
    loss = clipping.huber_loss(delta, huber_param)
  else:
    loss = jnp.abs(delta)
  loss *= weight

  # Average over target-samples dimension, sum over src-samples dimension.
  return jnp.sum(jnp.mean(loss, axis=-1))
Exemple #6
0
def dpg_loss(a_t: Array,
             dqda_t: Array,
             dqda_clipping: Optional[Scalar] = None) -> Array:
    """Calculates the deterministic policy gradient (DPG) loss.

  See "Deterministic Policy Gradient Algorithms" by Silver, Lever, Heess,
  Degris, Wierstra, Riedmiller (http://proceedings.mlr.press/v32/silver14.pdf).

  Args:
    a_t: continuous-valued action at time t.
    dqda_t: gradient of Q(s,a) wrt. a, evaluated at time t.
    dqda_clipping: clips the gradient to have norm <= `dqda_clipping`.

  Returns:
    DPG loss.
  """
    chex.assert_rank([a_t, dqda_t], 1)
    chex.assert_type([a_t, dqda_t], float)

    if dqda_clipping is not None:
        dqda_t = _clip_by_l2_norm(dqda_t, dqda_clipping)
    target_tm1 = dqda_t + a_t
    return losses.l2_loss(jax.lax.stop_gradient(target_tm1) - a_t)
Exemple #7
0
def l2_loss(
    predictions: chex.Array,
    targets: Optional[chex.Array] = None,
) -> chex.Array:
    """Calculates the L2 loss for a set of predictions.

  Note: the 0.5 term is standard in "Pattern Recognition and Machine Learning"
  by Bishop, but not "The Elements of Statistical Learning" by Tibshirani.

  References:
    [Chris Bishop, 2006](https://bit.ly/3eeP0ga)

  Args:
    predictions: a vector of arbitrary shape.
    targets: a vector of shape compatible with predictions; if not provides
      then it is assumed to be zero.

  Returns:
    the squared error loss.
  """
    chex.assert_type([predictions], float)
    errors = (predictions - targets) if (targets is not None) else predictions
    return 0.5 * (errors)**2
Exemple #8
0
def l2_loss(
    predictions: chex.Array,
    targets: Optional[chex.Array] = None,
) -> chex.Array:
    """Calculates the L2 loss for a set of predictions.

  Note: the 0.5 term is standard in "Pattern Recognition and Machine Learning"
  by Bishop, but not "The Elements of Statistical Learning" by Tibshirani.

  References:
    [Chris Bishop, 2006](https://bit.ly/3eeP0ga)

  Args:
    predictions: a vector of arbitrary shape `[...]`.
    targets: a vector with shape broadcastable to that of `predictions`;
      if not provided then it is assumed to be a vector of zeros.

  Returns:
    elementwise squared differences, with same shape as `predictions`.
  """
    chex.assert_type([predictions], float)
    errors = (predictions - targets) if (targets is not None) else predictions
    return 0.5 * (errors)**2
Exemple #9
0
def persistent_q_learning(
    q_tm1: Array,
    a_tm1: Numeric,
    r_t: Numeric,
    discount_t: Numeric,
    q_t: Array,
    action_gap_scale: float,
    stop_target_gradients: bool = True,
) -> Numeric:
    """Calculates the persistent Q-learning temporal difference error.

  See "Increasing the Action Gap: New Operators for Reinforcement Learning"
  by Bellemare, Ostrovski, Guez et al. (https://arxiv.org/abs/1512.04860).

  Args:
    q_tm1: Q-values at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_t: Q-values at time t.
    action_gap_scale: coefficient in [0, 1] for scaling the action gap term.
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.

  Returns:
    Persistent Q-learning temporal difference error.
  """
    chex.assert_rank([q_tm1, a_tm1, r_t, discount_t, q_t], [1, 0, 0, 0, 1])
    chex.assert_type([q_tm1, a_tm1, r_t, discount_t, q_t],
                     [float, int, float, float, float])

    corrected_q_t = ((1. - action_gap_scale) * jnp.max(q_t) +
                     action_gap_scale * q_t[a_tm1])
    target_tm1 = r_t + discount_t * corrected_q_t
    target_tm1 = jax.lax.select(stop_target_gradients,
                                jax.lax.stop_gradient(target_tm1), target_tm1)
    return target_tm1 - q_tm1[a_tm1]
Exemple #10
0
def expected_sarsa(
    q_tm1: Array,
    a_tm1: Numeric,
    r_t: Numeric,
    discount_t: Numeric,
    q_t: Array,
    probs_a_t: Array,
    stop_target_gradients: bool = True,
) -> Numeric:
    """Calculates the expected SARSA (SARSE) temporal difference error.

  See "A Theoretical and Empirical Analysis of Expected Sarsa" by Seijen,
  van Hasselt, Whiteson et al.
  (http://www.cs.ox.ac.uk/people/shimon.whiteson/pubs/vanseijenadprl09.pdf).

  Args:
    q_tm1: Q-values at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_t: Q-values at time t.
    probs_a_t: action probabilities at time t.
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.

  Returns:
    Expected SARSA temporal difference error.
  """
    chex.assert_rank([q_tm1, a_tm1, r_t, discount_t, q_t, probs_a_t],
                     [1, 0, 0, 0, 1, 1])
    chex.assert_type([q_tm1, a_tm1, r_t, discount_t, q_t, probs_a_t],
                     [float, int, float, float, float, float])

    target_tm1 = r_t + discount_t * jnp.dot(q_t, probs_a_t)
    target_tm1 = jax.lax.select(stop_target_gradients,
                                jax.lax.stop_gradient(target_tm1), target_tm1)
    return target_tm1 - q_tm1[a_tm1]
Exemple #11
0
def log_cosh(
    predictions: chex.Array,
    targets: Optional[chex.Array] = None,
) -> chex.Array:
    """Calculates the log-cosh loss for a set of predictions.

  log(cosh(x)) is approximately `(x**2) / 2` for small x and `abs(x) - log(2)`
  for large x.  It is a twice differentiable alternative to the Huber loss.

  References:
    [Chen et al, 2019](https://openreview.net/pdf?id=rkglvsC9Ym)

  Args:
    predictions: a vector of arbitrary shape.
    targets: a vector of shape compatible with predictions; if not provided
      then it is assumed to be zero.

  Returns:
    the log-cosh loss.
  """
    chex.assert_type([predictions], float)
    errors = (predictions - targets) if (targets is not None) else predictions
    # log(cosh(x)) = log((exp(x) + exp(-x))/2) = log(exp(x) + exp(-x)) - log(2)
    return jnp.logaddexp(errors, -errors) - jnp.log(2.0).astype(errors.dtype)
Exemple #12
0
def cosine_distance(
    predictions: chex.Array,
    targets: chex.Array,
    epsilon: float = 0.,
) -> chex.Array:
    r"""Computes the cosine distance between targets and predictions.

  The cosine **similarity** is a measure of similarity between vectors defined
  as the cosine of the angle between them, which is also the inner product of
  those vectors normalized to have unit norm. The cosine **distance**,
  implemented here, measures instead the **dissimilarity** as `1 - cos(\theta)`.

  References:
    [Wikipedia, 2021](https://en.wikipedia.org/wiki/Cosine_similarity)

  Args:
    predictions: The predicted vector.
    targets: Ground truth target vector.
    epsilon: minimum norm for terms in the denominator of the cosine similarity.

  Returns:
    cosine similarity values.
  """
    chex.assert_equal_shape([targets, predictions])
    chex.assert_type([targets, predictions], float)
    # vectorize norm fn, to treat all dimensions except the last as batch dims.
    batched_norm_fn = jnp.vectorize(utils.safe_norm,
                                    signature='(k)->()',
                                    excluded={1})
    # normalise the last dimension of targets and predictions.
    unit_targets = targets / jnp.expand_dims(batched_norm_fn(targets, epsilon),
                                             axis=-1)
    unit_predictions = predictions / jnp.expand_dims(
        batched_norm_fn(predictions, epsilon), axis=-1)
    # cosine distance = 1 - cosine similarity.
    return 1. - jnp.sum(unit_targets * unit_predictions, axis=-1)
Exemple #13
0
def discounted_returns(r_t: Array, discount_t: Array, v_t: Array) -> Array:
    """Calculates a discounted return from a trajectory.

  The returns are computed recursively, from `G_{T-1}` to `G_0`, according to:

    Gₜ = rₜ₊₁ + γₜ₊₁ Gₜ₊₁.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/sutton/book/ebook/node61.html).

  Args:
    r_t: reward sequence at time t.
    discount_t: discount sequence at time t.
    v_t: value sequence or scalar at time t.

  Returns:
    Discounted returns.
  """
    chex.assert_rank([r_t, discount_t, v_t], [1, 1, {0, 1}])
    chex.assert_type([r_t, discount_t, v_t], float)

    # If scalar make into vector.
    bootstrapped_v = jnp.ones_like(discount_t) * v_t
    return lambda_returns(r_t, discount_t, bootstrapped_v, lambda_=1.)
Exemple #14
0
def double_q_learning(
    q_tm1: Array,
    a_tm1: Numeric,
    r_t: Numeric,
    discount_t: Numeric,
    q_t_value: Array,
    q_t_selector: Array,
    stop_target_gradients: bool = True,
) -> Numeric:
    """Calculates the double Q-learning temporal difference error.

  See "Double Q-learning" by van Hasselt.
  (https://papers.nips.cc/paper/3964-double-q-learning.pdf).

  Args:
    q_tm1: Q-values at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_t_value: Q-values at time t.
    q_t_selector: selector Q-values at time t.
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.

  Returns:
    Double Q-learning temporal difference error.
  """
    chex.assert_rank([q_tm1, a_tm1, r_t, discount_t, q_t_value, q_t_selector],
                     [1, 0, 0, 0, 1, 1])
    chex.assert_type([q_tm1, a_tm1, r_t, discount_t, q_t_value, q_t_selector],
                     [float, int, float, float, float, float])

    target_tm1 = r_t + discount_t * q_t_value[q_t_selector.argmax()]
    target_tm1 = jax.lax.select(stop_target_gradients,
                                jax.lax.stop_gradient(target_tm1), target_tm1)
    return target_tm1 - q_tm1[a_tm1]
Exemple #15
0
def huber_loss(x: Array, delta: float = 1.) -> Array:
    """Huber loss, similar to L2 loss close to zero, L1 loss away from zero.

  See "Robust Estimation of a Location Parameter" by Huber.
  (https://projecteuclid.org/download/pdf_1/euclid.aoms/1177703732).

  Args:
    x: a vector of arbitrary shape.
    delta: the bounds for the huber loss transformation, defaults at 1.

  Note `grad(huber_loss(x))` is equivalent to `grad(0.5 * clip_gradient(x)**2)`.

  Returns:
    a vector of same shape of `x`.
  """
    chex.assert_type(x, float)

    # 0.5 * x^2                  if |x| <= d
    # 0.5 * d^2 + d * (|x| - d)  if |x| > d
    abs_x = jnp.abs(x)
    quadratic = jnp.minimum(abs_x, delta)
    # Same as max(abs_x - delta, 0) but avoids potentially doubling gradient.
    linear = abs_x - quadratic
    return 0.5 * quadratic**2 + delta * linear
Exemple #16
0
def sigmoid_binary_cross_entropy(logits, labels):
    """Computes sigmoid cross entropy given logits and multiple class labels.

  Measures the probability error in discrete classification tasks in which
  each class is an independent binary prediction and different classes are
  not mutually exclusive. This may be used for multilabel image classification
  for instance a model may predict that an image contains both a cat and a dog.

  References:
    [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html)

  Args:
    logits: unnormalized log probabilities.
    labels: the probability for that class.

  Returns:
    a sigmoid cross entropy loss.
  """
    chex.assert_equal_shape([logits, labels])
    chex.assert_type([logits, labels], float)
    log_p = jax.nn.log_sigmoid(logits)
    # log(1 - sigmoid(x)) = log_sigmoid(-x), the latter more numerically stable
    log_not_p = jax.nn.log_sigmoid(-logits)
    return -labels * log_p - (1. - labels) * log_not_p
Exemple #17
0
 def test_adam(self):
   init_fn, update_fn = optimizers.get_optimizer(
       ConfigDict({
           'optimizer': 'adam',
           'l2_decay_factor': None,
           'batch_size': 50,
           'total_accumulated_batch_size': 100,  # Use gradient accumulation.
           'opt_hparams': {
               'beta1': 0.9,
               'beta2': 0.999,
               'epsilon': 1e-7,
               'weight_decay': 0.0,
           }
       }))
   del update_fn
   optimizer_state = init_fn({'foo': jnp.ones(10)})
   # Test that we can extract 'count'.
   chex.assert_type(extract_field(optimizer_state, 'count'), int)
   # Test that we can extract 'nu'.
   chex.assert_shape(extract_field(optimizer_state, 'nu')['foo'], (10,))
   # Test that we can extract 'mu'.
   chex.assert_shape(extract_field(optimizer_state, 'mu')['foo'], (10,))
   # Test that attemptping to extract a nonexistent field "abc" returns None.
   chex.assert_equal(extract_field(optimizer_state, 'abc'), None)
Exemple #18
0
def cosine_distance(
    predictions: chex.Array,
    targets: chex.Array,
    epsilon: float = 0.,
) -> chex.Array:
    r"""Computes the cosine distance between targets and predictions.

  The cosine **distance**, implemented here, measures the **dissimilarity**
  of two vectors as the opposite of cosine **similarity**: `1 - cos(\theta)`.

  References:
    [Wikipedia, 2021](https://en.wikipedia.org/wiki/Cosine_similarity)

  Args:
    predictions: The predicted vectors, with shape `[..., dim]`.
    targets: Ground truth target vectors, with shape `[..., dim]`.
    epsilon: minimum norm for terms in the denominator of the cosine similarity.

  Returns:
    cosine distances, with shape `[...]`.
  """
    chex.assert_type([predictions, targets], float)
    # cosine distance = 1 - cosine similarity.
    return 1. - cosine_similarity(predictions, targets, epsilon)
Exemple #19
0
def categorical_kl_divergence(p_logits: Array,
                              q_logits: Array,
                              temperature: float = 1.) -> Array:
    """Compute the KL between two categorical distributions from their logits.

  Args:
    p_logits: unnormalized logits for the first distribution.
    q_logits: unnormalized logits for the second distribution.
    temperature: the temperature for the softmax distribution, defaults at 1.

  Returns:
    the kl divergence between the distributions.
  """
    chex.assert_type([p_logits, q_logits], float)

    p_logits /= temperature
    q_logits /= temperature

    p = jax.nn.softmax(p_logits)
    log_p = jax.nn.log_softmax(p_logits)
    log_q = jax.nn.log_softmax(q_logits)
    kl = jnp.sum(p * (log_p - log_q), axis=-1)
    return jax.nn.relu(
        kl)  # Guard against numerical issues giving negative KL.
Exemple #20
0
def add_ornstein_uhlenbeck_noise(key: Array, action: Array, noise_tm1: Array,
                                 damping: float, stddev: float) -> Array:
    """Returns continuous action with noise from Ornstein-Uhlenbeck process.

  See "On the theory of Brownian Motion" by Uhlenbeck and Ornstein.
  (https://journals.aps.org/pr/abstract/10.1103/PhysRev.36.823).

  Args:
    key: a key from `jax.random`.
    action: continuous action scalar or vector.
    noise_tm1: noise sampled from OU process in previous timestep.
    damping: parameter for controlling autocorrelation of OU process.
    stddev: standard deviation of noise distribution.

  Returns:
    noisy action, of the same shape as input action.
  """
    chex.assert_rank([action, noise_tm1], 1)
    chex.assert_type([action, noise_tm1], float)

    noise_t = (1. - damping) * noise_tm1 + jax.random.normal(
        key, shape=action.shape) * stddev

    return action + noise_t
Exemple #21
0
def lambda_returns(
    r_t: Array,
    discount_t: Array,
    v_t: Array,
    lambda_: Numeric = 1.,
) -> Array:
  """Estimates a multistep truncated lambda return from a trajectory.

  Given a a trajectory of length `T+1`, generated under some policy π, for each
  time-step `t` we can estimate a target return `G_t`, by combining rewards,
  discounts, and state values, according to a mixing parameter `lambda`.

  The parameter `lambda_`  mixes the different multi-step bootstrapped returns,
  corresponding to accumulating `k` rewards and then bootstrapping using `v_t`.

    rₜ₊₁ + γₜ₊₁ vₜ₊₁
    rₜ₊₁ + γₜ₊₁ rₜ₊₂ + γₜ₊₁ γₜ₊₂ vₜ₊₂
    rₜ₊₁ + γₜ₊₁ rₜ₊₂ + γₜ₊₁ γₜ₊₂ rₜ₊₂ + γₜ₊₁ γₜ₊₂ γₜ₊₃ vₜ₊₃

  The returns are computed recursively, from `G_{T-1}` to `G_0`, according to:

    Gₜ = rₜ₊₁ + γₜ₊₁ [(1 - λₜ₊₁) vₜ₊₁ + λₜ₊₁ Gₜ₊₁].

  In the `on-policy` case, we estimate a return target `G_t` for the same
  policy π that was used to generate the trajectory. In this setting the
  parameter `lambda_` is typically a fixed scalar factor. Depending
  on how values `v_t` are computed, this function can be used to construct
  targets for different multistep reinforcement learning updates:

    TD(λ):  `v_t` contains the state value estimates for each state under π.
    Q(λ):  `v_t = max(q_t, axis=-1)`, where `q_t` estimates the action values.
    Sarsa(λ):  `v_t = q_t[..., a_t]`, where `q_t` estimates the action values.

  In the `off-policy` case, the mixing factor is a function of state, and
  different definitions of `lambda` implement different off-policy corrections:

    Per-decision importance sampling:  λₜ = λ ρₜ = λ [π(aₜ|sₜ) / μ(aₜ|sₜ)]
    V-trace, as instantiated in IMPALA:  λₜ = min(1, ρₜ)

  Note that the second option is equivalent to applying per-decision importance
  sampling, but using an adaptive λ(ρₜ) = min(1/ρₜ, 1), such that the effective
  bootstrap parameter at time t becomes λₜ = λ(ρₜ) * ρₜ = min(1, ρₜ).
  This is the interpretation used in the ABQ(ζ) algorithm (Mahmood 2017).

  Of course this can be augmented to include an additional factor λ.  For
  instance we could use V-trace with a fixed additional parameter λ = 0.9, by
  setting λₜ = 0.9 * min(1, ρₜ) or, alternatively (but not equivalently),
  λₜ = min(0.9, ρₜ).

  Estimated return are then often used to define a td error, e.g.:  ρₜ(Gₜ - vₜ).

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/sutton/book/ebook/node74.html).

  Args:
    r_t: sequence of rewards rₜ for timesteps t in [1, T].
    discount_t: sequence of discounts γₜ for timesteps t in [1, T].
    v_t: sequence of state values estimates under π for timesteps t in [1, T].
    lambda_: mixing parameter; a scalar or a vector for timesteps t in [1, T].

  Returns:
    Multistep lambda returns.
  """
  chex.assert_rank([r_t, discount_t, v_t, lambda_], [1, 1, 1, {0, 1}])
  chex.assert_type([r_t, discount_t, v_t, lambda_], float)
  chex.assert_equal_shape([r_t, discount_t, v_t])

  # If scalar make into vector.
  lambda_ = jnp.ones_like(discount_t) * lambda_

  # Work backwards to compute `G_{T-1}`, ..., `G_0`.
  returns = []
  g = v_t[-1]
  for i in jnp.arange(v_t.shape[0] - 1, -1, -1):
    g = r_t[i] + discount_t[i] * ((1-lambda_[i]) * v_t[i] + lambda_[i] * g)
    returns.insert(0, g)

  return jnp.array(returns)
Exemple #22
0
def vtrace(
    v_tm1: Array,
    v_t: Array,
    r_t: Array,
    discount_t: Array,
    rho_tm1: Array,
    lambda_: float = 1.0,
    clip_rho_threshold: float = 1.0,
    stop_target_gradients: bool = True,
) -> Array:
    """Calculates V-Trace errors from importance weights.

  V-trace computes TD-errors from multistep trajectories by applying
  off-policy corrections based on clipped importance sampling ratios.

  See "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor
  Learner Architectures" by Espeholt et al. (https://arxiv.org/abs/1802.01561).

  Args:
    v_tm1: values at time t-1.
    v_t: values at time t.
    r_t: reward at time t.
    discount_t: discount at time t.
    rho_tm1: importance sampling ratios.
    lambda_: scalar mixing parameter lambda.
    clip_rho_threshold: clip threshold for importance weights.
    stop_target_gradients: whether or not to apply stop gradient to targets.

  Returns:
    V-Trace error.
  """
    chex.assert_rank([v_tm1, v_t, r_t, discount_t, rho_tm1], [1, 1, 1, 1, 1])
    chex.assert_type([v_tm1, v_t, r_t, discount_t, rho_tm1],
                     [float, float, float, float, float])
    chex.assert_equal_shape([v_tm1, v_t, r_t, discount_t, rho_tm1])

    # Clip importance sampling ratios.
    c_t = jnp.minimum(1.0, rho_tm1) * lambda_
    clipped_rhos = jnp.minimum(clip_rho_threshold, rho_tm1)

    # Compute the temporal difference errors.
    td_errors = clipped_rhos * (r_t + discount_t * v_t - v_tm1)

    # Work backwards computing the td-errors.
    err = 0.0
    errors = []
    for i in jnp.arange(v_t.shape[0] - 1, -1, -1):
        err = td_errors[i] + discount_t[i] * c_t[i] * err
        errors.insert(0, err)

    # Return errors.
    if not stop_target_gradients:
        return jnp.array(errors)
    # In TD-like algorithms, we want gradients to only flow in the predictions,
    # and not in the values used to bootstrap. In this case, add the value of the
    # initial state value to get the implied estimates of the returns, stop
    # gradient around such target and then subtract again the initial state value.
    else:
        target_tm1 = jnp.array(errors) + v_tm1
        target_tm1 = jax.lax.stop_gradient(target_tm1)
    return target_tm1 - v_tm1
Exemple #23
0
def leaky_vtrace(v_tm1: Array,
                 v_t: Array,
                 r_t: Array,
                 discount_t: Array,
                 rho_tm1: Array,
                 alpha_: float = 1.0,
                 lambda_: float = 1.0,
                 clip_rho_threshold: float = 1.0,
                 stop_target_gradients: bool = True):
    """Calculates Leaky V-Trace errors from importance weights.

  Leaky-Vtrace is a combination of Importance sampling and V-trace, where the
  degree of mixing is controlled by a scalar `alpha` (that may be meta-learnt).

  See "Self-Tuning Deep Reinforcement Learning"
  by Zahavy et al. (https://arxiv.org/abs/2002.12928)

  Args:
    v_tm1: values at time t-1.
    v_t: values at time t.
    r_t: reward at time t.
    discount_t: discount at time t.
    rho_tm1: importance weights at time t.
    alpha_: mixing parameter for Importance Sampling and V-trace.
    lambda_: scalar mixing parameter lambda.
    clip_rho_threshold: clip threshold for importance weights.
    stop_target_gradients: whether or not to apply stop gradient to targets.

  Returns:
    Leaky V-Trace error.
  """
    chex.assert_rank([v_tm1, v_t, r_t, discount_t, rho_tm1], [1, 1, 1, 1, 1])
    chex.assert_type([v_tm1, v_t, r_t, discount_t, rho_tm1],
                     [float, float, float, float, float])
    chex.assert_equal_shape([v_tm1, v_t, r_t, discount_t, rho_tm1])

    # Mix clipped and unclipped importance sampling ratios.
    c_t = (
        (1 - alpha_) * rho_tm1 + alpha_ * jnp.minimum(1.0, rho_tm1)) * lambda_
    clipped_rhos = ((1 - alpha_) * rho_tm1 +
                    alpha_ * jnp.minimum(clip_rho_threshold, rho_tm1))

    # Compute the temporal difference errors.
    td_errors = clipped_rhos * (r_t + discount_t * v_t - v_tm1)

    # Work backwards computing the td-errors.
    err = 0.0
    errors = []
    for i in jnp.arange(v_t.shape[0] - 1, -1, -1):
        err = td_errors[i] + discount_t[i] * c_t[i] * err
        errors.insert(0, err)

    # Return errors.
    if not stop_target_gradients:
        return jnp.array(errors)
    # In TD-like algorithms, we want gradients to only flow in the predictions,
    # and not in the values used to bootstrap. In this case, add the value of the
    # initial state value to get the implied estimates of the returns, stop
    # gradient around such target and then subtract again the initial state value.
    else:
        target_tm1 = jnp.array(errors) + v_tm1
        return jax.lax.stop_gradient(target_tm1) - v_tm1
Exemple #24
0
def power(x: Array, p: float) -> Array:
    """Power transform; `power_tx(_, 1/p)` is the inverse of `power_tx(_, p)`."""
    chex.assert_type(x, float)
    q = jnp.sqrt(p)
    return jnp.sign(x) * (jnp.power(jnp.abs(x) / q + 1., p) - 1) / q
Exemple #25
0
 def test_sample_dtype(self, dtype):
   dist = self.distrax_cls(preferences=self.preferences, dtype=dtype)
   samples = self.variant(dist.sample)(seed=self.key)
   self.assertEqual(samples.dtype, dist.dtype)
   chex.assert_type(samples, dtype)
Exemple #26
0
def signed_parabolic(x: Array, eps: float = 1e-3) -> Array:
    """Signed parabolic transform, inverse of signed_hyperbolic."""
    chex.assert_type(x, float)
    z = jnp.sqrt(1 + 4 * eps * (eps + 1 + jnp.abs(x))) / 2 / eps - 1 / 2 / eps
    return jnp.sign(x) * (jnp.square(z) - 1)
Exemple #27
0
def hyperbolic_arcsin(x: Array) -> Array:
    """Hyperbolic arcsinus transform."""
    chex.assert_type(x, float)
    return jnp.arcsinh(x)
Exemple #28
0
def signed_hyperbolic(x: Array, eps: float = 1e-3) -> Array:
    """Signed hyperbolic transform, inverse of signed_parabolic."""
    chex.assert_type(x, float)
    return jnp.sign(x) * (jnp.sqrt(jnp.abs(x) + 1) - 1) + eps * x
Exemple #29
0
def signed_expm1(x: Array) -> Array:
    """Signed exponential of x - 1, inverse of signed_logp1."""
    chex.assert_type(x, float)
    return jnp.sign(x) * (jnp.exp(jnp.abs(x)) - 1)
Exemple #30
0
def signed_logp1(x: Array) -> Array:
    """Signed logarithm of x + 1."""
    chex.assert_type(x, float)
    return jnp.sign(x) * jnp.log1p(jnp.abs(x))