def test_collect_trajectories(self): observation_shape = (2, 3, 4) num_actions = 2 policy_params, policy_apply = ppo.policy_net( self.rng_key, (-1, -1) + observation_shape, num_actions, # flatten except batch and time # step dimensions. [layers.Flatten(num_axis_to_keep=2)]) # We'll get done at time-step #5, starting from 0, therefore in 6 steps. done_time_step = 5 env = fake_env.FakeEnv(observation_shape, num_actions, done_time_step=done_time_step) num_trajectories = 5 trajectories = ppo.collect_trajectories( env, policy_fun=lambda obs: policy_apply(obs, policy_params), num_trajectories=num_trajectories, policy="categorical-sampling") # Number of trajectories is as expected. self.assertEqual(num_trajectories, len(trajectories)) # Shapes of observations, actions and rewards are as expected. for observations, actions, rewards in trajectories: # observations are one more in number than rewards or actions. self.assertEqual((done_time_step + 2, ) + observation_shape, observations.shape) self.assertEqual((done_time_step + 1, ), actions.shape) self.assertEqual((done_time_step + 1, ), rewards.shape) # Test collect using a Policy and Value function. pnv_params, pnv_apply = ppo.policy_and_value_net( self.rng_key, (-1, -1) + observation_shape, num_actions, [layers.Flatten(num_axis_to_keep=2)]) trajectories = ppo.collect_trajectories( env, policy_fun=lambda obs: pnv_apply(obs, pnv_params)[0], num_trajectories=num_trajectories, policy="categorical-sampling") # Number of trajectories is as expected. self.assertEqual(num_trajectories, len(trajectories)) # Shapes of observations, actions and rewards are as expected. for observations, actions, rewards in trajectories: # observations are one more in number than rewards or actions. self.assertEqual((done_time_step + 2, ) + observation_shape, observations.shape) self.assertEqual((done_time_step + 1, ), actions.shape) self.assertEqual((done_time_step + 1, ), rewards.shape)
def test_policy_and_value_net(self): observation_shape = (3, 4, 5) batch_observation_shape = (-1, -1) + observation_shape num_actions = 2 pnv_params, pnv_apply = ppo.policy_and_value_net( self.rng_key, batch_observation_shape, num_actions, [layers.Flatten(num_axis_to_keep=2)]) batch = 2 time_steps = 10 batch_of_observations = np.random.uniform(size=(batch, time_steps) + observation_shape) pnv_output = pnv_apply(batch_of_observations, pnv_params) # Output is a list, first is probab of actions and the next is value output. self.assertEqual(2, len(pnv_output)) self.assertEqual((batch, time_steps, num_actions), pnv_output[0].shape) self.assertEqual((batch, time_steps, 1), pnv_output[1].shape)
def test_collect_trajectories_max_timestep(self): self.rng_key, key1, key2 = jax_random.split(self.rng_key, num=3) observation_shape = (2, 3, 4) num_actions = 2 pnv_params, pnv_apply = ppo.policy_and_value_net( key1, (-1, -1) + observation_shape, num_actions, lambda: [layers.Flatten(num_axis_to_keep=2)]) def pnv_fun(obs, rng=None): rng, r = jax_random.split(rng) lp, v = pnv_apply(obs, pnv_params, rng=r) return lp, v, rng # We'll get done at time-step #5, starting from 0, therefore in 6 steps. done_time_step = 5 env = fake_env.FakeEnv(observation_shape, num_actions, done_time_step=done_time_step) num_trajectories = 5 # Let's collect trajectories only till `max_timestep`. max_timestep = 3 # we're testing when we early stop the trajectory. assert max_timestep < done_time_step trajectories = ppo.collect_trajectories( env, policy_fun=pnv_fun, num_trajectories=num_trajectories, policy="categorical-sampling", max_timestep=max_timestep, rng=key2) # Number of trajectories is as expected. self.assertEqual(num_trajectories, len(trajectories)) # Shapes of observations, actions and rewards are as expected. for observations, actions, rewards in trajectories: # observations are one more in number than rewards or actions. self.assertEqual((max_timestep, ) + observation_shape, observations.shape) self.assertEqual((max_timestep - 1, ), actions.shape) self.assertEqual((max_timestep - 1, ), rewards.shape)
def test_combined_loss(self): self.rng_key, key1, key2 = jax_random.split(self.rng_key, num=3) B, T, A, OBS = 2, 10, 2, (28, 28, 3) # pylint: disable=invalid-name batch_observation_shape = (-1, -1) + OBS old_params, _ = ppo.policy_and_value_net( key1, batch_observation_shape, A, [layers.Flatten(num_axis_to_keep=2)]) new_params, net_apply = ppo.policy_and_value_net( key2, batch_observation_shape, A, [layers.Flatten(num_axis_to_keep=2)]) # Generate a batch of observations. observations = np.random.uniform(size=(B, T + 1) + OBS) actions = np.random.randint(0, A, size=(B, T)) rewards = np.random.uniform(0, 1, size=(B, T)) mask = np.ones_like(rewards) # Just test that this computes at all. new_log_probabs, _ = net_apply(observations, new_params) old_log_probabs, value_predictions = net_apply(observations, old_params) gamma = 0.99 lambda_ = 0.95 epsilon = 0.2 c1 = 1.0 c2 = 0.01 value_loss_1 = ppo.value_loss_given_predictions(value_predictions, rewards, mask, gamma=gamma) ppo_loss_1 = ppo.ppo_loss_given_predictions(new_log_probabs, old_log_probabs, value_predictions, actions, rewards, mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon) (combined_loss, ppo_loss_2, value_loss_2, entropy_bonus) = (ppo.combined_loss(new_params, old_params, net_apply, observations, actions, rewards, mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon, c1=c1, c2=c2)) # Test that these compute at all and are self consistent. self.assertEqual(0.0, entropy_bonus) self.assertNear(value_loss_1, value_loss_2, 1e-6) self.assertNear(ppo_loss_1, ppo_loss_2, 1e-6) self.assertNear(combined_loss, ppo_loss_2 + (c1 * value_loss_2), 1e-6)
def test_collect_trajectories(self): self.rng_key, key1, key2, key3, key4 = jax_random.split(self.rng_key, num=5) observation_shape = (2, 3, 4) num_actions = 2 policy_params, policy_apply = ppo.policy_net( key1, (-1, -1) + observation_shape, num_actions, # flatten except batch and time # step dimensions. [layers.Flatten(num_axis_to_keep=2)]) # We'll get done at time-step #5, starting from 0, therefore in 6 steps. done_time_step = 5 env = fake_env.FakeEnv(observation_shape, num_actions, done_time_step=done_time_step) def policy_fun(obs, rng=None): rng, r = jax_random.split(rng) return policy_apply(obs, policy_params, rng=r), (), rng num_trajectories = 5 trajectories = ppo.collect_trajectories( env, policy_fun=policy_fun, num_trajectories=num_trajectories, policy="categorical-sampling", rng=key2) # Number of trajectories is as expected. self.assertEqual(num_trajectories, len(trajectories)) # Shapes of observations, actions and rewards are as expected. for observations, actions, rewards in trajectories: # observations are one more in number than rewards or actions. self.assertEqual((done_time_step + 2, ) + observation_shape, observations.shape) self.assertEqual((done_time_step + 1, ), actions.shape) self.assertEqual((done_time_step + 1, ), rewards.shape) # Test collect using a Policy and Value function. pnv_params, pnv_apply = ppo.policy_and_value_net( key3, (-1, -1) + observation_shape, num_actions, lambda: [layers.Flatten(num_axis_to_keep=2)]) def pnv_fun(obs, rng=None): rng, r = jax_random.split(rng) lp, v = pnv_apply(obs, pnv_params, rng=r) return lp, v, rng trajectories = ppo.collect_trajectories( env, policy_fun=pnv_fun, num_trajectories=num_trajectories, policy="categorical-sampling", rng=key4) # Number of trajectories is as expected. self.assertEqual(num_trajectories, len(trajectories)) # Shapes of observations, actions and rewards are as expected. for observations, actions, rewards in trajectories: # observations are one more in number than rewards or actions. self.assertEqual((done_time_step + 2, ) + observation_shape, observations.shape) self.assertEqual((done_time_step + 1, ), actions.shape) self.assertEqual((done_time_step + 1, ), rewards.shape)