def categorical_q_learning( q_atoms_tm1: Array, q_logits_tm1: Array, a_tm1: Numeric, r_t: Numeric, discount_t: Numeric, q_atoms_t: Array, q_logits_t: Array, stop_target_gradients: bool = True, ) -> Numeric: """Implements Q-learning for categorical Q distributions. See "A Distributional Perspective on Reinforcement Learning", by Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf). Args: q_atoms_tm1: atoms of Q distribution at time t-1. q_logits_tm1: logits of Q distribution at time t-1. a_tm1: action index at time t-1. r_t: reward at time t. discount_t: discount at time t. q_atoms_t: atoms of Q distribution at time t. q_logits_t: logits of Q distribution at time t. stop_target_gradients: bool indicating whether or not to apply stop gradient to targets. Returns: Categorical Q-learning loss (i.e. temporal difference error). """ chex.assert_rank([ q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t, q_logits_t ], [1, 2, 0, 0, 0, 1, 2]) chex.assert_type([ q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t, q_logits_t ], [float, float, int, float, float, float, float]) # Scale and shift time-t distribution atoms by discount and reward. target_z = r_t + discount_t * q_atoms_t # Convert logits to distribution, then find greedy action in state s_t. q_t_probs = jax.nn.softmax(q_logits_t) q_t_mean = jnp.sum(q_t_probs * q_atoms_t[jnp.newaxis, :], axis=1) pi_t = jnp.argmax(q_t_mean) # Compute distribution for greedy action. p_target_z = q_t_probs[pi_t] # Project using the Cramer distance and maybe stop gradient flow to targets. target = categorical_l2_project(target_z, p_target_z, q_atoms_tm1) target = jax.lax.select(stop_target_gradients, jax.lax.stop_gradient(target), target) # Compute loss (i.e. temporal difference error). logit_qa_tm1 = q_logits_tm1[a_tm1] return distributions.categorical_cross_entropy(labels=target, logits=logit_qa_tm1)
def categorical_double_q_learning( q_atoms_tm1: Array, q_logits_tm1: Array, a_tm1: Numeric, r_t: Numeric, discount_t: Numeric, q_atoms_t: Array, q_logits_t: Array, q_t_selector: Array, stop_target_gradients: bool = True, ) -> Numeric: """Implements double Q-learning for categorical Q distributions. See "A Distributional Perspective on Reinforcement Learning", by Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf) and "Double Q-learning" by van Hasselt. (https://papers.nips.cc/paper/3964-double-q-learning.pdf). Args: q_atoms_tm1: atoms of Q distribution at time t-1. q_logits_tm1: logits of Q distribution at time t-1. a_tm1: action index at time t-1. r_t: reward at time t. discount_t: discount at time t. q_atoms_t: atoms of Q distribution at time t. q_logits_t: logits of Q distribution at time t. q_t_selector: selector Q-values at time t. stop_target_gradients: bool indicating whether or not to apply stop gradient to targets. Returns: Categorical double Q-learning loss (i.e. temporal difference error). """ chex.assert_rank([ q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t, q_logits_t, q_t_selector ], [1, 2, 0, 0, 0, 1, 2, 1]) chex.assert_type([ q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t, q_logits_t, q_t_selector ], [float, float, int, float, float, float, float, float]) # Scale and shift time-t distribution atoms by discount and reward. target_z = r_t + discount_t * q_atoms_t # Select logits for greedy action in state s_t and convert to distribution. p_target_z = jax.nn.softmax(q_logits_t[q_t_selector.argmax()]) # Project using the Cramer distance and maybe stop gradient flow to targets. target = categorical_l2_project(target_z, p_target_z, q_atoms_tm1) target = jax.lax.select(stop_target_gradients, jax.lax.stop_gradient(target), target) # Compute loss (i.e. temporal difference error). logit_qa_tm1 = q_logits_tm1[a_tm1] return distributions.categorical_cross_entropy(labels=target, logits=logit_qa_tm1)
def categorical_q_learning( q_atoms_tm1: ArrayLike, q_logits_tm1: ArrayLike, a_tm1: ArrayOrScalar, r_t: ArrayOrScalar, discount_t: ArrayOrScalar, q_atoms_t: ArrayLike, q_logits_t: ArrayLike, ) -> ArrayOrScalar: """Implements Q-learning for categorical Q distributions. See "A Distributional Perspective on Reinforcement Learning", by Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf). Args: q_atoms_tm1: atoms of Q distribution at time t-1. q_logits_tm1: logits of Q distribution at time t-1. a_tm1: action index at time t-1. r_t: reward at time t. discount_t: discount at time t. q_atoms_t: atoms of Q distribution at time t. q_logits_t: logits of Q distribution at time t. Returns: Categorical Q-learning loss (i.e. temporal difference error). """ base.rank_assert([ q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t, q_logits_t ], [1, 2, 0, 0, 0, 1, 2]) base.type_assert([ q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t, q_logits_t ], [float, float, int, float, float, float, float]) # Scale and shift time-t distribution atoms by discount and reward. target_z = r_t + discount_t * q_atoms_t # Convert logits to distribution, then find greedy action in state s_t. q_t_probs = jax.nn.softmax(q_logits_t) q_t_mean = jnp.sum(q_t_probs * q_atoms_t[jnp.newaxis, :], axis=1) pi_t = jnp.argmax(q_t_mean) # Compute distribution for greedy action. p_target_z = q_t_probs[pi_t] # Project using the Cramer distance. target = jax.lax.stop_gradient( _categorical_l2_project(target_z, p_target_z, q_atoms_tm1)) # Compute loss (i.e. temporal difference error). logit_qa_tm1 = q_logits_tm1[a_tm1] return distributions.categorical_cross_entropy(labels=target, logits=logit_qa_tm1)
def categorical_td_learning( v_atoms_tm1: Array, v_logits_tm1: Array, r_t: Numeric, discount_t: Numeric, v_atoms_t: Array, v_logits_t: Array, stop_target_gradients: bool = True, ) -> Numeric: """Implements TD-learning for categorical value distributions. See "A Distributional Perspective on Reinforcement Learning", by Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf). Args: v_atoms_tm1: atoms of V distribution at time t-1. v_logits_tm1: logits of V distribution at time t-1. r_t: reward at time t. discount_t: discount at time t. v_atoms_t: atoms of V distribution at time t. v_logits_t: logits of V distribution at time t. stop_target_gradients: bool indicating whether or not to apply stop gradient to targets. Returns: Categorical Q learning loss (i.e. temporal difference error). """ chex.assert_rank( [v_atoms_tm1, v_logits_tm1, r_t, discount_t, v_atoms_t, v_logits_t], [1, 1, 0, 0, 1, 1]) chex.assert_type( [v_atoms_tm1, v_logits_tm1, r_t, discount_t, v_atoms_t, v_logits_t], [float, float, float, float, float, float]) # Scale and shift time-t distribution atoms by discount and reward. target_z = r_t + discount_t * v_atoms_t # Convert logits to distribution. v_t_probs = jax.nn.softmax(v_logits_t) # Project using the Cramer distance and maybe stop gradient flow to targets. target = categorical_l2_project(target_z, v_t_probs, v_atoms_tm1) target = jax.lax.select(stop_target_gradients, jax.lax.stop_gradient(target), target) # Compute loss (i.e. temporal difference error). return distributions.categorical_cross_entropy(labels=target, logits=v_logits_tm1)
def categorical_td_learning( v_atoms_tm1: ArrayLike, v_logits_tm1: ArrayLike, r_t: ArrayLike, discount_t: ArrayLike, v_atoms_t: ArrayLike, v_logits_t: ArrayLike ) -> ArrayLike: """Implements TD-learning for categorical value distributions. See "A Distributional Perspective on Reinforcement Learning", by Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf). Args: v_atoms_tm1: atoms of V distribution at time t-1. v_logits_tm1: logits of V distribution at time t-1. r_t: reward at time t. discount_t: discount at time t. v_atoms_t: atoms of V distribution at time t. v_logits_t: logits of V distribution at time t. Returns: Categorical Q learning loss (i.e. temporal difference error). """ base.rank_assert( [v_atoms_tm1, v_logits_tm1, r_t, discount_t, v_atoms_t, v_logits_t], [1, 1, 0, 0, 1, 1]) base.type_assert( [v_atoms_tm1, v_logits_tm1, r_t, discount_t, v_atoms_t, v_logits_t], [float, float, float, float, float, float]) # Scale and shift time-t distribution atoms by discount and reward. target_z = r_t + discount_t * v_atoms_t # Convert logits to distribution. v_t_probs = jax.nn.softmax(v_logits_t) # Project using the Cramer distance. target = jax.lax.stop_gradient( _categorical_l2_project(target_z, v_t_probs, v_atoms_tm1)) # Compute loss (i.e. temporal difference error). return distributions.categorical_cross_entropy( labels=target, logits=v_logits_tm1)