Esempio n. 1
0
def policy_gradient_loss(
    logits_t: Array,
    a_t: Array,
    adv_t: Array,
    w_t: Array,
) -> Array:
  """Calculates the policy gradient loss.

  See "Simple Gradient-Following Algorithms for Connectionist RL" by Williams.
  (http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf)

  Args:
    logits_t: a sequence of unnormalized action preferences.
    a_t: a sequence of actions sampled from the preferences `logits_t`.
    adv_t: the observed or estimated advantages from executing actions `a_t`.
    w_t: a per timestep weighting for the loss.

  Returns:
    Loss whose gradient corresponds to a policy gradient update.
  """
  chex.assert_rank([logits_t, a_t, adv_t, w_t], [2, 1, 1, 1])
  chex.assert_type([logits_t, a_t, adv_t, w_t], [float, int, float, float])

  log_pi_a_t = distributions.softmax().logprob(a_t, logits_t)
  adv_t = jax.lax.stop_gradient(adv_t)
  loss_per_timestep = -log_pi_a_t * adv_t
  return jnp.mean(loss_per_timestep * w_t)
Esempio n. 2
0
 def test_softmax_logprob_batch(self, variant):
     """Tests for a full batch."""
     distrib = distributions.softmax()
     logprob_fn = variant(distrib.logprob)
     # Test softmax output in batch.
     actual = logprob_fn(self.samples, self.logits)
     np.testing.assert_allclose(self.expected_logprobs, actual, atol=1e-4)
Esempio n. 3
0
 def test_softmax_probs_batch(self, variant):
     """Tests for a full batch."""
     distrib = distributions.softmax(temperature=10.)
     softmax = variant(distrib.probs)
     # Test softmax output in batch.
     actual = softmax(self.logits)
     np.testing.assert_allclose(self.expected_probs, actual, atol=1e-4)
Esempio n. 4
0
 def test_softmax_entropy_batch(self, variant):
     """Tests for a full batch."""
     distrib = distributions.softmax()
     entropy_fn = variant(distrib.entropy)
     # Test softmax output in batch.
     actual = entropy_fn(self.logits)
     np.testing.assert_allclose(self.expected_entropy, actual, atol=1e-4)
Esempio n. 5
0
 def test_softmax_entropy(self, variant):
     """Tests for a single element."""
     distrib = distributions.softmax()
     entropy_fn = variant(distrib.entropy)
     # For each element in the batch.
     for logits, expected in zip(self.logits, self.expected_entropy):
         # Test outputs.
         actual = entropy_fn(logits)
         np.testing.assert_allclose(expected, actual, atol=1e-4)
Esempio n. 6
0
 def test_softmax_probs(self, variant):
     """Tests for a single element."""
     distrib = distributions.softmax(temperature=10.)
     softmax = variant(distrib.probs)
     # For each element in the batch.
     for logits, expected in zip(self.logits, self.expected_probs):
         # Test outputs.
         actual = softmax(logits)
         np.testing.assert_allclose(expected, actual, atol=1e-4)
Esempio n. 7
0
 def test_softmax_logprob(self, variant):
     """Tests for a single element."""
     distrib = distributions.softmax()
     logprob_fn = variant(distrib.logprob)
     # For each element in the batch.
     for logits, samples, expected in zip(self.logits, self.samples,
                                          self.expected_logprobs):
         # Test output.
         actual = logprob_fn(samples, logits)
         np.testing.assert_allclose(expected, actual, atol=1e-4)
Esempio n. 8
0
 def test_softmax_probs_batch(self, compile_fn, place_fn):
     """Tests for a full batch."""
     distrib = distributions.softmax(temperature=10.)
     # Vmap and optionally compile.
     softmax = compile_fn(distrib.probs)
     # Optionally convert to device array.
     logits = place_fn(self.logits)
     # Test softmax output in batch.
     actual = softmax(logits)
     np.testing.assert_allclose(self.expected_probs, actual, atol=1e-4)
Esempio n. 9
0
 def test_softmax_entropy_batch(self, compile_fn, place_fn):
     """Tests for a full batch."""
     distrib = distributions.softmax()
     # Vmap and optionally compile.
     entropy_fn = compile_fn(distrib.entropy)
     # Optionally convert to device array.
     logits = place_fn(self.logits)
     # Test softmax output in batch.
     actual = entropy_fn(logits)
     np.testing.assert_allclose(self.expected_entropy, actual, atol=1e-4)
