Example #1
0
def vtrace(
    v_tm1: Array,
    v_t: Array,
    r_t: Array,
    discount_t: Array,
    rho_t: Array,
    lambda_: float = 1.0,
    clip_rho_threshold: float = 1.0,
    stop_target_gradients: bool = True,
) -> Array:
    """Calculates V-Trace errors from importance weights.

  V-trace computes TD-errors from multistep trajectories by applying
  off-policy corrections based on clipped importance sampling ratios.

  See "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor
  Learner Architectures" by Espeholt et al. (https://arxiv.org/abs/1802.01561).

  Args:
    v_tm1: values at time t-1.
    v_t: values at time t.
    r_t: reward at time t.
    discount_t: discount at time t.
    rho_t: importance sampling ratios.
    lambda_: scalar mixing parameter lambda.
    clip_rho_threshold: clip threshold for importance weights.
    stop_target_gradients: whether or not to apply stop gradient to targets.

  Returns:
    V-Trace error.
  """
    chex.assert_rank([v_tm1, v_t, r_t, discount_t, rho_t], [1, 1, 1, 1, 1])
    chex.assert_type([v_tm1, v_t, r_t, discount_t, rho_t],
                     [float, float, float, float, float])

    # Clip importance sampling ratios.
    c_t = jnp.minimum(1.0, rho_t) * lambda_
    clipped_rhos = jnp.minimum(clip_rho_threshold, rho_t)

    # Compute the temporal difference errors.
    td_errors = clipped_rhos * (r_t + discount_t * v_t - v_tm1)

    # Work backwards computing the td-errors.
    err = 0.0
    errors = []
    for i in jnp.arange(v_t.shape[0] - 1, -1, -1):
        err = td_errors[i] + discount_t[i] * c_t[i] * err
        errors.insert(0, err)

    # Return errors.
    if not stop_target_gradients:
        return jnp.array(errors)
    # In TD-like algorithms, we want gradients to only flow in the predictions,
    # and not in the values used to bootstrap. In this case, add the value of the
    # initial state value to get the implied estimates of the returns, stop
    # gradient around such target and then subtract again the initial state value.
    else:
        target_tm1 = jnp.array(errors) + v_tm1
        target_tm1 = jax.lax.stop_gradient(target_tm1)
    return target_tm1 - v_tm1
Example #2
0
def clipped_surrogate_pg_loss(
    prob_ratios_t: Array,
    adv_t: Array,
    epsilon: Scalar) -> Array:
  """Computes the clipped surrogate policy gradient loss.

  L_clipₜ(θ) = - min(rₜ(θ)Âₜ, clip(rₜ(θ), 1-ε, 1+ε)Âₜ)

  Where rₜ(θ) = π_θ(aₜ| sₜ) / π_θ_old(aₜ| sₜ) and Âₜ are the advantages.

  See Proximal Policy Optimization Algorithms, Schulman et al.:
  https://arxiv.org/abs/1707.06347

  Args:
    prob_ratios_t: Ratio of action probabilities for actions a_t:
        rₜ(θ) = π_θ(aₜ| sₜ) / π_θ_old(aₜ| sₜ)
    adv_t: the observed or estimated advantages from executing actions a_t.
    epsilon: Scalar value corresponding to how much to clip the objecctive.

  Returns:
    Loss whose gradient corresponds to a clipped surrogate policy gradient
        update.
  """
  chex.assert_rank([prob_ratios_t, adv_t], [1, 1])
  chex.assert_type([prob_ratios_t, adv_t], [float, float])

  clipped_ratios_t = jnp.clip(prob_ratios_t, 1. - epsilon, 1. + epsilon)
  clipped_objective = jnp.fmin(prob_ratios_t * adv_t, clipped_ratios_t * adv_t)
  return -jnp.mean(clipped_objective)
Example #3
0
def policy_gradient_loss(
    logits_t: Array,
    a_t: Array,
    adv_t: Array,
    w_t: Array,
) -> Array:
  """Calculates the policy gradient loss.

  See "Simple Gradient-Following Algorithms for Connectionist RL" by Williams.
  (http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf)

  Args:
    logits_t: a sequence of unnormalized action preferences.
    a_t: a sequence of actions sampled from the preferences `logits_t`.
    adv_t: the observed or estimated advantages from executing actions `a_t`.
    w_t: a per timestep weighting for the loss.

  Returns:
    Loss whose gradient corresponds to a policy gradient update.
  """
  chex.assert_rank([logits_t, a_t, adv_t, w_t], [2, 1, 1, 1])
  chex.assert_type([logits_t, a_t, adv_t, w_t], [float, int, float, float])

  log_pi_a_t = distributions.softmax().logprob(a_t, logits_t)
  adv_t = jax.lax.stop_gradient(adv_t)
  loss_per_timestep = -log_pi_a_t * adv_t
  return jnp.mean(loss_per_timestep * w_t)
