def test_ppo_loss(self): self.rng_key, key1, key2, key3 = jax_random.split(self.rng_key, num=4) B, T, A, OBS = 2, 10, 2, (28, 28, 3) # pylint: disable=invalid-name batch_observation_shape = (-1, -1) + OBS old_policy_params, _ = ppo.policy_net( key1, batch_observation_shape, A, [layers.Flatten(num_axis_to_keep=2)]) new_policy_params, policy_apply = ppo.policy_net( key2, batch_observation_shape, A, [layers.Flatten(num_axis_to_keep=2)]) value_params, value_apply = ppo.value_net( key3, batch_observation_shape, A, [layers.Flatten(num_axis_to_keep=2)]) # Generate a batch of observations. observations = np.random.uniform(size=(B, T + 1) + OBS) actions = np.random.randint(0, A, size=(B, T)) rewards = np.random.uniform(0, 1, size=(B, T)) mask = np.ones_like(rewards) # Just test that this computes at all. _ = ppo.ppo_loss(policy_apply, new_policy_params, old_policy_params, value_apply, value_params, observations, actions, rewards, mask)
def test_value_net(self): observation_shape = (3, 4, 5) num_actions = 2 value_params, value_apply = ppo.value_net( self.rng_key, (-1, -1) + observation_shape, num_actions, [layers.Flatten(num_axis_to_keep=2)]) batch = 2 time_steps = 10 batch_of_observations = np.random.uniform(size=(batch, time_steps) + observation_shape) value_output = value_apply(batch_of_observations, value_params) # NOTE: The extra dimension at the end because of Dense(1). self.assertEqual((batch, time_steps, 1), value_output.shape)