def test_ppo_loss(self): self.rng_key, key1, key2, key3 = jax_random.split(self.rng_key, num=4) B, T, A, OBS = 2, 10, 2, (28, 28, 3) # pylint: disable=invalid-name batch_observation_shape = (-1, -1) + OBS old_policy_params, _ = ppo.policy_net(key1, batch_observation_shape, A, [stax.Flatten(2)]) new_policy_params, policy_apply = ppo.policy_net( key2, batch_observation_shape, A, [stax.Flatten(2)]) value_params, value_apply = ppo.value_net(key3, batch_observation_shape, A, [stax.Flatten(2)]) # Generate a batch of observations. observations = np.random.uniform(size=(B, T + 1) + OBS) actions = np.random.randint(0, A, size=(B, T)) rewards = np.random.uniform(0, 1, size=(B, T)) mask = np.ones_like(rewards) # Just test that this computes at all. _ = ppo.ppo_loss(policy_apply, new_policy_params, old_policy_params, value_apply, value_params, observations, actions, rewards, mask)
def test_collect_trajectories(self): observation_shape = (2, 3, 4) num_actions = 2 policy_params, policy_apply = ppo.policy_net( self.rng_key, (-1, -1) + observation_shape, num_actions, # flatten except batch and time # step dimensions. [stax.Flatten(2)]) # We'll get done at time-step #10, starting from 0, therefore in 11 steps. done_time_step = 5 env = fake_env.FakeEnv(observation_shape, num_actions, done_time_step=done_time_step) num_trajectories = 5 trajectories = ppo.collect_trajectories(env, policy_apply, policy_params, num_trajectories, policy="categorical-sampling") # Number of trajectories is as expected. self.assertEqual(num_trajectories, len(trajectories)) # Shapes of observations, actions and rewards are as expected. for observations, actions, rewards in trajectories: # observations are one more in number than rewards or actions. self.assertEqual((done_time_step + 2, ) + observation_shape, observations.shape) self.assertEqual((done_time_step + 1, ), actions.shape) self.assertEqual((done_time_step + 1, ), rewards.shape)
def test_policy_net(self): observation_shape = (3, 4) num_actions = 2 policy_params, policy_apply = ppo.policy_net( self.rng_key, (-1, -1) + observation_shape, num_actions, # flatten except batch and time # step dimensions. [stax.Flatten(2)]) # Generate a batch of observations. batch = 2 time_steps = 10 batch_of_observations = np.random.uniform(size=(batch, time_steps) + observation_shape) # Apply the policy net on observations policy_output = policy_apply(policy_params, batch_of_observations) # Verify certain expectations on the output. self.assertEqual((batch, time_steps, num_actions), policy_output.shape) # Also last axis normalizes to 1, since these are probabilities. sum_actions = np.sum(policy_output, axis=-1) self.assertAllClose(np.ones_like(sum_actions), sum_actions)
def test_value_net(self): observation_shape = (3, 4, 5) num_actions = 2 value_params, value_apply = ppo.value_net( self.rng_key, (-1, -1) + observation_shape, num_actions, [stax.Flatten(num_axis_to_keep=2)]) batch = 2 time_steps = 10 batch_of_observations = np.random.uniform(size=(batch, time_steps) + observation_shape) value_output = value_apply(value_params, batch_of_observations) # NOTE: The extra dimension at the end because of Dense(1). self.assertEqual((batch, time_steps, 1), value_output.shape)
def test_policy_and_value_net(self): observation_shape = (3, 4, 5) batch_observation_shape = (-1, -1) + observation_shape num_actions = 2 pnv_params, pnv_apply = ppo.policy_and_value_net( self.rng_key, batch_observation_shape, num_actions, [stax.Flatten(2)]) batch = 2 time_steps = 10 batch_of_observations = np.random.uniform(size=(batch, time_steps) + observation_shape) pnv_output = pnv_apply(pnv_params, batch_of_observations) # Output is a list, first is probab of actions and the next is value output. self.assertEqual(2, len(pnv_output)) self.assertEqual((batch, time_steps, num_actions), pnv_output[0].shape) self.assertEqual((batch, time_steps, 1), pnv_output[1].shape)