Example #4
0
    def objective_func(self, params, state, hyperparams, rng, transition_batch,
                       Adv):
        rngs = hk.PRNGSequence(rng)

        # get distribution params from function approximator
        S = self.pi.observation_preprocessor(next(rngs), transition_batch.S)
        dist_params, state_new = self.pi.function(params, state, next(rngs), S,
                                                  True)

        # compute probability ratios
        A = self.pi.proba_dist.preprocess_variate(next(rngs),
                                                  transition_batch.A)
        log_pi = self.pi.proba_dist.log_proba(dist_params, A)
        ratio = jnp.exp(log_pi - transition_batch.logP)  # π_new / π_old
        ratio_clip = jnp.clip(ratio, 1 - hyperparams['epsilon'],
                              1 + hyperparams['epsilon'])

        # clip importance weights to reduce variance
        W = jnp.clip(transition_batch.W, 0.1, 10.)

        # ppo-clip objective
        chex.assert_equal_shape([W, Adv, ratio, ratio_clip])
        chex.assert_rank([W, Adv, ratio, ratio_clip], 1)
        objective = W * jnp.minimum(Adv * ratio, Adv * ratio_clip)

        # also pass auxiliary data to avoid multiple forward passes
        return jnp.mean(objective), (dist_params, log_pi, state_new)
Example #5
0
def qpg_loss(
    logits_t: Array,
    q_t: Array,
) -> Array:
  """Computes the QPG (Q-based Policy Gradient) loss.

  See "Actor-Critic Policy Optimization in Partially Observable Multiagent
  Environments" by Srinivasan, Lanctot.
  (https://papers.nips.cc/paper/7602-actor-critic-policy-optimization-in-partially-observable-multiagent-environments.pdf)

  Args:
    logits_t: a sequence of unnormalized action preferences.
    q_t: the observed or estimated action value from executing actions `a_t` at
      time t.
      regularization.

  Returns:
    QPG Loss.
  """
  chex.assert_rank([logits_t, q_t], 2)
  chex.assert_type([logits_t, q_t], float)

  policy_t, advantage_t = _compute_advantages(logits_t, q_t)
  policy_advantages = -policy_t * jax.lax.stop_gradient(advantage_t)
  loss = jnp.mean(jnp.sum(policy_advantages, axis=1), axis=0)
  return loss
Example #6
0
def qv_max(
    v_tm1: Numeric,
    r_t: Numeric,
    discount_t: Numeric,
    q_t: Array,
    stop_target_gradients: bool = True,
) -> Numeric:
    """Calculates the QVMAX temporal difference error.

  See "The QV Family Compared to Other Reinforcement Learning Algorithms" by
  Wiering and van Hasselt (2009).
  (http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.713.1931)

  Args:
    v_tm1: state values at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_t: Q-values at time t.
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.

  Returns:
    QVMAX temporal difference error.
  """
    chex.assert_rank([v_tm1, r_t, discount_t, q_t], [0, 0, 0, 1])
    chex.assert_type([v_tm1, r_t, discount_t, q_t], float)

    target_tm1 = r_t + discount_t * jnp.max(q_t)
    target_tm1 = jax.lax.select(stop_target_gradients,
                                jax.lax.stop_gradient(target_tm1), target_tm1)
    return target_tm1 - v_tm1
Example #7
0
 def _compute_posterior(
     self,
     inputs,
     encoder_outputs,
     context_vectors,
 ):
     """Computes the posterior branch of the DecoderBlock."""
     chex.assert_rank(inputs, 4)
     resolution = inputs.shape[1]
     try:
         encoded_image = encoder_outputs[resolution]
     except KeyError:
         raise KeyError(
             'encoder_outputs does not contain the required '  # pylint: disable=g-doc-exception
             f'resolution ({resolution}). encoder_outputs resolutions '
             f'are {list(encoder_outputs.keys())}.')
     posterior_block = blocks.ResBlock(self.bottlenecked_num_channels,
                                       self.latent_dim * 2,
                                       use_residual_connection=False,
                                       precision=self.precision,
                                       name='posterior_block')
     concatenated_inputs = jnp.concatenate([inputs, encoded_image], axis=3)
     posterior_output = posterior_block(concatenated_inputs,
                                        context_vectors)
     posterior_mean, posterior_log_std = jnp.split(posterior_output,
                                                   2,
                                                   axis=3)
     return posterior_mean, posterior_log_std
