def policy_gradient_loss( logits_t: Array, a_t: Array, adv_t: Array, w_t: Array, ) -> Array: """Calculates the policy gradient loss. See "Simple Gradient-Following Algorithms for Connectionist RL" by Williams. (http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf) Args: logits_t: a sequence of unnormalized action preferences. a_t: a sequence of actions sampled from the preferences `logits_t`. adv_t: the observed or estimated advantages from executing actions `a_t`. w_t: a per timestep weighting for the loss. Returns: Loss whose gradient corresponds to a policy gradient update. """ chex.assert_rank([logits_t, a_t, adv_t, w_t], [2, 1, 1, 1]) chex.assert_type([logits_t, a_t, adv_t, w_t], [float, int, float, float]) log_pi_a_t = distributions.softmax().logprob(a_t, logits_t) adv_t = jax.lax.stop_gradient(adv_t) loss_per_timestep = -log_pi_a_t * adv_t return jnp.mean(loss_per_timestep * w_t)
def test_softmax_logprob_batch(self, variant): """Tests for a full batch.""" distrib = distributions.softmax() logprob_fn = variant(distrib.logprob) # Test softmax output in batch. actual = logprob_fn(self.samples, self.logits) np.testing.assert_allclose(self.expected_logprobs, actual, atol=1e-4)
def test_softmax_probs_batch(self, variant): """Tests for a full batch.""" distrib = distributions.softmax(temperature=10.) softmax = variant(distrib.probs) # Test softmax output in batch. actual = softmax(self.logits) np.testing.assert_allclose(self.expected_probs, actual, atol=1e-4)
def test_softmax_entropy_batch(self, variant): """Tests for a full batch.""" distrib = distributions.softmax() entropy_fn = variant(distrib.entropy) # Test softmax output in batch. actual = entropy_fn(self.logits) np.testing.assert_allclose(self.expected_entropy, actual, atol=1e-4)
def test_softmax_entropy(self, variant): """Tests for a single element.""" distrib = distributions.softmax() entropy_fn = variant(distrib.entropy) # For each element in the batch. for logits, expected in zip(self.logits, self.expected_entropy): # Test outputs. actual = entropy_fn(logits) np.testing.assert_allclose(expected, actual, atol=1e-4)
def test_softmax_probs(self, variant): """Tests for a single element.""" distrib = distributions.softmax(temperature=10.) softmax = variant(distrib.probs) # For each element in the batch. for logits, expected in zip(self.logits, self.expected_probs): # Test outputs. actual = softmax(logits) np.testing.assert_allclose(expected, actual, atol=1e-4)
def test_softmax_logprob(self, variant): """Tests for a single element.""" distrib = distributions.softmax() logprob_fn = variant(distrib.logprob) # For each element in the batch. for logits, samples, expected in zip(self.logits, self.samples, self.expected_logprobs): # Test output. actual = logprob_fn(samples, logits) np.testing.assert_allclose(expected, actual, atol=1e-4)
def test_softmax_probs_batch(self, compile_fn, place_fn): """Tests for a full batch.""" distrib = distributions.softmax(temperature=10.) # Vmap and optionally compile. softmax = compile_fn(distrib.probs) # Optionally convert to device array. logits = place_fn(self.logits) # Test softmax output in batch. actual = softmax(logits) np.testing.assert_allclose(self.expected_probs, actual, atol=1e-4)
def test_softmax_entropy_batch(self, compile_fn, place_fn): """Tests for a full batch.""" distrib = distributions.softmax() # Vmap and optionally compile. entropy_fn = compile_fn(distrib.entropy) # Optionally convert to device array. logits = place_fn(self.logits) # Test softmax output in batch. actual = entropy_fn(logits) np.testing.assert_allclose(self.expected_entropy, actual, atol=1e-4)
def test_softmax_logprob_batch(self, compile_fn, place_fn): """Tests for a full batch.""" distrib = distributions.softmax() # Vmap and optionally compile. logprob_fn = compile_fn(distrib.logprob) # Optionally convert to device array. logits, samples = tree_map(place_fn, (self.logits, self.samples)) # Test softmax output in batch. actual = logprob_fn(samples, logits) np.testing.assert_allclose(self.expected_logprobs, actual, atol=1e-4)
def test_softmax_probs(self, compile_fn, place_fn): """Tests for a single element.""" distrib = distributions.softmax(temperature=10.) # Optionally compile. softmax = compile_fn(distrib.probs) # For each element in the batch. for logits, expected in zip(self.logits, self.expected_probs): # Optionally convert to device array. logits = place_fn(logits) # Test outputs. actual = softmax(logits) np.testing.assert_allclose(expected, actual, atol=1e-4)
def test_softmax_logprob(self, compile_fn, place_fn): """Tests for a single element.""" distrib = distributions.softmax() # Optionally compile. logprob_fn = compile_fn(distrib.logprob) # For each element in the batch. for logits, samples, expected in zip(self.logits, self.samples, self.expected_logprobs): # Optionally convert to device array. logits, samples = tree_map(place_fn, (logits, samples)) # Test output. actual = logprob_fn(samples, logits) np.testing.assert_allclose(expected, actual, atol=1e-4)
def test_single_double_q_learning_eq_batch(self): """Tests equivalence to categorical_q_learning when q_t_selector == q_t.""" # Not using vmap for atoms. @self.variant @jax.vmap def batch_categorical_double_q_learning(q_logits_tm1, a_tm1, r_t, discount_t, q_logits_t, q_t_selector): return value_learning.categorical_double_q_learning( q_atoms_tm1=self.atoms, q_logits_tm1=q_logits_tm1, a_tm1=a_tm1, r_t=r_t, discount_t=discount_t, q_atoms_t=self.atoms, q_logits_t=q_logits_t, q_t_selector=q_t_selector) @self.variant @jax.vmap def batch_categorical_q_learning(q_logits_tm1, a_tm1, r_t, discount_t, q_logits_t): return value_learning.categorical_q_learning( q_atoms_tm1=self.atoms, q_logits_tm1=q_logits_tm1, a_tm1=a_tm1, r_t=r_t, discount_t=discount_t, q_atoms_t=self.atoms, q_logits_t=q_logits_t) # Double Q-learning estimate with q_t_selector=q_t distrib = distributions.softmax() # Add batch and time dimension to atoms. atoms = jnp.expand_dims(jnp.expand_dims(self.atoms, 0), 0) q_t_selector = jnp.sum(distrib.probs(self.q_logits_t) * atoms, axis=-1) actual = batch_categorical_double_q_learning(self.q_logits_tm1, self.a_tm1, self.r_t, self.discount_t, self.q_logits_t, q_t_selector) # Q-learning estimate. expected = batch_categorical_q_learning(self.q_logits_tm1, self.a_tm1, self.r_t, self.discount_t, self.q_logits_t) # Test equivalence. np.testing.assert_allclose(expected, actual)
def entropy_loss( logits_t: Array, w_t: Array, ) -> Array: """Calculates the entropy regularization loss. See "Function Optimization using Connectionist RL Algorithms" by Williams. (https://www.tandfonline.com/doi/abs/10.1080/09540099108946587) Args: logits_t: a sequence of unnormalized action preferences. w_t: a per timestep weighting for the loss. Returns: Entropy loss. """ chex.assert_rank([logits_t, w_t], [2, 1]) chex.assert_type([logits_t, w_t], float) entropy_per_timestep = distributions.softmax().entropy(logits_t) return -jnp.mean(entropy_per_timestep * w_t)