def vtrace( v_tm1: Array, v_t: Array, r_t: Array, discount_t: Array, rho_t: Array, lambda_: float = 1.0, clip_rho_threshold: float = 1.0, stop_target_gradients: bool = True, ) -> Array: """Calculates V-Trace errors from importance weights. V-trace computes TD-errors from multistep trajectories by applying off-policy corrections based on clipped importance sampling ratios. See "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor Learner Architectures" by Espeholt et al. (https://arxiv.org/abs/1802.01561). Args: v_tm1: values at time t-1. v_t: values at time t. r_t: reward at time t. discount_t: discount at time t. rho_t: importance sampling ratios. lambda_: scalar mixing parameter lambda. clip_rho_threshold: clip threshold for importance weights. stop_target_gradients: whether or not to apply stop gradient to targets. Returns: V-Trace error. """ chex.assert_rank([v_tm1, v_t, r_t, discount_t, rho_t], [1, 1, 1, 1, 1]) chex.assert_type([v_tm1, v_t, r_t, discount_t, rho_t], [float, float, float, float, float]) # Clip importance sampling ratios. c_t = jnp.minimum(1.0, rho_t) * lambda_ clipped_rhos = jnp.minimum(clip_rho_threshold, rho_t) # Compute the temporal difference errors. td_errors = clipped_rhos * (r_t + discount_t * v_t - v_tm1) # Work backwards computing the td-errors. err = 0.0 errors = [] for i in jnp.arange(v_t.shape[0] - 1, -1, -1): err = td_errors[i] + discount_t[i] * c_t[i] * err errors.insert(0, err) # Return errors. if not stop_target_gradients: return jnp.array(errors) # In TD-like algorithms, we want gradients to only flow in the predictions, # and not in the values used to bootstrap. In this case, add the value of the # initial state value to get the implied estimates of the returns, stop # gradient around such target and then subtract again the initial state value. else: target_tm1 = jnp.array(errors) + v_tm1 target_tm1 = jax.lax.stop_gradient(target_tm1) return target_tm1 - v_tm1
def clipped_surrogate_pg_loss( prob_ratios_t: Array, adv_t: Array, epsilon: Scalar) -> Array: """Computes the clipped surrogate policy gradient loss. L_clipₜ(θ) = - min(rₜ(θ)Âₜ, clip(rₜ(θ), 1-ε, 1+ε)Âₜ) Where rₜ(θ) = π_θ(aₜ| sₜ) / π_θ_old(aₜ| sₜ) and Âₜ are the advantages. See Proximal Policy Optimization Algorithms, Schulman et al.: https://arxiv.org/abs/1707.06347 Args: prob_ratios_t: Ratio of action probabilities for actions a_t: rₜ(θ) = π_θ(aₜ| sₜ) / π_θ_old(aₜ| sₜ) adv_t: the observed or estimated advantages from executing actions a_t. epsilon: Scalar value corresponding to how much to clip the objecctive. Returns: Loss whose gradient corresponds to a clipped surrogate policy gradient update. """ chex.assert_rank([prob_ratios_t, adv_t], [1, 1]) chex.assert_type([prob_ratios_t, adv_t], [float, float]) clipped_ratios_t = jnp.clip(prob_ratios_t, 1. - epsilon, 1. + epsilon) clipped_objective = jnp.fmin(prob_ratios_t * adv_t, clipped_ratios_t * adv_t) return -jnp.mean(clipped_objective)
def policy_gradient_loss( logits_t: Array, a_t: Array, adv_t: Array, w_t: Array, ) -> Array: """Calculates the policy gradient loss. See "Simple Gradient-Following Algorithms for Connectionist RL" by Williams. (http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf) Args: logits_t: a sequence of unnormalized action preferences. a_t: a sequence of actions sampled from the preferences `logits_t`. adv_t: the observed or estimated advantages from executing actions `a_t`. w_t: a per timestep weighting for the loss. Returns: Loss whose gradient corresponds to a policy gradient update. """ chex.assert_rank([logits_t, a_t, adv_t, w_t], [2, 1, 1, 1]) chex.assert_type([logits_t, a_t, adv_t, w_t], [float, int, float, float]) log_pi_a_t = distributions.softmax().logprob(a_t, logits_t) adv_t = jax.lax.stop_gradient(adv_t) loss_per_timestep = -log_pi_a_t * adv_t return jnp.mean(loss_per_timestep * w_t)
def objective_func(self, params, state, hyperparams, rng, transition_batch, Adv): rngs = hk.PRNGSequence(rng) # get distribution params from function approximator S = self.pi.observation_preprocessor(next(rngs), transition_batch.S) dist_params, state_new = self.pi.function(params, state, next(rngs), S, True) # compute probability ratios A = self.pi.proba_dist.preprocess_variate(next(rngs), transition_batch.A) log_pi = self.pi.proba_dist.log_proba(dist_params, A) ratio = jnp.exp(log_pi - transition_batch.logP) # π_new / π_old ratio_clip = jnp.clip(ratio, 1 - hyperparams['epsilon'], 1 + hyperparams['epsilon']) # clip importance weights to reduce variance W = jnp.clip(transition_batch.W, 0.1, 10.) # ppo-clip objective chex.assert_equal_shape([W, Adv, ratio, ratio_clip]) chex.assert_rank([W, Adv, ratio, ratio_clip], 1) objective = W * jnp.minimum(Adv * ratio, Adv * ratio_clip) # also pass auxiliary data to avoid multiple forward passes return jnp.mean(objective), (dist_params, log_pi, state_new)
def qpg_loss( logits_t: Array, q_t: Array, ) -> Array: """Computes the QPG (Q-based Policy Gradient) loss. See "Actor-Critic Policy Optimization in Partially Observable Multiagent Environments" by Srinivasan, Lanctot. (https://papers.nips.cc/paper/7602-actor-critic-policy-optimization-in-partially-observable-multiagent-environments.pdf) Args: logits_t: a sequence of unnormalized action preferences. q_t: the observed or estimated action value from executing actions `a_t` at time t. regularization. Returns: QPG Loss. """ chex.assert_rank([logits_t, q_t], 2) chex.assert_type([logits_t, q_t], float) policy_t, advantage_t = _compute_advantages(logits_t, q_t) policy_advantages = -policy_t * jax.lax.stop_gradient(advantage_t) loss = jnp.mean(jnp.sum(policy_advantages, axis=1), axis=0) return loss
def qv_max( v_tm1: Numeric, r_t: Numeric, discount_t: Numeric, q_t: Array, stop_target_gradients: bool = True, ) -> Numeric: """Calculates the QVMAX temporal difference error. See "The QV Family Compared to Other Reinforcement Learning Algorithms" by Wiering and van Hasselt (2009). (http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.713.1931) Args: v_tm1: state values at time t-1. r_t: reward at time t. discount_t: discount at time t. q_t: Q-values at time t. stop_target_gradients: bool indicating whether or not to apply stop gradient to targets. Returns: QVMAX temporal difference error. """ chex.assert_rank([v_tm1, r_t, discount_t, q_t], [0, 0, 0, 1]) chex.assert_type([v_tm1, r_t, discount_t, q_t], float) target_tm1 = r_t + discount_t * jnp.max(q_t) target_tm1 = jax.lax.select(stop_target_gradients, jax.lax.stop_gradient(target_tm1), target_tm1) return target_tm1 - v_tm1
def _compute_posterior( self, inputs, encoder_outputs, context_vectors, ): """Computes the posterior branch of the DecoderBlock.""" chex.assert_rank(inputs, 4) resolution = inputs.shape[1] try: encoded_image = encoder_outputs[resolution] except KeyError: raise KeyError( 'encoder_outputs does not contain the required ' # pylint: disable=g-doc-exception f'resolution ({resolution}). encoder_outputs resolutions ' f'are {list(encoder_outputs.keys())}.') posterior_block = blocks.ResBlock(self.bottlenecked_num_channels, self.latent_dim * 2, use_residual_connection=False, precision=self.precision, name='posterior_block') concatenated_inputs = jnp.concatenate([inputs, encoded_image], axis=3) posterior_output = posterior_block(concatenated_inputs, context_vectors) posterior_mean, posterior_log_std = jnp.split(posterior_output, 2, axis=3) return posterior_mean, posterior_log_std
def test_reshape_shape_forward(filters): n_dims = (1, 28, 28, 1) new_shape = _get_new_shapes(28, 28, 1, filters) params_rng, data_rng = jax.random.split(KEY, 2) x = jax.random.uniform(data_rng, shape=n_dims) # create layer init_func = Squeeze(filter_shape=filters, collapse=None, return_outputs=True) # create layer z_, params, forward_f, inverse_f = init_func(rng=params_rng, shape=n_dims, inputs=x) # forward transformation z, log_abs_det = forward_f(params, x) # checks chex.assert_tree_all_close(z, z_) chex.assert_equal_shape([z, log_abs_det, z_]) chex.assert_rank(z, 4) chex.assert_equal(z.shape[1:], new_shape) # inverse transformation x_approx, log_abs_det = inverse_f(params, z) # checks chex.assert_equal_shape([x_approx, x]) chex.assert_tree_all_close(x_approx, x)
def test_reshape_shape_collapse(filters, collapse): n_dims = (1, 28, 28, 1) params_rng, data_rng = jax.random.split(KEY, 2) x = jax.random.uniform(data_rng, shape=n_dims) # create layer init_func = Squeeze(filter_shape=filters, collapse=collapse) # create layer params, forward_f, inverse_f = init_func(rng=params_rng, shape=n_dims,) # forward transformation z, log_abs_det = forward_f(params, x) # checks chex.assert_equal_shape([z, log_abs_det]) chex.assert_rank(z, 2) # inverse transformation x_approx, log_abs_det = inverse_f(params, z) # checks chex.assert_equal_shape([x_approx, x]) chex.assert_tree_all_close(x_approx, x)
def fix_step_type_on_interruptions(step_type: chex.Array): """Returns step_type with a LAST step before almost every FIRST step. If the environment crashes or is interrupted while a trajectory is being written, the LAST step can be missing before a FIRST step. We add the LAST step before each FIRST step, if the step before the FIRST step is a MID step, to signal to the agent that the next observation is not connected to the current stream of data. Note that the agent must still then appropriately handle both `terminations` (e.g. game over in a game) and `interruptions` (a timeout or a reset for system maintenance): the value of the discount on LAST step will be > 0 on `interruptions`, while it will be 0 on `terminations`. Similar issues arise in hierarchical RL systems as well. Args: step_type: an array of `dm_env` step types, with shape `[T, B]`. Returns: Fixed step_type. """ chex.assert_rank(step_type, 2) next_step_type = jnp.concatenate([ step_type[1:], jnp.full(step_type[:1].shape, int(dm_env.StepType.MID), dtype=step_type.dtype), ], axis=0) return jnp.where( jnp.logical_and( jnp.equal(next_step_type, int(dm_env.StepType.FIRST)), jnp.equal(step_type, int(dm_env.StepType.MID)), ), int(dm_env.StepType.LAST), step_type)
def add_ornstein_uhlenbeck_noise( key: Array, action: Array, noise_tm1: Array, damping: float, stddev: float ) -> Array: """Returns continuous action with noise from Ornstein-Uhlenbeck process. See "On the theory of Brownian Motion" by Uhlenbeck and Ornstein. (https://journals.aps.org/pr/abstract/10.1103/PhysRev.36.823). Args: key: a key from `jax.random`. action: continuous action scalar or vector. noise_tm1: noise sampled from OU process in previous timestep. damping: parameter for controlling autocorrelation of OU process. stddev: standard deviation of noise distribution. Returns: noisy action, of the same shape as input action. """ chex.assert_rank([action, noise_tm1], 1) chex.assert_type([action, noise_tm1], float) noise_t = (1. - damping) * noise_tm1 + jax.random.normal( key, shape=action.shape) * stddev return action + noise_t
def rpg_loss( logits_t: Array, q_t: Array, use_stop_gradient: bool = True, ) -> Array: """Computes the RPG (Regret Policy Gradient) loss. The gradient of this loss adapts the Regret Matching rule by weighting the standard PG update with regret. See "Actor-Critic Policy Optimization in Partially Observable Multiagent Environments" by Srinivasan, Lanctot (https://arxiv.org/abs/1810.09026). Args: logits_t: a sequence of unnormalized action preferences. q_t: the observed or estimated action value from executing actions `a_t` at time t. use_stop_gradient: bool indicating whether or not to apply stop gradient to advantages. Returns: RPG Loss. """ chex.assert_rank([logits_t, q_t], 2) chex.assert_type([logits_t, q_t], float) _, adv_t = _compute_advantages(logits_t, q_t, use_stop_gradient) regrets_t = jnp.sum(jax.nn.relu(adv_t), axis=1) total_regret_t = jnp.mean(regrets_t, axis=0) return total_regret_t
def qpg_loss( logits_t: Array, q_t: Array, use_stop_gradient: bool = True, ) -> Array: """Computes the QPG (Q-based Policy Gradient) loss. See "Actor-Critic Policy Optimization in Partially Observable Multiagent Environments" by Srinivasan, Lanctot (https://arxiv.org/abs/1810.09026). Args: logits_t: a sequence of unnormalized action preferences. q_t: the observed or estimated action value from executing actions `a_t` at time t. use_stop_gradient: bool indicating whether or not to apply stop gradient to advantages. Returns: QPG Loss. """ chex.assert_rank([logits_t, q_t], 2) chex.assert_type([logits_t, q_t], float) policy_t, advantage_t = _compute_advantages(logits_t, q_t) advantage_t = jax.lax.select(use_stop_gradient, jax.lax.stop_gradient(advantage_t), advantage_t) policy_advantages = -policy_t * advantage_t loss = jnp.mean(jnp.sum(policy_advantages, axis=1), axis=0) return loss
def leaky_vtrace(v_tm1: Array, v_t: Array, r_t: Array, discount_t: Array, rho_t: Array, alpha_: float = 1.0, lambda_: float = 1.0, clip_rho_threshold: float = 1.0, stop_target_gradients: bool = True): """Calculates Leaky V-Trace errors from importance weights. Leaky-Vtrace is a combination of Importance sampling and V-trace, where the degree of mixing is controlled by a scalar `alpha` (that may be meta-learnt). See "Self-Tuning Deep Reinforcement Learning" by Zahavy et al. (https://arxiv.org/abs/2002.12928) Args: v_tm1: values at time t-1. v_t: values at time t. r_t: reward at time t. discount_t: discount at time t. rho_t: importance weights at time t. alpha_: mixing parameter for Importance Sampling and V-trace. lambda_: scalar mixing parameter lambda. clip_rho_threshold: clip threshold for importance weights. stop_target_gradients: whether or not to apply stop gradient to targets. Returns: Leaky V-Trace error. """ chex.assert_rank([v_tm1, v_t, r_t, discount_t, rho_t], [1, 1, 1, 1, 1]) chex.assert_type([v_tm1, v_t, r_t, discount_t, rho_t], [float, float, float, float, float]) # Mix clipped and unclipped importance sampling ratios. c_t = ((1 - alpha_) * rho_t + alpha_ * jnp.minimum(1.0, rho_t)) * lambda_ clipped_rhos = ((1 - alpha_) * rho_t + alpha_ * jnp.minimum(clip_rho_threshold, rho_t)) # Compute the temporal difference errors. td_errors = clipped_rhos * (r_t + discount_t * v_t - v_tm1) # Work backwards computing the td-errors. err = 0.0 errors = [] for i in jnp.arange(v_t.shape[0] - 1, -1, -1): err = td_errors[i] + discount_t[i] * c_t[i] * err errors.insert(0, err) # Return errors. if not stop_target_gradients: return jnp.array(errors) # In TD-like algorithms, we want gradients to only flow in the predictions, # and not in the values used to bootstrap. In this case, add the value of the # initial state value to get the implied estimates of the returns, stop # gradient around such target and then subtract again the initial state value. else: target_tm1 = jnp.array(errors) + v_tm1 return jax.lax.stop_gradient(target_tm1) - v_tm1
def qv_learning( q_tm1: Array, a_tm1: Numeric, r_t: Numeric, discount_t: Numeric, v_t: Numeric, stop_target_gradients: bool = True, ) -> Numeric: """Calculates the QV-learning temporal difference error. See "Two Novel On-policy Reinforcement Learning Algorithms based on TD(lambda)-methods" by Wiering and van Hasselt (https://ieeexplore.ieee.org/abstract/document/4220845.) Args: q_tm1: Q-values at time t-1. a_tm1: action index at time t-1. r_t: reward at time t. discount_t: discount at time t. v_t: state values at time t. stop_target_gradients: bool indicating whether or not to apply stop gradient to targets. Returns: QV-learning temporal difference error. """ chex.assert_rank([q_tm1, a_tm1, r_t, discount_t, v_t], [1, 0, 0, 0, 0]) chex.assert_type([q_tm1, a_tm1, r_t, discount_t, v_t], [float, int, float, float, float]) target_tm1 = r_t + discount_t * v_t target_tm1 = jax.lax.select(stop_target_gradients, jax.lax.stop_gradient(target_tm1), target_tm1) return target_tm1 - q_tm1[a_tm1]
def expected_sarsa( q_tm1: Array, a_tm1: Numeric, r_t: Numeric, discount_t: Numeric, q_t: Array, probs_a_t: Array, ) -> Numeric: """Calculates the expected SARSA (SARSE) temporal difference error. See "A Theoretical and Empirical Analysis of Expected Sarsa" by Seijen, van Hasselt, Whiteson et al. (http://www.cs.ox.ac.uk/people/shimon.whiteson/pubs/vanseijenadprl09.pdf). Args: q_tm1: Q-values at time t-1. a_tm1: action index at time t-1. r_t: reward at time t. discount_t: discount at time t. q_t: Q-values at time t. probs_a_t: action probabilities at time t. Returns: Expected SARSA temporal difference error. """ chex.assert_rank([q_tm1, a_tm1, r_t, discount_t, q_t, probs_a_t], [1, 0, 0, 0, 1, 1]) chex.assert_type([q_tm1, a_tm1, r_t, discount_t, q_t, probs_a_t], [float, int, float, float, float, float]) target_tm1 = r_t + discount_t * jnp.dot(q_t, probs_a_t) return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
def td_learning( v_tm1: Numeric, r_t: Numeric, discount_t: Numeric, v_t: Numeric, stop_target_gradients: bool = True, ) -> Numeric: """Calculates the TD-learning temporal difference error. See "Learning to Predict by the Methods of Temporal Differences" by Sutton. (https://link.springer.com/article/10.1023/A:1022633531479). Args: v_tm1: state values at time t-1. r_t: reward at time t. discount_t: discount at time t. v_t: state values at time t. stop_target_gradients: bool indicating whether or not to apply stop gradient to targets. Returns: TD-learning temporal difference error. """ chex.assert_rank([v_tm1, r_t, discount_t, v_t], 0) chex.assert_type([v_tm1, r_t, discount_t, v_t], float) target_tm1 = r_t + discount_t * v_t target_tm1 = jax.lax.select(stop_target_gradients, jax.lax.stop_gradient(target_tm1), target_tm1) return target_tm1 - v_tm1
def double_q_learning( q_tm1: Array, a_tm1: Numeric, r_t: Numeric, discount_t: Numeric, q_t_value: Array, q_t_selector: Array, ) -> Numeric: """Calculates the double Q-learning temporal difference error. See "Double Q-learning" by van Hasselt. (https://papers.nips.cc/paper/3964-double-q-learning.pdf). Args: q_tm1: Q-values at time t-1. a_tm1: action index at time t-1. r_t: reward at time t. discount_t: discount at time t. q_t_value: Q-values at time t. q_t_selector: selector Q-values at time t. Returns: Double Q-learning temporal difference error. """ chex.assert_rank([q_tm1, a_tm1, r_t, discount_t, q_t_value, q_t_selector], [1, 0, 0, 0, 1, 1]) chex.assert_type([q_tm1, a_tm1, r_t, discount_t, q_t_value, q_t_selector], [float, int, float, float, float, float]) target_tm1 = r_t + discount_t * q_t_value[q_t_selector.argmax()] return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
def td_lambda( v_tm1: Array, r_t: Array, discount_t: Array, v_t: Array, lambda_: Numeric, stop_target_gradients: bool = True, ) -> Array: """Calculates the TD(lambda) temporal difference error. See "Reinforcement Learning: An Introduction" by Sutton and Barto. (http://incompleteideas.net/book/ebook/node74.html). Args: v_tm1: sequence of state values at time t-1. r_t: sequence of rewards at time t. discount_t: sequence of discounts at time t. v_t: sequence of state values at time t. lambda_: mixing parameter lambda, either a scalar or a sequence. stop_target_gradients: bool indicating whether or not to apply stop gradient to targets. Returns: TD(lambda) temporal difference error. """ chex.assert_rank([v_tm1, r_t, discount_t, v_t, lambda_], [1, 1, 1, 1, {0, 1}]) chex.assert_type([v_tm1, r_t, discount_t, v_t, lambda_], float) target_tm1 = multistep.lambda_returns(r_t, discount_t, v_t, lambda_) target_tm1 = jax.lax.select(stop_target_gradients, jax.lax.stop_gradient(target_tm1), target_tm1) return target_tm1 - v_tm1
def persistent_q_learning( q_tm1: Array, a_tm1: Numeric, r_t: Numeric, discount_t: Numeric, q_t: Array, action_gap_scale: float, ) -> Numeric: """Calculates the persistent Q-learning temporal difference error. See "Increasing the Action Gap: New Operators for Reinforcement Learning" by Bellemare, Ostrovski, Guez et al. (https://arxiv.org/abs/1512.04860). Args: q_tm1: Q-values at time t-1. a_tm1: action index at time t-1. r_t: reward at time t. discount_t: discount at time t. q_t: Q-values at time t. action_gap_scale: coefficient in [0, 1] for scaling the action gap term. Returns: Persistent Q-learning temporal difference error. """ chex.assert_rank([q_tm1, a_tm1, r_t, discount_t, q_t], [1, 0, 0, 0, 1]) chex.assert_type([q_tm1, a_tm1, r_t, discount_t, q_t], [float, int, float, float, float]) corrected_q_t = ( (1. - action_gap_scale) * jnp.max(q_t) + action_gap_scale * q_t[a_tm1] ) target_tm1 = r_t + discount_t * corrected_q_t return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
def affine_transform(dist_params, scale, shift, value_transform=None): chex.assert_rank([dist_params['values'], scale, shift], [2, {0, 1}, {0, 1}]) values = check_shape(dist_params['values'], 'values') quantile_fractions = check_shape(dist_params['quantile_fractions'], 'quantile_fractions') batch_size = values.shape[0] if isscalar(scale): scale = jnp.full(shape=(batch_size, 1), fill_value=jnp.squeeze(scale)) if isscalar(shift): shift = jnp.full(shape=(batch_size, 1), fill_value=jnp.squeeze(shift)) scale = jnp.reshape(scale, (batch_size, 1)) shift = jnp.reshape(shift, (batch_size, 1)) chex.assert_shape(values, (batch_size, self.num_quantiles)) chex.assert_shape([scale, shift], (batch_size, 1)) if value_transform is None: f = f_inv = lambda x: x else: f, f_inv = value_transform return { 'values': f(shift + scale * f_inv(values)), 'quantile_fractions': quantile_fractions }
def _categorical_l2_project( z_p: Array, probs: Array, z_q: Array ) -> Array: """Projects a categorical distribution (z_p, p) onto a different support z_q. The projection step minimizes an L2-metric over the cumulative distribution functions (CDFs) of the source and target distributions. Let kq be len(z_q) and kp be len(z_p). This projection works for any support z_q, in particular kq need not be equal to kp. See "A Distributional Perspective on RL" by Bellemare et al. (https://arxiv.org/abs/1707.06887). Args: z_p: support of distribution p. probs: probability values. z_q: support to project distribution (z_p, probs) onto. Returns: Projection of (z_p, p) onto support z_q under Cramer distance. """ chex.assert_rank([z_p, probs, z_q], 1) chex.assert_type([z_p, probs, z_q], float) kp = z_p.shape[0] kq = z_q.shape[0] # Construct helper arrays from z_q. d_pos = jnp.roll(z_q, shift=-1) d_neg = jnp.roll(z_q, shift=1) # Clip z_p to be in new support range (vmin, vmax). z_p = jnp.clip(z_p, z_q[0], z_q[-1])[None, :] assert z_p.shape == (1, kp) # Get the distance between atom values in support. d_pos = (d_pos - z_q)[:, None] # z_q[i+1] - z_q[i] d_neg = (z_q - d_neg)[:, None] # z_q[i] - z_q[i-1] z_q = z_q[:, None] assert z_q.shape == (kq, 1) # Ensure that we do not divide by zero, in case of atoms of identical value. d_neg = jnp.where(d_neg > 0, 1. / d_neg, jnp.zeros_like(d_neg)) d_pos = jnp.where(d_pos > 0, 1. / d_pos, jnp.zeros_like(d_pos)) delta_qp = z_p - z_q # clip(z_p)[j] - z_q[i] d_sign = (delta_qp >= 0.).astype(probs.dtype) assert delta_qp.shape == (kq, kp) assert d_sign.shape == (kq, kp) # Matrix of entries sgn(a_ij) * |a_ij|, with a_ij = clip(z_p)[j] - z_q[i]. delta_hat = (d_sign * delta_qp * d_pos) - ((1. - d_sign) * delta_qp * d_neg) probs = probs[None, :] assert delta_hat.shape == (kq, kp) assert probs.shape == (1, kp) return jnp.sum(jnp.clip(1. - delta_hat, 0., 1.) * probs, axis=-1)
def update(self, idx, Adv): r""" Update the priority weights of transitions previously added to the buffer. Parameters ---------- idx : 1d array of ints The identifiers of the transitions to be updated. Adv : ndarray The corresponding updated advantages. """ idx = onp.asarray(idx, dtype='int32') Adv = onp.asarray(Adv, dtype='float32') chex.assert_equal_shape([idx, Adv]) chex.assert_rank([idx, Adv], 1) idx_lookup = idx % self.capacity # wrap around new_values = onp.where( _get_transition_batch_idx( self._storage[idx_lookup]) == idx, # only update if ids match onp.power(onp.abs(Adv) + self.epsilon, self.alpha), self._sumtree.values[idx_lookup]) self._sumtree.set_values(idx_lookup, new_values)
def sarsa( q_tm1: Array, a_tm1: Numeric, r_t: Numeric, discount_t: Numeric, q_t: Array, a_t: Numeric, ) -> Numeric: """Calculates the SARSA temporal difference error. See "Reinforcement Learning: An Introduction" by Sutton and Barto. (http://incompleteideas.net/book/ebook/node64.html.) Args: q_tm1: Q-values at time t-1. a_tm1: action index at time t-1. r_t: reward at time t. discount_t: discount at time t. q_t: Q-values at time t. a_t: action index at time t. Returns: SARSA temporal difference error. """ chex.assert_rank([q_tm1, a_tm1, r_t, discount_t, q_t, a_t], [1, 0, 0, 0, 1, 0]) chex.assert_type([q_tm1, a_tm1, r_t, discount_t, q_t, a_t], [float, int, float, float, float, int]) target_tm1 = r_t + discount_t * q_t[a_t] return jax.lax.stop_gradient(target_tm1) - q_tm1[a_tm1]
def rpg_loss( logits_t: Array, q_t: Array, ) -> Array: """Computes the RPG (Regret Policy Gradient) loss. The gradient of this loss adapts the Regret Matching rule by weighting the standard PG update with regret. See "Actor-Critic Policy Optimization in Partially Observable Multiagent Environments" by Srinivasan, Lanctot. (https://papers.nips.cc/paper/7602-actor-critic-policy-optimization-in-partially-observable-multiagent-environments.pdf) Args: logits_t: a sequence of unnormalized action preferences. q_t: the observed or estimated action value from executing actions `a_t` at time t. Returns: RPG Loss. """ chex.assert_rank([logits_t, q_t], 2) chex.assert_type([logits_t, q_t], float) _, adv_t = _compute_advantages(logits_t, q_t) regrets_t = jnp.sum(jax.nn.relu(adv_t), axis=1) total_regret_t = jnp.mean(regrets_t, axis=0) return total_regret_t
def objective_func(self, params, state, hyperparams, rng, transition_batch, Adv): rngs = hk.PRNGSequence(rng) # get distribution params from function approximator S = self.pi.observation_preprocessor(next(rngs), transition_batch.S) dist_params, state_new = self.pi.function(params, state, next(rngs), S, True) # compute objective: q(s, a_greedy) S = self.q_targ.observation_preprocessor(next(rngs), transition_batch.S) A = self.pi.proba_dist.mode(dist_params) log_pi = self.pi.proba_dist.log_proba(dist_params, A) params_q, state_q = hyperparams['q']['params'], hyperparams['q'][ 'function_state'] Q, _ = self.q_targ.function_type1(params_q, state_q, next(rngs), S, A, True) # clip importance weights to reduce variance W = jnp.clip(transition_batch.W, 0.1, 10.) # the objective chex.assert_equal_shape([W, Q]) chex.assert_rank([W, Q], 1) objective = W * Q return jnp.mean(objective), (dist_params, log_pi, state_new)
def dpg_loss( a_t: Array, dqda_t: Array, dqda_clipping: Optional[Scalar] = None ) -> Array: """Calculates the deterministic policy gradient (DPG) loss. See "Deterministic Policy Gradient Algorithms" by Silver, Lever, Heess, Degris, Wierstra, Riedmiller (http://proceedings.mlr.press/v32/silver14.pdf). Args: a_t: continuous-valued action at time t. dqda_t: gradient of Q(s,a) wrt. a, evaluated at time t. dqda_clipping: clips the gradient to have norm <= `dqda_clipping`. Returns: DPG loss. """ chex.assert_rank([a_t, dqda_t], 1) chex.assert_type([a_t, dqda_t], float) if dqda_clipping is not None: dqda_t = _clip_by_l2_norm(dqda_t, dqda_clipping) target_tm1 = dqda_t + a_t return losses.l2_loss(jax.lax.stop_gradient(target_tm1) - a_t)
def q_learning( q_tm1: Array, a_tm1: Numeric, r_t: Numeric, discount_t: Numeric, q_t: Array, stop_target_gradients: bool = True, ) -> Numeric: """Calculates the Q-learning temporal difference error. See "Reinforcement Learning: An Introduction" by Sutton and Barto. (http://incompleteideas.net/book/ebook/node65.html). Args: q_tm1: Q-values at time t-1. a_tm1: action index at time t-1. r_t: reward at time t. discount_t: discount at time t. q_t: Q-values at time t. stop_target_gradients: bool indicating whether or not to apply stop gradient to targets. Returns: Q-learning temporal difference error. """ chex.assert_rank([q_tm1, a_tm1, r_t, discount_t, q_t], [1, 0, 0, 0, 1]) chex.assert_type([q_tm1, a_tm1, r_t, discount_t, q_t], [float, int, float, float, float]) target_tm1 = r_t + discount_t * jnp.max(q_t) target_tm1 = jax.lax.select(stop_target_gradients, jax.lax.stop_gradient(target_tm1), target_tm1) return target_tm1 - q_tm1[a_tm1]
def discounted_returns( r_t: Array, discount_t: Array, v_t: Array ) -> Array: """Calculates a discounted return from a trajectory. The returns are computed recursively, from `G_{T-1}` to `G_0`, according to: Gₜ = rₜ₊₁ + γₜ₊₁ Gₜ₊₁. See "Reinforcement Learning: An Introduction" by Sutton and Barto. (http://incompleteideas.net/sutton/book/ebook/node61.html). Args: r_t: reward sequence at time t. discount_t: discount sequence at time t. v_t: value sequence or scalar at time t. Returns: Discounted returns. """ chex.assert_rank([r_t, discount_t, v_t], [1, 1, {0, 1}]) chex.assert_type([r_t, discount_t, v_t], float) # If scalar make into vector. bootstrapped_v = jnp.ones_like(discount_t) * v_t return lambda_returns(r_t, discount_t, bootstrapped_v, lambda_=1.)
def n_step_bootstrapped_returns(r_t: Array, discount_t: Array, v_t: Array, n: int) -> Array: """Computes strided n-step bootstrapped return targets over a sequence. The returns are computed in a backwards fashion according to the equation: Gₜ = rₜ₊₁ + γₜ₊₁ * (rₜ₊₂ + γₜ₊₂ * (... * (rₜ₊ₙ + γₜ₊ₙ * vₜ₊ₙ ))), Args: r_t: rewards at times [1, ..., T]. discount_t: discounts at times [1, ..., T]. v_t: state or state-action values to bootstrap from at time [1, ...., T] n: number of steps over which to accumulate reward before bootstrapping. Returns: estimated bootstrapped returns at times [1, ...., T] """ chex.assert_rank([r_t, discount_t, v_t], 1) chex.assert_type([r_t, discount_t, v_t], float) seq_len = r_t.shape[0] # Pad end of reward and discount sequences with 0 and 1 respectively. r_t = jnp.concatenate([r_t, jnp.zeros(n - 1)]) discount_t = jnp.concatenate([discount_t, jnp.ones(n - 1)]) # Shift bootstrap values by n and pad end of sequence with last value v_t[-1]. pad_size = min(n - 1, seq_len) targets = jnp.concatenate([v_t[n - 1:], jnp.array([v_t[-1]] * pad_size)]) # Work backwards to compute discounted, bootstrapped n-step returns. for i in reversed(range(n)): targets = r_t[i:i + seq_len] + discount_t[i:i + seq_len] * targets return targets