Example #8
0
def test_reshape_shape_forward(filters):

    n_dims = (1, 28, 28, 1)
    new_shape = _get_new_shapes(28, 28, 1, filters)
    params_rng, data_rng = jax.random.split(KEY, 2)

    x = jax.random.uniform(data_rng, shape=n_dims)

    # create layer
    init_func = Squeeze(filter_shape=filters, collapse=None, return_outputs=True)

    # create layer
    z_, params, forward_f, inverse_f = init_func(rng=params_rng, shape=n_dims, inputs=x)

    # forward transformation
    z, log_abs_det = forward_f(params, x)

    # checks
    chex.assert_tree_all_close(z, z_)
    chex.assert_equal_shape([z, log_abs_det, z_])
    chex.assert_rank(z, 4)
    chex.assert_equal(z.shape[1:], new_shape)

    # inverse transformation
    x_approx, log_abs_det = inverse_f(params, z)

    # checks
    chex.assert_equal_shape([x_approx, x])
    chex.assert_tree_all_close(x_approx, x)
Example #9
0
def test_reshape_shape_collapse(filters, collapse):

    n_dims = (1, 28, 28, 1)
    params_rng, data_rng = jax.random.split(KEY, 2)

    x = jax.random.uniform(data_rng, shape=n_dims)

    # create layer
    init_func = Squeeze(filter_shape=filters, collapse=collapse)

    # create layer
    params, forward_f, inverse_f = init_func(rng=params_rng, shape=n_dims,)

    # forward transformation
    z, log_abs_det = forward_f(params, x)

    # checks
    chex.assert_equal_shape([z, log_abs_det])
    chex.assert_rank(z, 2)

    # inverse transformation
    x_approx, log_abs_det = inverse_f(params, z)

    # checks
    chex.assert_equal_shape([x_approx, x])
    chex.assert_tree_all_close(x_approx, x)
Example #10
0
def fix_step_type_on_interruptions(step_type: chex.Array):
    """Returns step_type with a LAST step before almost every FIRST step.

  If the environment crashes or is interrupted while a trajectory is being
  written, the LAST step can be missing before a FIRST step. We add the LAST
  step before each FIRST step, if the step before the FIRST step is a MID step,
  to signal to the agent that the next observation is not connected to the
  current stream of data. Note that the agent must still then appropriately
  handle both `terminations` (e.g. game over in a game) and `interruptions` (a
  timeout or a reset for system maintenance): the value of the discount on LAST
  step will be > 0 on `interruptions`, while it will be 0 on `terminations`.
  Similar issues arise in hierarchical RL systems as well.

  Args:
    step_type: an array of `dm_env` step types, with shape `[T, B]`.

  Returns:
    Fixed step_type.
  """
    chex.assert_rank(step_type, 2)
    next_step_type = jnp.concatenate([
        step_type[1:],
        jnp.full(step_type[:1].shape,
                 int(dm_env.StepType.MID),
                 dtype=step_type.dtype),
    ],
                                     axis=0)
    return jnp.where(
        jnp.logical_and(
            jnp.equal(next_step_type, int(dm_env.StepType.FIRST)),
            jnp.equal(step_type, int(dm_env.StepType.MID)),
        ), int(dm_env.StepType.LAST), step_type)
Example #11
0
def add_ornstein_uhlenbeck_noise(
    key: Array,
    action: Array,
    noise_tm1: Array,
    damping: float,
    stddev: float
) -> Array:
  """Returns continuous action with noise from Ornstein-Uhlenbeck process.

  See "On the theory of Brownian Motion" by Uhlenbeck and Ornstein.
  (https://journals.aps.org/pr/abstract/10.1103/PhysRev.36.823).

  Args:
    key: a key from `jax.random`.
    action: continuous action scalar or vector.
    noise_tm1: noise sampled from OU process in previous timestep.
    damping: parameter for controlling autocorrelation of OU process.
    stddev: standard deviation of noise distribution.

  Returns:
    noisy action, of the same shape as input action.
  """
  chex.assert_rank([action, noise_tm1], 1)
  chex.assert_type([action, noise_tm1], float)

  noise_t = (1. - damping) * noise_tm1 + jax.random.normal(
      key, shape=action.shape) * stddev

  return action + noise_t
