def test_collect_trajectories(self): observation_shape = (2, 3, 4) num_actions = 2 policy_params, policy_apply = ppo.policy_net( self.rng_key, (-1, -1) + observation_shape, num_actions, # flatten except batch and time # step dimensions. [layers.Flatten(num_axis_to_keep=2)]) # We'll get done at time-step #5, starting from 0, therefore in 6 steps. done_time_step = 5 env = fake_env.FakeEnv(observation_shape, num_actions, done_time_step=done_time_step) num_trajectories = 5 trajectories = ppo.collect_trajectories( env, policy_fun=lambda obs: policy_apply(obs, policy_params), num_trajectories=num_trajectories, policy="categorical-sampling") # Number of trajectories is as expected. self.assertEqual(num_trajectories, len(trajectories)) # Shapes of observations, actions and rewards are as expected. for observations, actions, rewards in trajectories: # observations are one more in number than rewards or actions. self.assertEqual((done_time_step + 2, ) + observation_shape, observations.shape) self.assertEqual((done_time_step + 1, ), actions.shape) self.assertEqual((done_time_step + 1, ), rewards.shape) # Test collect using a Policy and Value function. pnv_params, pnv_apply = ppo.policy_and_value_net( self.rng_key, (-1, -1) + observation_shape, num_actions, [layers.Flatten(num_axis_to_keep=2)]) trajectories = ppo.collect_trajectories( env, policy_fun=lambda obs: pnv_apply(obs, pnv_params)[0], num_trajectories=num_trajectories, policy="categorical-sampling") # Number of trajectories is as expected. self.assertEqual(num_trajectories, len(trajectories)) # Shapes of observations, actions and rewards are as expected. for observations, actions, rewards in trajectories: # observations are one more in number than rewards or actions. self.assertEqual((done_time_step + 2, ) + observation_shape, observations.shape) self.assertEqual((done_time_step + 1, ), actions.shape) self.assertEqual((done_time_step + 1, ), rewards.shape)
def test_collect_trajectories(self): observation_shape = (2, 3, 4) num_actions = 2 policy_params, policy_apply = ppo.policy_net( self.rng_key, (-1, -1) + observation_shape, num_actions, # flatten except batch and time # step dimensions. [stax.Flatten(2)]) # We'll get done at time-step #10, starting from 0, therefore in 11 steps. done_time_step = 5 env = fake_env.FakeEnv(observation_shape, num_actions, done_time_step=done_time_step) num_trajectories = 5 trajectories = ppo.collect_trajectories(env, policy_apply, policy_params, num_trajectories, policy="categorical-sampling") # Number of trajectories is as expected. self.assertEqual(num_trajectories, len(trajectories)) # Shapes of observations, actions and rewards are as expected. for observations, actions, rewards in trajectories: # observations are one more in number than rewards or actions. self.assertEqual((done_time_step + 2, ) + observation_shape, observations.shape) self.assertEqual((done_time_step + 1, ), actions.shape) self.assertEqual((done_time_step + 1, ), rewards.shape)
def test_collect_trajectories_max_timestep(self): self.rng_key, key1, key2 = jax_random.split(self.rng_key, num=3) observation_shape = (2, 3, 4) num_actions = 2 pnv_params, pnv_apply = ppo.policy_and_value_net( key1, (-1, -1) + observation_shape, num_actions, lambda: [layers.Flatten(num_axis_to_keep=2)]) def pnv_fun(obs, rng=None): rng, r = jax_random.split(rng) lp, v = pnv_apply(obs, pnv_params, rng=r) return lp, v, rng # We'll get done at time-step #5, starting from 0, therefore in 6 steps. done_time_step = 5 env = fake_env.FakeEnv(observation_shape, num_actions, done_time_step=done_time_step) num_trajectories = 5 # Let's collect trajectories only till `max_timestep`. max_timestep = 3 # we're testing when we early stop the trajectory. assert max_timestep < done_time_step trajectories = ppo.collect_trajectories( env, policy_fun=pnv_fun, num_trajectories=num_trajectories, policy="categorical-sampling", max_timestep=max_timestep, rng=key2) # Number of trajectories is as expected. self.assertEqual(num_trajectories, len(trajectories)) # Shapes of observations, actions and rewards are as expected. for observations, actions, rewards in trajectories: # observations are one more in number than rewards or actions. self.assertEqual((max_timestep, ) + observation_shape, observations.shape) self.assertEqual((max_timestep - 1, ), actions.shape) self.assertEqual((max_timestep - 1, ), rewards.shape)
def test_collect_trajectories_max_timestep(self): observation_shape = (2, 3, 4) num_actions = 2 policy_params, policy_apply = ppo.policy_net( self.rng_key, (-1, -1) + observation_shape, num_actions, # flatten except batch and time # step dimensions. [layers.Flatten(num_axis_to_keep=2)]) # We'll get done at time-step #5, starting from 0, therefore in 6 steps. done_time_step = 5 env = fake_env.FakeEnv(observation_shape, num_actions, done_time_step=done_time_step) num_trajectories = 5 # Let's collect trajectories only till `max_timestep`. max_timestep = 3 # we're testing when we early stop the trajectory. assert max_timestep < done_time_step trajectories = ppo.collect_trajectories(env, policy_apply, policy_params, num_trajectories, policy="categorical-sampling", max_timestep=max_timestep) # Number of trajectories is as expected. self.assertEqual(num_trajectories, len(trajectories)) # Shapes of observations, actions and rewards are as expected. for observations, actions, rewards in trajectories: # observations are one more in number than rewards or actions. self.assertEqual((max_timestep, ) + observation_shape, observations.shape) self.assertEqual((max_timestep - 1, ), actions.shape) self.assertEqual((max_timestep - 1, ), rewards.shape)
def test_collect_trajectories(self): self.rng_key, key1, key2, key3, key4 = jax_random.split(self.rng_key, num=5) observation_shape = (2, 3, 4) num_actions = 2 policy_params, policy_apply = ppo.policy_net( key1, (-1, -1) + observation_shape, num_actions, # flatten except batch and time # step dimensions. [layers.Flatten(num_axis_to_keep=2)]) # We'll get done at time-step #5, starting from 0, therefore in 6 steps. done_time_step = 5 env = fake_env.FakeEnv(observation_shape, num_actions, done_time_step=done_time_step) def policy_fun(obs, rng=None): rng, r = jax_random.split(rng) return policy_apply(obs, policy_params, rng=r), (), rng num_trajectories = 5 trajectories = ppo.collect_trajectories( env, policy_fun=policy_fun, num_trajectories=num_trajectories, policy="categorical-sampling", rng=key2) # Number of trajectories is as expected. self.assertEqual(num_trajectories, len(trajectories)) # Shapes of observations, actions and rewards are as expected. for observations, actions, rewards in trajectories: # observations are one more in number than rewards or actions. self.assertEqual((done_time_step + 2, ) + observation_shape, observations.shape) self.assertEqual((done_time_step + 1, ), actions.shape) self.assertEqual((done_time_step + 1, ), rewards.shape) # Test collect using a Policy and Value function. pnv_params, pnv_apply = ppo.policy_and_value_net( key3, (-1, -1) + observation_shape, num_actions, lambda: [layers.Flatten(num_axis_to_keep=2)]) def pnv_fun(obs, rng=None): rng, r = jax_random.split(rng) lp, v = pnv_apply(obs, pnv_params, rng=r) return lp, v, rng trajectories = ppo.collect_trajectories( env, policy_fun=pnv_fun, num_trajectories=num_trajectories, policy="categorical-sampling", rng=key4) # Number of trajectories is as expected. self.assertEqual(num_trajectories, len(trajectories)) # Shapes of observations, actions and rewards are as expected. for observations, actions, rewards in trajectories: # observations are one more in number than rewards or actions. self.assertEqual((done_time_step + 2, ) + observation_shape, observations.shape) self.assertEqual((done_time_step + 1, ), actions.shape) self.assertEqual((done_time_step + 1, ), rewards.shape)