Esempio n. 10
0
 def test_softmax_logprob_batch(self, compile_fn, place_fn):
     """Tests for a full batch."""
     distrib = distributions.softmax()
     # Vmap and optionally compile.
     logprob_fn = compile_fn(distrib.logprob)
     # Optionally convert to device array.
     logits, samples = tree_map(place_fn, (self.logits, self.samples))
     # Test softmax output in batch.
     actual = logprob_fn(samples, logits)
     np.testing.assert_allclose(self.expected_logprobs, actual, atol=1e-4)
Esempio n. 11
0
 def test_softmax_probs(self, compile_fn, place_fn):
     """Tests for a single element."""
     distrib = distributions.softmax(temperature=10.)
     # Optionally compile.
     softmax = compile_fn(distrib.probs)
     # For each element in the batch.
     for logits, expected in zip(self.logits, self.expected_probs):
         # Optionally convert to device array.
         logits = place_fn(logits)
         # Test outputs.
         actual = softmax(logits)
         np.testing.assert_allclose(expected, actual, atol=1e-4)
Esempio n. 12
0
 def test_softmax_logprob(self, compile_fn, place_fn):
     """Tests for a single element."""
     distrib = distributions.softmax()
     # Optionally compile.
     logprob_fn = compile_fn(distrib.logprob)
     # For each element in the batch.
     for logits, samples, expected in zip(self.logits, self.samples,
                                          self.expected_logprobs):
         # Optionally convert to device array.
         logits, samples = tree_map(place_fn, (logits, samples))
         # Test output.
         actual = logprob_fn(samples, logits)
         np.testing.assert_allclose(expected, actual, atol=1e-4)
Esempio n. 13
0
    def test_single_double_q_learning_eq_batch(self):
        """Tests equivalence to categorical_q_learning when q_t_selector == q_t."""
        # Not using vmap for atoms.
        @self.variant
        @jax.vmap
        def batch_categorical_double_q_learning(q_logits_tm1, a_tm1, r_t,
                                                discount_t, q_logits_t,
                                                q_t_selector):
            return value_learning.categorical_double_q_learning(
                q_atoms_tm1=self.atoms,
                q_logits_tm1=q_logits_tm1,
                a_tm1=a_tm1,
                r_t=r_t,
                discount_t=discount_t,
                q_atoms_t=self.atoms,
                q_logits_t=q_logits_t,
                q_t_selector=q_t_selector)

        @self.variant
        @jax.vmap
        def batch_categorical_q_learning(q_logits_tm1, a_tm1, r_t, discount_t,
                                         q_logits_t):
            return value_learning.categorical_q_learning(
                q_atoms_tm1=self.atoms,
                q_logits_tm1=q_logits_tm1,
                a_tm1=a_tm1,
                r_t=r_t,
                discount_t=discount_t,
                q_atoms_t=self.atoms,
                q_logits_t=q_logits_t)

        # Double Q-learning estimate with q_t_selector=q_t
        distrib = distributions.softmax()
        # Add batch and time dimension to atoms.
        atoms = jnp.expand_dims(jnp.expand_dims(self.atoms, 0), 0)
        q_t_selector = jnp.sum(distrib.probs(self.q_logits_t) * atoms, axis=-1)
        actual = batch_categorical_double_q_learning(self.q_logits_tm1,
                                                     self.a_tm1, self.r_t,
                                                     self.discount_t,
                                                     self.q_logits_t,
                                                     q_t_selector)
        # Q-learning estimate.
        expected = batch_categorical_q_learning(self.q_logits_tm1, self.a_tm1,
                                                self.r_t, self.discount_t,
                                                self.q_logits_t)
        # Test equivalence.
        np.testing.assert_allclose(expected, actual)
Esempio n. 14
0
def entropy_loss(
    logits_t: Array,
    w_t: Array,
) -> Array:
  """Calculates the entropy regularization loss.

  See "Function Optimization using Connectionist RL Algorithms" by Williams.
  (https://www.tandfonline.com/doi/abs/10.1080/09540099108946587)

  Args:
    logits_t: a sequence of unnormalized action preferences.
    w_t: a per timestep weighting for the loss.

  Returns:
    Entropy loss.
  """
  chex.assert_rank([logits_t, w_t], [2, 1])
  chex.assert_type([logits_t, w_t], float)

  entropy_per_timestep = distributions.softmax().entropy(logits_t)
  return -jnp.mean(entropy_per_timestep * w_t)