Example #12
0
def rpg_loss(
    logits_t: Array,
    q_t: Array,
    use_stop_gradient: bool = True,
) -> Array:
    """Computes the RPG (Regret Policy Gradient) loss.

  The gradient of this loss adapts the Regret Matching rule by weighting the
  standard PG update with regret.

  See "Actor-Critic Policy Optimization in Partially Observable Multiagent
  Environments" by Srinivasan, Lanctot (https://arxiv.org/abs/1810.09026).

  Args:
    logits_t: a sequence of unnormalized action preferences.
    q_t: the observed or estimated action value from executing actions `a_t` at
      time t.
    use_stop_gradient: bool indicating whether or not to apply stop gradient to
      advantages.

  Returns:
    RPG Loss.
  """
    chex.assert_rank([logits_t, q_t], 2)
    chex.assert_type([logits_t, q_t], float)

    _, adv_t = _compute_advantages(logits_t, q_t, use_stop_gradient)
    regrets_t = jnp.sum(jax.nn.relu(adv_t), axis=1)

    total_regret_t = jnp.mean(regrets_t, axis=0)
    return total_regret_t
Example #13
0
def qpg_loss(
    logits_t: Array,
    q_t: Array,
    use_stop_gradient: bool = True,
) -> Array:
    """Computes the QPG (Q-based Policy Gradient) loss.

  See "Actor-Critic Policy Optimization in Partially Observable Multiagent
  Environments" by Srinivasan, Lanctot (https://arxiv.org/abs/1810.09026).

  Args:
    logits_t: a sequence of unnormalized action preferences.
    q_t: the observed or estimated action value from executing actions `a_t` at
      time t.
    use_stop_gradient: bool indicating whether or not to apply stop gradient to
      advantages.

  Returns:
    QPG Loss.
  """
    chex.assert_rank([logits_t, q_t], 2)
    chex.assert_type([logits_t, q_t], float)

    policy_t, advantage_t = _compute_advantages(logits_t, q_t)
    advantage_t = jax.lax.select(use_stop_gradient,
                                 jax.lax.stop_gradient(advantage_t),
                                 advantage_t)
    policy_advantages = -policy_t * advantage_t
    loss = jnp.mean(jnp.sum(policy_advantages, axis=1), axis=0)
    return loss
Example #14
0
def leaky_vtrace(v_tm1: Array,
                 v_t: Array,
                 r_t: Array,
                 discount_t: Array,
                 rho_t: Array,
                 alpha_: float = 1.0,
                 lambda_: float = 1.0,
                 clip_rho_threshold: float = 1.0,
                 stop_target_gradients: bool = True):
    """Calculates Leaky V-Trace errors from importance weights.

  Leaky-Vtrace is a combination of Importance sampling and V-trace, where the
  degree of mixing is controlled by a scalar `alpha` (that may be meta-learnt).

  See "Self-Tuning Deep Reinforcement Learning"
  by Zahavy et al. (https://arxiv.org/abs/2002.12928)

  Args:
    v_tm1: values at time t-1.
    v_t: values at time t.
    r_t: reward at time t.
    discount_t: discount at time t.
    rho_t: importance weights at time t.
    alpha_: mixing parameter for Importance Sampling and V-trace.
    lambda_: scalar mixing parameter lambda.
    clip_rho_threshold: clip threshold for importance weights.
    stop_target_gradients: whether or not to apply stop gradient to targets.

  Returns:
    Leaky V-Trace error.
  """
    chex.assert_rank([v_tm1, v_t, r_t, discount_t, rho_t], [1, 1, 1, 1, 1])
    chex.assert_type([v_tm1, v_t, r_t, discount_t, rho_t],
                     [float, float, float, float, float])

    # Mix clipped and unclipped importance sampling ratios.
    c_t = ((1 - alpha_) * rho_t + alpha_ * jnp.minimum(1.0, rho_t)) * lambda_
    clipped_rhos = ((1 - alpha_) * rho_t +
                    alpha_ * jnp.minimum(clip_rho_threshold, rho_t))

    # Compute the temporal difference errors.
    td_errors = clipped_rhos * (r_t + discount_t * v_t - v_tm1)

    # Work backwards computing the td-errors.
    err = 0.0
    errors = []
    for i in jnp.arange(v_t.shape[0] - 1, -1, -1):
        err = td_errors[i] + discount_t[i] * c_t[i] * err
        errors.insert(0, err)

    # Return errors.
    if not stop_target_gradients:
        return jnp.array(errors)
    # In TD-like algorithms, we want gradients to only flow in the predictions,
    # and not in the values used to bootstrap. In this case, add the value of the
    # initial state value to get the implied estimates of the returns, stop
    # gradient around such target and then subtract again the initial state value.
    else:
        target_tm1 = jnp.array(errors) + v_tm1
        return jax.lax.stop_gradient(target_tm1) - v_tm1
