def test_value_loss(self): rewards = np.array([ [1, 2, 4, 8, 16, 32, 64, 128], [1, 1, 1, 1, 1, 1, 1, 1], ]) rewards_mask = np.array([ [1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0], ]) gamma = 0.5 # Random observations and a value function that returns a constant value. # NOTE: Observations have an extra time-step. B, T = rewards.shape # pylint: disable=invalid-name observation_shape = (210, 160, 3) # atari pong random_observations = np.random.uniform(size=(B, T + 1) + observation_shape) def value_net_apply(observations, params, rng=None): del params, rng # pylint: disable=invalid-name B, T_p_1, OBS = (observations.shape[0], observations.shape[1], observations.shape[2:]) del OBS return np.ones((B, T_p_1, 1)) # pylint: enable=invalid-name value_prediction = value_net_apply(random_observations, []) with jax.disable_jit(): value_loss = ppo.value_loss_given_predictions( value_prediction, rewards, rewards_mask, gamma) self.assertNear(53.3637084961, value_loss, 1e-6)
def test_combined_loss(self): self.rng_key, key1, key2 = jax_random.split(self.rng_key, num=3) B, T, A, OBS = 2, 10, 2, (28, 28, 3) # pylint: disable=invalid-name batch_observation_shape = (-1, -1) + OBS old_params, _ = ppo.policy_and_value_net( key1, batch_observation_shape, A, [layers.Flatten(num_axis_to_keep=2)]) new_params, net_apply = ppo.policy_and_value_net( key2, batch_observation_shape, A, [layers.Flatten(num_axis_to_keep=2)]) # Generate a batch of observations. observations = np.random.uniform(size=(B, T + 1) + OBS) actions = np.random.randint(0, A, size=(B, T)) rewards = np.random.uniform(0, 1, size=(B, T)) mask = np.ones_like(rewards) # Just test that this computes at all. new_log_probabs, _ = net_apply(observations, new_params) old_log_probabs, value_predictions = net_apply(observations, old_params) gamma = 0.99 lambda_ = 0.95 epsilon = 0.2 c1 = 1.0 c2 = 0.01 value_loss_1 = ppo.value_loss_given_predictions(value_predictions, rewards, mask, gamma=gamma) ppo_loss_1 = ppo.ppo_loss_given_predictions(new_log_probabs, old_log_probabs, value_predictions, actions, rewards, mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon) (combined_loss, ppo_loss_2, value_loss_2, entropy_bonus) = (ppo.combined_loss(new_params, old_params, net_apply, observations, actions, rewards, mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon, c1=c1, c2=c2)) # Test that these compute at all and are self consistent. self.assertEqual(0.0, entropy_bonus) self.assertNear(value_loss_1, value_loss_2, 1e-6) self.assertNear(ppo_loss_1, ppo_loss_2, 1e-6) self.assertNear(combined_loss, ppo_loss_2 + (c1 * value_loss_2), 1e-6)