def persistent_q_learning( q_tm1: ArrayLike, a_tm1: ArrayOrScalar, r_t: ArrayOrScalar, discount_t: ArrayOrScalar, q_t: ArrayLike, action_gap_scale: float, ) -> ArrayOrScalar: """Calculates the persistent Q-learning temporal difference error. See "Increasing the Action Gap: New Operators for Reinforcement Learning" by Bellemare, Ostrovski, Guez et al. (https://arxiv.org/abs/1512.04860). Args: q_tm1: Q-values at time t-1. a_tm1: action index at time t-1. r_t: reward at time t. discount_t: discount at time t. q_t: Q-values at time t. action_gap_scale: coefficient in [0, 1] for scaling the action gap term. Returns: Persistent Q-learning temporal difference error. """ base.rank_assert([q_tm1, a_tm1, r_t, discount_t, q_t], [1, 0, 0, 0, 1]) base.type_assert([q_tm1, a_tm1, r_t, discount_t, q_t], [float, int, float, float, float]) corrected_q_t = ((1. - action_gap_scale) * jnp.max(q_t) + action_gap_scale * q_t[a_tm1]) target_tm1 = r_t + discount_t * corrected_q_t return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
def sarsa( q_tm1: ArrayLike, a_tm1: ArrayOrScalar, r_t: ArrayOrScalar, discount_t: ArrayOrScalar, q_t: ArrayLike, a_t: ArrayOrScalar, ) -> ArrayOrScalar: """Calculates the SARSA temporal difference error. See "Reinforcement Learning: An Introduction" by Sutton and Barto. (http://incompleteideas.net/book/ebook/node64.html.) Args: q_tm1: Q-values at time t-1. a_tm1: action index at time t-1. r_t: reward at time t. discount_t: discount at time t. q_t: Q-values at time t. a_t: action index at time t. Returns: SARSA temporal difference error. """ base.rank_assert([q_tm1, a_tm1, r_t, discount_t, q_t, a_t], [1, 0, 0, 0, 1, 0]) base.type_assert([q_tm1, a_tm1, r_t, discount_t, q_t, a_t], [float, int, float, float, float, int]) target_tm1 = r_t + discount_t * q_t[a_t] return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
def td_learning( v_tm1: ArrayOrScalar, r_t: ArrayOrScalar, discount_t: ArrayOrScalar, v_t: ArrayOrScalar, ) -> ArrayOrScalar: """Calculates the TD-learning temporal difference error. See "Learning to Predict by the Methods of Temporal Differences" by Sutton. (https://link.springer.com/article/10.1023/A:1022633531479). Args: v_tm1: state values at time t-1. r_t: reward at time t. discount_t: discount at time t. v_t: state values at time t. Returns: TD-learning temporal difference error. """ base.rank_assert([v_tm1, r_t, discount_t, v_t], 0) base.type_assert([v_tm1, r_t, discount_t, v_t], float) target_tm1 = r_t + discount_t * v_t return jax.lax.stop_gradient(target_tm1) - v_tm1
def qv_max( v_tm1: ArrayOrScalar, r_t: ArrayOrScalar, discount_t: ArrayOrScalar, q_t: ArrayLike, ) -> ArrayOrScalar: """Calculates the QVMAX temporal difference error. See "The QV Family Compared to Other Reinforcement Learning Algorithms" by Wiering and van Hasselt (2009). (http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.713.1931) Args: v_tm1: state values at time t-1. r_t: reward at time t. discount_t: discount at time t. q_t: Q-values at time t. Returns: QVMAX temporal difference error. """ base.rank_assert([v_tm1, r_t, discount_t, q_t], [0, 0, 0, 1]) base.type_assert([v_tm1, r_t, discount_t, q_t], float) target_tm1 = r_t + discount_t * jnp.max(q_t) return jax.lax.stop_gradient(target_tm1) - v_tm1
def expected_sarsa( q_tm1: ArrayLike, a_tm1: ArrayOrScalar, r_t: ArrayOrScalar, discount_t: ArrayOrScalar, q_t: ArrayLike, probs_a_t: ArrayLike, ) -> ArrayOrScalar: """Calculates the expected SARSA (SARSE) temporal difference error. See "A Theoretical and Empirical Analysis of Expected Sarsa" by Seijen, van Hasselt, Whiteson et al. (http://www.cs.ox.ac.uk/people/shimon.whiteson/pubs/vanseijenadprl09.pdf). Args: q_tm1: Q-values at time t-1. a_tm1: action index at time t-1. r_t: reward at time t. discount_t: discount at time t. q_t: Q-values at time t. probs_a_t: action probabilities at time t. Returns: Expected SARSA temporal difference error. """ base.rank_assert([q_tm1, a_tm1, r_t, discount_t, q_t, probs_a_t], [1, 0, 0, 0, 1, 1]) base.type_assert([q_tm1, a_tm1, r_t, discount_t, q_t, probs_a_t], [float, int, float, float, float, float]) target_tm1 = r_t + discount_t * jnp.dot(q_t, probs_a_t) return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
def discounted_returns( r_t: ArrayLike, discount_t: ArrayLike, v_t: ArrayLike ) -> ArrayLike: """Calculates a discounted return from a trajectory. The returns are computed recursively, from `G_{T-1}` to `G_0`, according to: Gₜ = rₜ₊₁ + γₜ₊₁ Gₜ₊₁. See "Reinforcement Learning: An Introduction" by Sutton and Barto. (http://incompleteideas.net/sutton/book/ebook/node61.html). Args: r_t: reward sequence at time t. discount_t: discount sequence at time t. v_t: value sequence or scalar at time t. Returns: Discounted returns. """ base.rank_assert([r_t, discount_t, v_t], [1, 1, [0, 1]]) base.type_assert([r_t, discount_t, v_t], float) # If scalar make into vector. bootstrapped_v = jnp.ones_like(discount_t) * v_t return lambda_returns(r_t, discount_t, bootstrapped_v, lambda_=1.)
def double_q_learning( q_tm1: ArrayLike, a_tm1: ArrayOrScalar, r_t: ArrayOrScalar, discount_t: ArrayOrScalar, q_t_value: ArrayLike, q_t_selector: ArrayLike, ) -> ArrayOrScalar: """Calculates the double Q-learning temporal difference error. See "Double Q-learning" by van Hasselt. (https://papers.nips.cc/paper/3964-double-q-learning.pdf). Args: q_tm1: Q-values at time t-1. a_tm1: action index at time t-1. r_t: reward at time t. discount_t: discount at time t. q_t_value: Q-values at time t. q_t_selector: selector Q-values at time t. Returns: Double Q-learning temporal difference error. """ base.rank_assert([q_tm1, a_tm1, r_t, discount_t, q_t_value, q_t_selector], [1, 0, 0, 0, 1, 1]) base.type_assert([q_tm1, a_tm1, r_t, discount_t, q_t_value, q_t_selector], [float, int, float, float, float, float]) target_tm1 = r_t + discount_t * q_t_value[q_t_selector.argmax()] return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
def qv_learning( q_tm1: ArrayLike, a_tm1: ArrayOrScalar, r_t: ArrayOrScalar, discount_t: ArrayOrScalar, v_t: ArrayOrScalar, ) -> ArrayOrScalar: """Calculates the QV-learning temporal difference error. See "Two Novel On-policy Reinforcement Learning Algorithms based on TD(lambda)-methods" by Wiering and van Hasselt (https://ieeexplore.ieee.org/abstract/document/4220845.) Args: q_tm1: Q-values at time t-1. a_tm1: action index at time t-1. r_t: reward at time t. discount_t: discount at time t. v_t: state values at time t. Returns: QV-learning temporal difference error. """ base.rank_assert([q_tm1, a_tm1, r_t, discount_t, v_t], [1, 0, 0, 0, 0]) base.type_assert([q_tm1, a_tm1, r_t, discount_t, v_t], [float, int, float, float, float]) target_tm1 = r_t + discount_t * v_t return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
def _categorical_l2_project( z_p: ArrayLike, probs: ArrayLike, z_q: ArrayLike ) -> ArrayLike: """Projects a categorical distribution (z_p, p) onto a different support z_q. The projection step minimizes an L2-metric over the cumulative distribution functions (CDFs) of the source and target distributions. Let kq be len(z_q) and kp be len(z_p). This projection works for any support z_q, in particular kq need not be equal to kp. See "A Distributional Perspective on RL" by Bellemare et al. (https://arxiv.org/abs/1707.06887). Args: z_p: support of distribution p. probs: probability values. z_q: support to project distribution (z_p, probs) onto. Returns: Projection of (z_p, p) onto support z_q under Cramer distance. """ base.rank_assert([z_p, probs, z_q], 1) base.type_assert([z_p, probs, z_q], float) kp = z_p.shape[0] kq = z_q.shape[0] # Construct helper arrays from z_q. d_pos = jnp.roll(z_q, shift=-1) d_neg = jnp.roll(z_q, shift=1) # Clip z_p to be in new support range (vmin, vmax). z_p = jnp.clip(z_p, z_q[0], z_q[-1])[None, :] assert z_p.shape == (1, kp) # Get the distance between atom values in support. d_pos = (d_pos - z_q)[:, None] # z_q[i+1] - z_q[i] d_neg = (z_q - d_neg)[:, None] # z_q[i] - z_q[i-1] z_q = z_q[:, None] assert z_q.shape == (kq, 1) # Ensure that we do not divide by zero, in case of atoms of identical value. d_neg = jnp.where(d_neg > 0, 1. / d_neg, jnp.zeros_like(d_neg)) d_pos = jnp.where(d_pos > 0, 1. / d_pos, jnp.zeros_like(d_pos)) delta_qp = z_p - z_q # clip(z_p)[j] - z_q[i] d_sign = (delta_qp >= 0.).astype(probs.dtype) assert delta_qp.shape == (kq, kp) assert d_sign.shape == (kq, kp) # Matrix of entries sgn(a_ij) * |a_ij|, with a_ij = clip(z_p)[j] - z_q[i]. delta_hat = (d_sign * delta_qp * d_pos) - ((1. - d_sign) * delta_qp * d_neg) probs = probs[None, :] assert delta_hat.shape == (kq, kp) assert probs.shape == (1, kp) return jnp.sum(jnp.clip(1. - delta_hat, 0., 1.) * probs, axis=-1)
def policy_gradient_loss( logits_t: ArrayLike, a_t: ArrayLike, adv_t: ArrayLike, w_t: ArrayLike, ) -> ArrayLike: """Calculates the policy gradient loss. See "Simple Gradient-Following Algorithms for Connectionist RL" by Williams. (http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf) Args: logits_t: a sequence of unnormalized action preferences. a_t: a sequence of actions sampled from the preferences `logits_t`. adv_t: the observed or estimated advantages from executing actions `a_t`. w_t: a per timestep weighting for the loss. Returns: Loss whose gradient corresponds to a policy gradient update. """ base.rank_assert([logits_t, a_t, adv_t, w_t], [2, 1, 1, 1]) base.type_assert([logits_t, a_t, adv_t, w_t], [float, int, float, float]) log_pi_a = distributions.softmax().logprob(a_t, logits_t) adv_t = jax.lax.stop_gradient(adv_t) loss_per_timestep = -log_pi_a * adv_t return jnp.mean(loss_per_timestep * w_t)
def td_lambda( v_tm1: ArrayLike, r_t: ArrayLike, discount_t: ArrayLike, v_t: ArrayLike, lambda_: ArrayOrScalar, ) -> ArrayLike: """Calculates the TD(lambda) temporal difference error. See "Reinforcement Learning: An Introduction" by Sutton and Barto. (http://incompleteideas.net/book/ebook/node74.html). Args: v_tm1: sequence of state values at time t-1. r_t: sequence of rewards at time t. discount_t: sequence of discounts at time t. v_t: sequence of state values at time t. lambda_: mixing parameter lambda, either a scalar or a sequence. Returns: TD(lambda) temporal difference error. """ base.rank_assert([v_tm1, r_t, discount_t, v_t, lambda_], [1, 1, 1, 1, [0, 1]]) base.type_assert([v_tm1, r_t, discount_t, v_t, lambda_], float) target_tm1 = multistep.lambda_returns(r_t, discount_t, v_t, lambda_) return jax.lax.stop_gradient(target_tm1) - v_tm1
def q_learning( q_tm1: ArrayLike, a_tm1: ArrayLike, r_t: ArrayLike, discount_t: ArrayLike, q_t: ArrayLike, ) -> ArrayLike: """Calculates the Q-learning temporal difference error. See "Reinforcement Learning: An Introduction" by Sutton and Barto. (http://incompleteideas.net/book/ebook/node65.html). Args: q_tm1: Q-values at time t-1. a_tm1: action index at time t-1. r_t: reward at time t. discount_t: discount at time t. q_t: Q-values at time t. Returns: Q-learning temporal difference error. """ base.rank_assert([q_tm1, a_tm1, r_t, discount_t, q_t], [1, 0, 0, 0, 1]) base.type_assert([q_tm1, a_tm1, r_t, discount_t, q_t], [float, int, float, float, float]) target_tm1 = r_t + discount_t * jnp.max(q_t) return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
def add_ornstein_uhlenbeck_noise(key: ArrayLike, action: ArrayLike, noise_tm1: ArrayLike, damping: float, stddev: float) -> ArrayLike: """Returns continuous action with noise from Ornstein-Uhlenbeck process. See "On the theory of Brownian Motion" by Uhlenbeck and Ornstein. (https://journals.aps.org/pr/abstract/10.1103/PhysRev.36.823). Args: key: a key from `jax.random`. action: continuous action scalar or vector. noise_tm1: noise sampled from OU process in previous timestep. damping: parameter for controlling autocorrelation of OU process. stddev: standard deviation of noise distribution. Returns: noisy action, of the same shape as input action. """ base.rank_assert([action, noise_tm1], 1) base.type_assert([action, noise_tm1], float) noise_t = (1. - damping) * noise_tm1 + jax.random.normal( key, shape=action.shape) * stddev return action + noise_t
def categorical_kl_divergence( p_logits: ArrayLike, q_logits: ArrayLike, temperature: float = 1. ) -> ArrayLike: """Compute the KL between two categorical distributions from their logits. Args: p_logits: unnormalized logits for the first distribution. q_logits: unnormalized logits for the second distribution. temperature: the temperature for the softmax distribution, defaults at 1. Returns: the kl divergence between the distributions. """ base.type_assert([p_logits, q_logits], float) p_logits /= temperature q_logits /= temperature p = jax.nn.softmax(p_logits) log_p = jax.nn.log_softmax(p_logits) log_q = jax.nn.log_softmax(q_logits) kl = jnp.sum(p * (log_p - log_q), axis=-1) return jax.nn.relu(kl) # Guard against numerical issues giving negative KL.
def test_mixed_inputs_should_not_raise(self): a_float = 1. an_int = 2 a_np_float = np.asarray([3., 4.]) a_jax_int = jnp.asarray([5, 6]) base.type_assert([a_float, an_int, a_np_float, a_jax_int], [float, int, float, int])
def test_unsupported_type_should_raise(self): a_float = 1. an_int = 2 a_np_float = np.asarray([3., 4.]) a_jax_int = jnp.asarray([5, 6]) with self.assertRaises(ValueError): base.type_assert([a_float, an_int, a_np_float, a_jax_int], [np.complex, np.complex, float, int])
def test_different_length_should_raise(self): a_float = 1. an_int = 2 a_np_float = np.asarray([3., 4.]) a_jax_int = jnp.asarray([5, 6]) with self.assertRaises(ValueError): base.type_assert([a_float, an_int, a_np_float, a_jax_int], [int, float, int])
def test_mixed_inputs_should_raise(self): a_float = 1. an_int = 2 a_np_float = np.asarray([3., 4.]) a_jax_int = jnp.asarray([5, 6]) with self.assertRaises(ValueError): base.type_assert([a_float, an_int, a_np_float, a_jax_int], [float, int, float, float])
def importance_corrected_td_errors( r_t: ArrayLike, discount_t: ArrayLike, rho_tm1: ArrayLike, lambda_: ArrayLike, values: ArrayLike, ) -> ArrayLike: """Computes the multistep td errors with per decision importance sampling. Given a trajectory of length `T+1`, generated under some policy π, for each time-step `t` we can estimate a multistep temporal difference error δₜ(ρ,λ), by combining rewards, discounts, and state values, according to a mixing parameter `λ` and importance sampling ratios ρₜ = π(aₜ|sₜ) / μ(aₜ|sₜ): td-errorₜ = ρₜ δₜ(ρ,λ) δₜ(ρ,λ) = δₜ + ρₜ₊₁ λₜ₊₁ γₜ₊₁ δₜ₊₁(ρ,λ), where δₜ = rₜ₊₁ + γₜ₊₁ vₜ₊₁ - vₜ is the one step, temporal difference error for the agent's state value estimates. This is equivalent to computing the λ-return with λₜ = ρₜ (e.g. using the `lambda_returns` function from above), and then computing errors as td-errorₜ = ρₜ(Gₜ - vₜ). See "A new Q(λ) with interim forward view and Monte Carlo equivalence" by Sutton et al. (http://proceedings.mlr.press/v32/sutton14.html). Args: r_t: sequence of rewards rₜ for timesteps t in [1, T]. discount_t: sequence of discounts γₜ for timesteps t in [1, T]. rho_tm1: sequence of importance ratios for all timesteps t in [0, T-1]. lambda_: mixing parameter; scalar or have per timestep values in [1, T]. values: sequence of state values under π for all timesteps t in [0, T]. Returns: Off-policy estimates of the multistep lambda returns from each state. """ base.rank_assert([r_t, discount_t, rho_tm1, values], [1, 1, 1, 1]) base.type_assert([r_t, discount_t, rho_tm1, values], float) v_tm1 = values[:-1] # Predictions to compute errors for. v_t = values[1:] # Values for bootstrapping. rho_t = jnp.concatenate( (rho_tm1[1:], jnp.array([1.]))) # Unused dummy value. lambda_ = jnp.ones_like( discount_t) * lambda_ # If scalar, make into vector. # Compute the one step temporal difference errors. one_step_delta = r_t + discount_t * v_t - v_tm1 # Work backwards to compute `delta_{T-1}`, ..., `delta_0`. delta, errors = 0.0, [] for i in jnp.arange(one_step_delta.shape[0] - 1, -1, -1): delta = one_step_delta[ i] + discount_t[i] * rho_t[i] * lambda_[i] * delta errors.insert(0, delta) return rho_tm1 * jnp.array(errors)
def categorical_q_learning( q_atoms_tm1: ArrayLike, q_logits_tm1: ArrayLike, a_tm1: ArrayOrScalar, r_t: ArrayOrScalar, discount_t: ArrayOrScalar, q_atoms_t: ArrayLike, q_logits_t: ArrayLike, ) -> ArrayOrScalar: """Implements Q-learning for categorical Q distributions. See "A Distributional Perspective on Reinforcement Learning", by Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf). Args: q_atoms_tm1: atoms of Q distribution at time t-1. q_logits_tm1: logits of Q distribution at time t-1. a_tm1: action index at time t-1. r_t: reward at time t. discount_t: discount at time t. q_atoms_t: atoms of Q distribution at time t. q_logits_t: logits of Q distribution at time t. Returns: Categorical Q-learning loss (i.e. temporal difference error). """ base.rank_assert([ q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t, q_logits_t ], [1, 2, 0, 0, 0, 1, 2]) base.type_assert([ q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t, q_logits_t ], [float, float, int, float, float, float, float]) # Scale and shift time-t distribution atoms by discount and reward. target_z = r_t + discount_t * q_atoms_t # Convert logits to distribution, then find greedy action in state s_t. q_t_probs = jax.nn.softmax(q_logits_t) q_t_mean = jnp.sum(q_t_probs * q_atoms_t[jnp.newaxis, :], axis=1) pi_t = jnp.argmax(q_t_mean) # Compute distribution for greedy action. p_target_z = q_t_probs[pi_t] # Project using the Cramer distance. target = jax.lax.stop_gradient( _categorical_l2_project(target_z, p_target_z, q_atoms_tm1)) # Compute loss (i.e. temporal difference error). logit_qa_tm1 = q_logits_tm1[a_tm1] return distributions.categorical_cross_entropy(labels=target, logits=logit_qa_tm1)
def categorical_double_q_learning( q_atoms_tm1: ArrayLike, q_logits_tm1: ArrayLike, a_tm1: ArrayOrScalar, r_t: ArrayOrScalar, discount_t: ArrayOrScalar, q_atoms_t: ArrayLike, q_logits_t: ArrayLike, q_t_selector: ArrayLike, ) -> ArrayOrScalar: """Implements double Q-learning for categorical Q distributions. See "A Distributional Perspective on Reinforcement Learning", by Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf) and "Double Q-learning" by van Hasselt. (https://papers.nips.cc/paper/3964-double-q-learning.pdf). Args: q_atoms_tm1: atoms of Q distribution at time t-1. q_logits_tm1: logits of Q distribution at time t-1. a_tm1: action index at time t-1. r_t: reward at time t. discount_t: discount at time t. q_atoms_t: atoms of Q distribution at time t. q_logits_t: logits of Q distribution at time t. q_t_selector: selector Q-values at time t. Returns: Categorical double Q-learning loss (i.e. temporal difference error). """ base.rank_assert([ q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t, q_logits_t, q_t_selector ], [1, 2, 0, 0, 0, 1, 2, 1]) base.type_assert([ q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t, q_logits_t, q_t_selector ], [float, float, int, float, float, float, float, float]) # Scale and shift time-t distribution atoms by discount and reward. target_z = r_t + discount_t * q_atoms_t # Select logits for greedy action in state s_t and convert to distribution. p_target_z = jax.nn.softmax(q_logits_t[q_t_selector.argmax()]) # Project using the Cramer distance. target = jax.lax.stop_gradient( _categorical_l2_project(target_z, p_target_z, q_atoms_tm1)) # Compute loss (i.e. temporal difference error). logit_qa_tm1 = q_logits_tm1[a_tm1] return distributions.categorical_cross_entropy(labels=target, logits=logit_qa_tm1)
def general_off_policy_returns_from_action_values( q_t: ArrayLike, a_t: ArrayLike, r_t: ArrayLike, discount_t: ArrayLike, c_t: ArrayLike, pi_t: ArrayLike, ) -> ArrayLike: """Calculates errors for various off-policy correction algorithms. Given a window of experience of length `K`, generated by a behaviour policy μ, for each time-step `t` we can estimate the return `G_t` from that step onwards, under some target policy π, using the rewards in the trajectory, the actions selected by μ and the action-values under π, according to equation: Gₜ = rₜ₊₁ + γₜ₊₁ * (E[q(aₜ₊₁)] - cₜ * q(aₜ₊₁) + cₜ * Gₜ₊₁), where, depending on the choice of `c_t`, the algorithm implements: Importance Sampling c_t = π(x_t, a_t) / μ(x_t, a_t), Harutyunyan's et al. Q(lambda) c_t = λ, Precup's et al. Tree-Backup c_t = π(x_t, a_t), Munos' et al. Retrace c_t = λ min(1, π(x_t, a_t) / μ(x_t, a_t)). See "Safe and Efficient Off-Policy Reinforcement Learning" by Munos et al. (https://arxiv.org/abs/1606.02647). Args: q_t: Q-values at time t. a_t: action index at time t. r_t: reward at time t. discount_t: discount at time t. c_t: importance weights at time t. pi_t: target policy probs at time t. Returns: Off-policy estimates of the multistep lambda returns from each state.. """ base.rank_assert([q_t, a_t, r_t, discount_t, c_t, pi_t], [2, 1, 1, 1, 1, 2]) base.type_assert([q_t, a_t, r_t, discount_t, c_t, pi_t], [float, int, float, float, float, float]) # Get the expected values and the values of actually selected actions. exp_q_t = (pi_t * q_t).sum(axis=-1) # The generalized returns are independent of Q-values and cs at the final # state. q_a_t = base.batched_index(q_t, a_t)[:-1] c_t = c_t[:-1] return general_off_policy_returns_from_q_and_v( q_a_t, exp_q_t, r_t, discount_t, c_t)
def vtrace( v_tm1: ArrayLike, v_t: ArrayLike, r_t: ArrayLike, discount_t: ArrayLike, rho_t: ArrayLike, lambda_: float = 1.0, clip_rho_threshold: float = 1.0, stop_target_gradients: bool = True, ) -> ArrayLike: """Calculates V-Trace errors from policy logits. Args: v_tm1: values at time t-1. v_t: values at time t. r_t: reward at time t. discount_t: discount at time t. rho_t: importance sampling ratios. lambda_: scalar mixing parameter lambda. clip_rho_threshold: clip threshold for importance weights. stop_target_gradients: whether or not to apply stop gradient to targets. Returns: V-Trace error. """ base.rank_assert( [v_tm1, v_t, r_t, discount_t, rho_t], [1, 1, 1, 1, 1]) base.type_assert( [v_tm1, v_t, r_t, discount_t, rho_t], [float, float, float, float, float]) # Clip importance sampling ratios. c_t = jnp.minimum(1.0, rho_t) * lambda_ clipped_rhos = jnp.minimum(clip_rho_threshold, rho_t) # Compute the temporal difference errors. td_errors = clipped_rhos * (r_t + discount_t * v_t - v_tm1) # Work backwards computing the td-errors. err = 0.0 errors = [] for i in jnp.arange(v_t.shape[0] - 1, -1, -1): err = td_errors[i] + discount_t[i] * c_t[i] * err errors.insert(0, err) # Add the value of the initial state to get the estimates of the returns. target_tm1 = jnp.array(errors) + v_tm1 # Stop gradients and return temporal difference error. if stop_target_gradients: target_tm1 = jax.lax.stop_gradient(target_tm1) return target_tm1 - v_tm1
def quantile_q_learning(dist_q_tm1: ArrayLike, tau_q_tm1: ArrayLike, a_tm1: ArrayOrScalar, r_t: ArrayOrScalar, discount_t: ArrayOrScalar, dist_q_t_selector: ArrayLike, dist_q_t: ArrayLike, huber_param: float = 0.) -> ArrayOrScalar: """Implements Q-learning for quantile-valued Q distributions. See "Distributional Reinforcement Learning with Quantile Regression" by Dabney et al. (https://arxiv.org/abs/1710.10044). Args: dist_q_tm1: Q distribution at time t-1. tau_q_tm1: Q distribution probability thresholds. a_tm1: action index at time t-1. r_t: reward at time t. discount_t: discount at time t. dist_q_t_selector: Q distribution at time t for selecting greedy action in target policy. This is separate from dist_q_t as in Double Q-Learning, but can be computed with the target network and a separate set of samples. dist_q_t: target Q distribution at time t. huber_param: Huber loss parameter, defaults to 0 (no Huber loss). Returns: Quantile regression Q learning loss. """ base.rank_assert([ dist_q_tm1, tau_q_tm1, a_tm1, r_t, discount_t, dist_q_t_selector, dist_q_t ], [2, 1, 0, 0, 0, 2, 2]) base.type_assert([ dist_q_tm1, tau_q_tm1, a_tm1, r_t, discount_t, dist_q_t_selector, dist_q_t ], [float, float, int, float, float, float, float]) # Only update the taken actions. dist_qa_tm1 = dist_q_tm1[:, a_tm1] # Select target action according to greedy policy w.r.t. dist_q_t_selector. q_t_selector = jnp.mean(dist_q_t_selector, axis=0) a_t = jnp.argmax(q_t_selector) dist_qa_t = dist_q_t[:, a_t] # Compute target, do not backpropagate into it. dist_target = r_t + discount_t * dist_qa_t dist_target = jax.lax.stop_gradient(dist_target) return _quantile_regression_loss(dist_qa_tm1, tau_q_tm1, dist_target, huber_param)
def transformed_retrace( q_tm1: ArrayLike, q_t: ArrayLike, a_tm1: ArrayLike, a_t: ArrayLike, r_t: ArrayLike, discount_t: ArrayLike, pi_t: ArrayLike, mu_t: ArrayLike, lambda_: float, eps: float = 1e-8, stop_target_gradients: bool = True, tx_pair: TxPair = IDENTITY_PAIR, ) -> ArrayLike: """Calculates transformed Retrace errors. See "Recurrent Experience Replay in Distributed Reinforcement Learning" by Kapturowski et al. (https://openreview.net/pdf?id=r1lyTjAqYX). Args: q_tm1: Q-values at time t-1. q_t: Q-values at time t. a_tm1: action index at time t-1. a_t: action index at time t. r_t: reward at time t. discount_t: discount at time t. pi_t: target policy probs at time t. mu_t: behavior policy probs at time t. lambda_: scalar mixing parameter lambda. eps: small value to add to mu_t for numerical stability. stop_target_gradients: bool indicating whether or not to apply stop gradient to targets. tx_pair: TxPair of value function transformation and its inverse. Returns: Transformed Retrace error. """ base.rank_assert([q_tm1, q_t, a_tm1, a_t, r_t, discount_t, pi_t, mu_t], [2, 2, 1, 1, 1, 1, 2, 1]) base.type_assert([q_tm1, q_t, a_tm1, a_t, r_t, discount_t, pi_t, mu_t], [float, float, int, int, float, float, float, float]) pi_a_t = base.batched_index(pi_t, a_t) c_t = jnp.minimum(1.0, pi_a_t / (mu_t + eps)) * lambda_ target_tm1 = transformed_general_off_policy_returns_from_action_values( tx_pair, q_t, a_t, r_t, discount_t, c_t, pi_t) if stop_target_gradients: target_tm1 = jax.lax.stop_gradient(target_tm1) q_a_tm1 = base.batched_index(q_tm1, a_tm1) return target_tm1 - q_a_tm1
def retrace( q_tm1: ArrayLike, q_t: ArrayLike, a_tm1: ArrayLike, a_t: ArrayLike, r_t: ArrayLike, discount_t: ArrayLike, pi_t: ArrayLike, mu_t: ArrayLike, lambda_: float, eps: float = 1e-8, stop_target_gradients: bool = True, ) -> ArrayLike: """Calculates Retrace errors. See "Safe and Efficient Off-Policy Reinforcement Learning" by Munos et al. (https://arxiv.org/abs/1606.02647). Args: q_tm1: Q-values at time t-1. q_t: Q-values at time t. a_tm1: action index at time t-1. a_t: action index at time t. r_t: reward at time t. discount_t: discount at time t. pi_t: target policy probs at time t. mu_t: behavior policy probs at time t. lambda_: scalar mixing parameter lambda. eps: small value to add to mu_t for numerical stability. stop_target_gradients: bool indicating whether or not to apply stop gradient to targets. Returns: Retrace error. """ base.rank_assert([q_tm1, q_t, a_tm1, a_t, r_t, discount_t, pi_t, mu_t], [2, 2, 1, 1, 1, 1, 2, 1]) base.type_assert([q_tm1, q_t, a_tm1, a_t, r_t, discount_t, pi_t, mu_t], [float, float, int, int, float, float, float, float]) pi_a_t = base.batched_index(pi_t, a_t) c_t = jnp.minimum(1.0, pi_a_t / (mu_t + eps)) * lambda_ target_tm1 = multistep.general_off_policy_returns_from_action_values( q_t, a_t, r_t, discount_t, c_t, pi_t) q_a_tm1 = base.batched_index(q_tm1, a_tm1) if stop_target_gradients: target_tm1 = jax.lax.stop_gradient(target_tm1) return target_tm1 - q_a_tm1
def log_loss( predictions: ArrayLike, targets: ArrayLike, ) -> ArrayLike: """Calculates the log loss of predictions wrt targets. Args: predictions: a vector of probabilities of arbitrary shape. targets: a vector of probabilities of shape compatible with predictions. Returns: a vector of same shape of `predictions`. """ base.type_assert([predictions, targets], float) return -jnp.log(likelihood(predictions, targets))
def general_off_policy_returns_from_q_and_v( q_t: ArrayLike, v_t: ArrayLike, r_t: ArrayLike, discount_t: ArrayLike, c_t: ArrayLike, ) -> ArrayLike: """Calculates targets for various off-policy evaluation algorithms. Given a window of experience of length `K+1`, generated by a behaviour policy μ, for each time-step `t` we can estimate the return `G_t` from that step onwards, under some target policy π, using the rewards in the trajectory, the values under π of states and actions selected by μ, according to equation: Gₜ = rₜ₊₁ + γₜ₊₁ * (vₜ₊₁ - cₜ₊₁ * q(aₜ₊₁) + cₜ₊₁* Gₜ₊₁), where, depending on the choice of `c_t`, the algorithm implements: Importance Sampling c_t = π(x_t, a_t) / μ(x_t, a_t), Harutyunyan's et al. Q(lambda) c_t = λ, Precup's et al. Tree-Backup c_t = π(x_t, a_t), Munos' et al. Retrace c_t = λ min(1, π(x_t, a_t) / μ(x_t, a_t)). See "Safe and Efficient Off-Policy Reinforcement Learning" by Munos et al. (https://arxiv.org/abs/1606.02647). Args: q_t: Q-values under π of actions executed by μ at times [1, ..., K - 1]. v_t: Values under π at times [1, ..., K]. r_t: rewards at times [1, ..., K]. discount_t: discounts at times [1, ..., K]. c_t: weights at times [1, ..., K - 1]. Returns: Off-policy estimates of the generalized returns from states visited at times [0, ..., K - 1]. """ base.rank_assert([q_t, v_t, r_t, discount_t, c_t], 1) base.type_assert([q_t, v_t, r_t, discount_t, c_t], float) # Work backwards to compute `G_K-1`, ..., `G_1`, `G_0`. g = r_t[-1] + discount_t[-1] * v_t[-1] # G_K-1. returns = [g] for i in jnp.arange(q_t.shape[0] - 1, -1, -1): # [K - 2, ..., 0] g = r_t[i] + discount_t[i] * (v_t[i] - c_t[i] * q_t[i] + c_t[i] * g) returns.insert(0, g) return jnp.array(returns)
def add_gaussian_noise(key: ArrayLike, action: ArrayLike, stddev: float) -> ArrayLike: """Returns continuous action with noise drawn from a Gaussian distribution. Args: key: a key from `jax.random`. action: continuous action scalar or vector. stddev: standard deviation of noise distribution. Returns: noisy action, of the same shape as input action. """ base.type_assert(action, float) noise = jax.random.normal(key, shape=action.shape) * stddev return action + noise
def likelihood(predictions: ArrayLike, targets: ArrayLike) -> ArrayLike: """Calculates the likelihood of predictions wrt targets. Args: predictions: a vector of arbitrary shape. targets: a vector of shape compatible with predictions. Returns: a vector of same shape of `predictions`. """ base.type_assert([predictions, targets], float) likelihood_vals = predictions**targets * (1. - predictions)**(1. - targets) # Note: 0**0 evaluates to NaN on TPUs, manually set these cases to 1. filter_indices = jnp.logical_or( jnp.logical_and(targets == 1, predictions == 1), jnp.logical_and(targets == 0, predictions == 0)) return jnp.where(filter_indices, 1, likelihood_vals)