Example #15
0
def qv_learning(
    q_tm1: Array,
    a_tm1: Numeric,
    r_t: Numeric,
    discount_t: Numeric,
    v_t: Numeric,
    stop_target_gradients: bool = True,
) -> Numeric:
    """Calculates the QV-learning temporal difference error.

  See "Two Novel On-policy Reinforcement Learning Algorithms based on
  TD(lambda)-methods" by Wiering and van Hasselt
  (https://ieeexplore.ieee.org/abstract/document/4220845.)

  Args:
    q_tm1: Q-values at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    v_t: state values at time t.
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.

  Returns:
    QV-learning temporal difference error.
  """
    chex.assert_rank([q_tm1, a_tm1, r_t, discount_t, v_t], [1, 0, 0, 0, 0])
    chex.assert_type([q_tm1, a_tm1, r_t, discount_t, v_t],
                     [float, int, float, float, float])

    target_tm1 = r_t + discount_t * v_t
    target_tm1 = jax.lax.select(stop_target_gradients,
                                jax.lax.stop_gradient(target_tm1), target_tm1)
    return target_tm1 - q_tm1[a_tm1]
Example #16
0
def expected_sarsa(
    q_tm1: Array,
    a_tm1: Numeric,
    r_t: Numeric,
    discount_t: Numeric,
    q_t: Array,
    probs_a_t: Array,
) -> Numeric:
  """Calculates the expected SARSA (SARSE) temporal difference error.

  See "A Theoretical and Empirical Analysis of Expected Sarsa" by Seijen,
  van Hasselt, Whiteson et al.
  (http://www.cs.ox.ac.uk/people/shimon.whiteson/pubs/vanseijenadprl09.pdf).

  Args:
    q_tm1: Q-values at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_t: Q-values at time t.
    probs_a_t: action probabilities at time t.

  Returns:
    Expected SARSA temporal difference error.
  """
  chex.assert_rank([q_tm1, a_tm1, r_t, discount_t, q_t, probs_a_t],
                   [1, 0, 0, 0, 1, 1])
  chex.assert_type([q_tm1, a_tm1, r_t, discount_t, q_t, probs_a_t],
                   [float, int, float, float, float, float])

  target_tm1 = r_t + discount_t * jnp.dot(q_t, probs_a_t)
  return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
Example #17
0
def td_learning(
    v_tm1: Numeric,
    r_t: Numeric,
    discount_t: Numeric,
    v_t: Numeric,
    stop_target_gradients: bool = True,
) -> Numeric:
    """Calculates the TD-learning temporal difference error.

  See "Learning to Predict by the Methods of Temporal Differences" by Sutton.
  (https://link.springer.com/article/10.1023/A:1022633531479).

  Args:
    v_tm1: state values at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    v_t: state values at time t.
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.
  Returns:
    TD-learning temporal difference error.
  """

    chex.assert_rank([v_tm1, r_t, discount_t, v_t], 0)
    chex.assert_type([v_tm1, r_t, discount_t, v_t], float)

    target_tm1 = r_t + discount_t * v_t
    target_tm1 = jax.lax.select(stop_target_gradients,
                                jax.lax.stop_gradient(target_tm1), target_tm1)
    return target_tm1 - v_tm1
Example #18
0
def double_q_learning(
    q_tm1: Array,
    a_tm1: Numeric,
    r_t: Numeric,
    discount_t: Numeric,
    q_t_value: Array,
    q_t_selector: Array,
) -> Numeric:
  """Calculates the double Q-learning temporal difference error.

  See "Double Q-learning" by van Hasselt.
  (https://papers.nips.cc/paper/3964-double-q-learning.pdf).

  Args:
    q_tm1: Q-values at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_t_value: Q-values at time t.
    q_t_selector: selector Q-values at time t.

  Returns:
    Double Q-learning temporal difference error.
  """
  chex.assert_rank([q_tm1, a_tm1, r_t, discount_t, q_t_value, q_t_selector],
                   [1, 0, 0, 0, 1, 1])
  chex.assert_type([q_tm1, a_tm1, r_t, discount_t, q_t_value, q_t_selector],
                   [float, int, float, float, float, float])

  target_tm1 = r_t + discount_t * q_t_value[q_t_selector.argmax()]
  return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
Example #19
0
def td_lambda(
    v_tm1: Array,
    r_t: Array,
    discount_t: Array,
    v_t: Array,
    lambda_: Numeric,
    stop_target_gradients: bool = True,
) -> Array:
    """Calculates the TD(lambda) temporal difference error.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/book/ebook/node74.html).

  Args:
    v_tm1: sequence of state values at time t-1.
    r_t: sequence of rewards at time t.
    discount_t: sequence of discounts at time t.
    v_t: sequence of state values at time t.
    lambda_: mixing parameter lambda, either a scalar or a sequence.
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.
  Returns:
    TD(lambda) temporal difference error.
  """
    chex.assert_rank([v_tm1, r_t, discount_t, v_t, lambda_],
                     [1, 1, 1, 1, {0, 1}])
    chex.assert_type([v_tm1, r_t, discount_t, v_t, lambda_], float)

    target_tm1 = multistep.lambda_returns(r_t, discount_t, v_t, lambda_)
    target_tm1 = jax.lax.select(stop_target_gradients,
                                jax.lax.stop_gradient(target_tm1), target_tm1)
    return target_tm1 - v_tm1
Example #20
0
def persistent_q_learning(
    q_tm1: Array,
    a_tm1: Numeric,
    r_t: Numeric,
    discount_t: Numeric,
    q_t: Array,
    action_gap_scale: float,
) -> Numeric:
  """Calculates the persistent Q-learning temporal difference error.

  See "Increasing the Action Gap: New Operators for Reinforcement Learning"
  by Bellemare, Ostrovski, Guez et al. (https://arxiv.org/abs/1512.04860).

  Args:
    q_tm1: Q-values at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_t: Q-values at time t.
    action_gap_scale: coefficient in [0, 1] for scaling the action gap term.

  Returns:
    Persistent Q-learning temporal difference error.
  """
  chex.assert_rank([q_tm1, a_tm1, r_t, discount_t, q_t], [1, 0, 0, 0, 1])
  chex.assert_type([q_tm1, a_tm1, r_t, discount_t, q_t],
                   [float, int, float, float, float])

  corrected_q_t = (
      (1. - action_gap_scale) * jnp.max(q_t)
      + action_gap_scale * q_t[a_tm1]
  )
  target_tm1 = r_t + discount_t * corrected_q_t
  return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
Example #21
0
        def affine_transform(dist_params, scale, shift, value_transform=None):
            chex.assert_rank([dist_params['values'], scale, shift],
                             [2, {0, 1}, {0, 1}])
            values = check_shape(dist_params['values'], 'values')
            quantile_fractions = check_shape(dist_params['quantile_fractions'],
                                             'quantile_fractions')
            batch_size = values.shape[0]

            if isscalar(scale):
                scale = jnp.full(shape=(batch_size, 1),
                                 fill_value=jnp.squeeze(scale))
            if isscalar(shift):
                shift = jnp.full(shape=(batch_size, 1),
                                 fill_value=jnp.squeeze(shift))

            scale = jnp.reshape(scale, (batch_size, 1))
            shift = jnp.reshape(shift, (batch_size, 1))

            chex.assert_shape(values, (batch_size, self.num_quantiles))
            chex.assert_shape([scale, shift], (batch_size, 1))

            if value_transform is None:
                f = f_inv = lambda x: x
            else:
                f, f_inv = value_transform

            return {
                'values': f(shift + scale * f_inv(values)),
                'quantile_fractions': quantile_fractions
            }
Example #22
0
def _categorical_l2_project(
    z_p: Array,
    probs: Array,
    z_q: Array
) -> Array:
  """Projects a categorical distribution (z_p, p) onto a different support z_q.

  The projection step minimizes an L2-metric over the cumulative distribution
  functions (CDFs) of the source and target distributions.

  Let kq be len(z_q) and kp be len(z_p). This projection works for any
  support z_q, in particular kq need not be equal to kp.

  See "A Distributional Perspective on RL" by Bellemare et al.
  (https://arxiv.org/abs/1707.06887).

  Args:
    z_p: support of distribution p.
    probs: probability values.
    z_q: support to project distribution (z_p, probs) onto.

  Returns:
    Projection of (z_p, p) onto support z_q under Cramer distance.
  """
  chex.assert_rank([z_p, probs, z_q], 1)
  chex.assert_type([z_p, probs, z_q], float)

  kp = z_p.shape[0]
  kq = z_q.shape[0]

  # Construct helper arrays from z_q.
  d_pos = jnp.roll(z_q, shift=-1)
  d_neg = jnp.roll(z_q, shift=1)

  # Clip z_p to be in new support range (vmin, vmax).
  z_p = jnp.clip(z_p, z_q[0], z_q[-1])[None, :]
  assert z_p.shape == (1, kp)

  # Get the distance between atom values in support.
  d_pos = (d_pos - z_q)[:, None]  # z_q[i+1] - z_q[i]
  d_neg = (z_q - d_neg)[:, None]  # z_q[i] - z_q[i-1]
  z_q = z_q[:, None]
  assert z_q.shape == (kq, 1)

  # Ensure that we do not divide by zero, in case of atoms of identical value.
  d_neg = jnp.where(d_neg > 0, 1. / d_neg, jnp.zeros_like(d_neg))
  d_pos = jnp.where(d_pos > 0, 1. / d_pos, jnp.zeros_like(d_pos))

  delta_qp = z_p - z_q  # clip(z_p)[j] - z_q[i]
  d_sign = (delta_qp >= 0.).astype(probs.dtype)
  assert delta_qp.shape == (kq, kp)
  assert d_sign.shape == (kq, kp)

  # Matrix of entries sgn(a_ij) * |a_ij|, with a_ij = clip(z_p)[j] - z_q[i].
  delta_hat = (d_sign * delta_qp * d_pos) - ((1. - d_sign) * delta_qp * d_neg)
  probs = probs[None, :]
  assert delta_hat.shape == (kq, kp)
  assert probs.shape == (1, kp)

  return jnp.sum(jnp.clip(1. - delta_hat, 0., 1.) * probs, axis=-1)
Example #23
0
    def update(self, idx, Adv):
        r"""

        Update the priority weights of transitions previously added to the buffer.

        Parameters
        ----------
        idx : 1d array of ints

            The identifiers of the transitions to be updated.

        Adv : ndarray

            The corresponding updated advantages.

        """
        idx = onp.asarray(idx, dtype='int32')
        Adv = onp.asarray(Adv, dtype='float32')
        chex.assert_equal_shape([idx, Adv])
        chex.assert_rank([idx, Adv], 1)

        idx_lookup = idx % self.capacity  # wrap around
        new_values = onp.where(
            _get_transition_batch_idx(
                self._storage[idx_lookup]) == idx,  # only update if ids match
            onp.power(onp.abs(Adv) + self.epsilon, self.alpha),
            self._sumtree.values[idx_lookup])
        self._sumtree.set_values(idx_lookup, new_values)
Example #24
0
def sarsa(
    q_tm1: Array,
    a_tm1: Numeric,
    r_t: Numeric,
    discount_t: Numeric,
    q_t: Array,
    a_t: Numeric,
) -> Numeric:
  """Calculates the SARSA temporal difference error.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/book/ebook/node64.html.)

  Args:
    q_tm1: Q-values at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_t: Q-values at time t.
    a_t: action index at time t.

  Returns:
    SARSA temporal difference error.
  """
  chex.assert_rank([q_tm1, a_tm1, r_t, discount_t, q_t, a_t],
                   [1, 0, 0, 0, 1, 0])
  chex.assert_type([q_tm1, a_tm1, r_t, discount_t, q_t, a_t],
                   [float, int, float, float, float, int])

  target_tm1 = r_t + discount_t * q_t[a_t]
  return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
Example #25
0
def rpg_loss(
    logits_t: Array,
    q_t: Array,
) -> Array:
  """Computes the RPG (Regret Policy Gradient) loss.

  The gradient of this loss adapts the Regret Matching rule by weighting the
  standard PG update with regret.

  See "Actor-Critic Policy Optimization in Partially Observable Multiagent
  Environments" by Srinivasan, Lanctot.
  (https://papers.nips.cc/paper/7602-actor-critic-policy-optimization-in-partially-observable-multiagent-environments.pdf)

  Args:
    logits_t: a sequence of unnormalized action preferences.
    q_t: the observed or estimated action value from executing actions `a_t` at
      time t.

  Returns:
    RPG Loss.
  """
  chex.assert_rank([logits_t, q_t], 2)
  chex.assert_type([logits_t, q_t], float)

  _, adv_t = _compute_advantages(logits_t, q_t)
  regrets_t = jnp.sum(jax.nn.relu(adv_t), axis=1)

  total_regret_t = jnp.mean(regrets_t, axis=0)
  return total_regret_t
Example #26
0
    def objective_func(self, params, state, hyperparams, rng, transition_batch,
                       Adv):
        rngs = hk.PRNGSequence(rng)

        # get distribution params from function approximator
        S = self.pi.observation_preprocessor(next(rngs), transition_batch.S)
        dist_params, state_new = self.pi.function(params, state, next(rngs), S,
                                                  True)

        # compute objective: q(s, a_greedy)
        S = self.q_targ.observation_preprocessor(next(rngs),
                                                 transition_batch.S)
        A = self.pi.proba_dist.mode(dist_params)
        log_pi = self.pi.proba_dist.log_proba(dist_params, A)
        params_q, state_q = hyperparams['q']['params'], hyperparams['q'][
            'function_state']
        Q, _ = self.q_targ.function_type1(params_q, state_q, next(rngs), S, A,
                                          True)

        # clip importance weights to reduce variance
        W = jnp.clip(transition_batch.W, 0.1, 10.)

        # the objective
        chex.assert_equal_shape([W, Q])
        chex.assert_rank([W, Q], 1)
        objective = W * Q

        return jnp.mean(objective), (dist_params, log_pi, state_new)
Example #27
0
def dpg_loss(
    a_t: Array,
    dqda_t: Array,
    dqda_clipping: Optional[Scalar] = None
) -> Array:
  """Calculates the deterministic policy gradient (DPG) loss.

  See "Deterministic Policy Gradient Algorithms" by Silver, Lever, Heess,
  Degris, Wierstra, Riedmiller (http://proceedings.mlr.press/v32/silver14.pdf).

  Args:
    a_t: continuous-valued action at time t.
    dqda_t: gradient of Q(s,a) wrt. a, evaluated at time t.
    dqda_clipping: clips the gradient to have norm <= `dqda_clipping`.

  Returns:
    DPG loss.
  """
  chex.assert_rank([a_t, dqda_t], 1)
  chex.assert_type([a_t, dqda_t], float)

  if dqda_clipping is not None:
    dqda_t = _clip_by_l2_norm(dqda_t, dqda_clipping)
  target_tm1 = dqda_t + a_t
  return losses.l2_loss(jax.lax.stop_gradient(target_tm1) - a_t)
Example #28
0
def q_learning(
    q_tm1: Array,
    a_tm1: Numeric,
    r_t: Numeric,
    discount_t: Numeric,
    q_t: Array,
    stop_target_gradients: bool = True,
) -> Numeric:
    """Calculates the Q-learning temporal difference error.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/book/ebook/node65.html).

  Args:
    q_tm1: Q-values at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_t: Q-values at time t.
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.

  Returns:
    Q-learning temporal difference error.
  """
    chex.assert_rank([q_tm1, a_tm1, r_t, discount_t, q_t], [1, 0, 0, 0, 1])
    chex.assert_type([q_tm1, a_tm1, r_t, discount_t, q_t],
                     [float, int, float, float, float])

    target_tm1 = r_t + discount_t * jnp.max(q_t)
    target_tm1 = jax.lax.select(stop_target_gradients,
                                jax.lax.stop_gradient(target_tm1), target_tm1)
    return target_tm1 - q_tm1[a_tm1]
Example #29
0
def discounted_returns(
    r_t: Array,
    discount_t: Array,
    v_t: Array
) -> Array:
  """Calculates a discounted return from a trajectory.

  The returns are computed recursively, from `G_{T-1}` to `G_0`, according to:

    Gₜ = rₜ₊₁ + γₜ₊₁ Gₜ₊₁.

  See "Reinforcement Learning: An Introduction" by Sutton and Barto.
  (http://incompleteideas.net/sutton/book/ebook/node61.html).

  Args:
    r_t: reward sequence at time t.
    discount_t: discount sequence at time t.
    v_t: value sequence or scalar at time t.

  Returns:
    Discounted returns.
  """
  chex.assert_rank([r_t, discount_t, v_t], [1, 1, {0, 1}])
  chex.assert_type([r_t, discount_t, v_t], float)

  # If scalar make into vector.
  bootstrapped_v = jnp.ones_like(discount_t) * v_t
  return lambda_returns(r_t, discount_t, bootstrapped_v, lambda_=1.)
Example #30
0
def n_step_bootstrapped_returns(r_t: Array, discount_t: Array, v_t: Array,
                                n: int) -> Array:
    """Computes strided n-step bootstrapped return targets over a sequence.

  The returns are computed in a backwards fashion according to the equation:

     Gₜ = rₜ₊₁ + γₜ₊₁ * (rₜ₊₂ + γₜ₊₂ * (... * (rₜ₊ₙ + γₜ₊ₙ * vₜ₊ₙ ))),

  Args:
    r_t: rewards at times [1, ..., T].
    discount_t: discounts at times [1, ..., T].
    v_t: state or state-action values to bootstrap from at time [1, ...., T]
    n: number of steps over which to accumulate reward before bootstrapping.

  Returns:
    estimated bootstrapped returns at times [1, ...., T]
  """
    chex.assert_rank([r_t, discount_t, v_t], 1)
    chex.assert_type([r_t, discount_t, v_t], float)
    seq_len = r_t.shape[0]

    # Pad end of reward and discount sequences with 0 and 1 respectively.
    r_t = jnp.concatenate([r_t, jnp.zeros(n - 1)])
    discount_t = jnp.concatenate([discount_t, jnp.ones(n - 1)])

    # Shift bootstrap values by n and pad end of sequence with last value v_t[-1].
    pad_size = min(n - 1, seq_len)
    targets = jnp.concatenate([v_t[n - 1:], jnp.array([v_t[-1]] * pad_size)])

    # Work backwards to compute discounted, bootstrapped n-step returns.
    for i in reversed(range(n)):
        targets = r_t[i:i + seq_len] + discount_t[i:i + seq_len] * targets
    return targets