Esempio n. 1
0
    def test_collect_trajectories(self):
        observation_shape = (2, 3, 4)
        num_actions = 2
        policy_params, policy_apply = ppo.policy_net(
            self.rng_key,
            (-1, -1) + observation_shape,
            num_actions,
            # flatten except batch and time
            # step dimensions.
            [layers.Flatten(num_axis_to_keep=2)])

        # We'll get done at time-step #5, starting from 0, therefore in 6 steps.
        done_time_step = 5
        env = fake_env.FakeEnv(observation_shape,
                               num_actions,
                               done_time_step=done_time_step)

        num_trajectories = 5
        trajectories = ppo.collect_trajectories(
            env,
            policy_fun=lambda obs: policy_apply(obs, policy_params),
            num_trajectories=num_trajectories,
            policy="categorical-sampling")

        # Number of trajectories is as expected.
        self.assertEqual(num_trajectories, len(trajectories))

        # Shapes of observations, actions and rewards are as expected.
        for observations, actions, rewards in trajectories:
            # observations are one more in number than rewards or actions.
            self.assertEqual((done_time_step + 2, ) + observation_shape,
                             observations.shape)
            self.assertEqual((done_time_step + 1, ), actions.shape)
            self.assertEqual((done_time_step + 1, ), rewards.shape)

        # Test collect using a Policy and Value function.
        pnv_params, pnv_apply = ppo.policy_and_value_net(
            self.rng_key, (-1, -1) + observation_shape, num_actions,
            [layers.Flatten(num_axis_to_keep=2)])

        trajectories = ppo.collect_trajectories(
            env,
            policy_fun=lambda obs: pnv_apply(obs, pnv_params)[0],
            num_trajectories=num_trajectories,
            policy="categorical-sampling")

        # Number of trajectories is as expected.
        self.assertEqual(num_trajectories, len(trajectories))

        # Shapes of observations, actions and rewards are as expected.
        for observations, actions, rewards in trajectories:
            # observations are one more in number than rewards or actions.
            self.assertEqual((done_time_step + 2, ) + observation_shape,
                             observations.shape)
            self.assertEqual((done_time_step + 1, ), actions.shape)
            self.assertEqual((done_time_step + 1, ), rewards.shape)
Esempio n. 2
0
    def test_policy_and_value_net(self):
        observation_shape = (3, 4, 5)
        batch_observation_shape = (-1, -1) + observation_shape
        num_actions = 2
        pnv_params, pnv_apply = ppo.policy_and_value_net(
            self.rng_key, batch_observation_shape, num_actions,
            [layers.Flatten(num_axis_to_keep=2)])
        batch = 2
        time_steps = 10
        batch_of_observations = np.random.uniform(size=(batch, time_steps) +
                                                  observation_shape)
        pnv_output = pnv_apply(batch_of_observations, pnv_params)

        # Output is a list, first is probab of actions and the next is value output.
        self.assertEqual(2, len(pnv_output))
        self.assertEqual((batch, time_steps, num_actions), pnv_output[0].shape)
        self.assertEqual((batch, time_steps, 1), pnv_output[1].shape)
Esempio n. 3
0
    def test_collect_trajectories_max_timestep(self):
        self.rng_key, key1, key2 = jax_random.split(self.rng_key, num=3)
        observation_shape = (2, 3, 4)
        num_actions = 2
        pnv_params, pnv_apply = ppo.policy_and_value_net(
            key1, (-1, -1) + observation_shape, num_actions,
            lambda: [layers.Flatten(num_axis_to_keep=2)])

        def pnv_fun(obs, rng=None):
            rng, r = jax_random.split(rng)
            lp, v = pnv_apply(obs, pnv_params, rng=r)
            return lp, v, rng

        # We'll get done at time-step #5, starting from 0, therefore in 6 steps.
        done_time_step = 5
        env = fake_env.FakeEnv(observation_shape,
                               num_actions,
                               done_time_step=done_time_step)

        num_trajectories = 5

        # Let's collect trajectories only till `max_timestep`.
        max_timestep = 3

        # we're testing when we early stop the trajectory.
        assert max_timestep < done_time_step

        trajectories = ppo.collect_trajectories(
            env,
            policy_fun=pnv_fun,
            num_trajectories=num_trajectories,
            policy="categorical-sampling",
            max_timestep=max_timestep,
            rng=key2)

        # Number of trajectories is as expected.
        self.assertEqual(num_trajectories, len(trajectories))

        # Shapes of observations, actions and rewards are as expected.
        for observations, actions, rewards in trajectories:
            # observations are one more in number than rewards or actions.
            self.assertEqual((max_timestep, ) + observation_shape,
                             observations.shape)
            self.assertEqual((max_timestep - 1, ), actions.shape)
            self.assertEqual((max_timestep - 1, ), rewards.shape)
Esempio n. 4
0
    def test_combined_loss(self):
        self.rng_key, key1, key2 = jax_random.split(self.rng_key, num=3)

        B, T, A, OBS = 2, 10, 2, (28, 28, 3)  # pylint: disable=invalid-name
        batch_observation_shape = (-1, -1) + OBS

        old_params, _ = ppo.policy_and_value_net(
            key1, batch_observation_shape, A,
            [layers.Flatten(num_axis_to_keep=2)])

        new_params, net_apply = ppo.policy_and_value_net(
            key2, batch_observation_shape, A,
            [layers.Flatten(num_axis_to_keep=2)])

        # Generate a batch of observations.

        observations = np.random.uniform(size=(B, T + 1) + OBS)
        actions = np.random.randint(0, A, size=(B, T))
        rewards = np.random.uniform(0, 1, size=(B, T))
        mask = np.ones_like(rewards)

        # Just test that this computes at all.
        new_log_probabs, _ = net_apply(observations, new_params)
        old_log_probabs, value_predictions = net_apply(observations,
                                                       old_params)

        gamma = 0.99
        lambda_ = 0.95
        epsilon = 0.2
        c1 = 1.0
        c2 = 0.01

        value_loss_1 = ppo.value_loss_given_predictions(value_predictions,
                                                        rewards,
                                                        mask,
                                                        gamma=gamma)
        ppo_loss_1 = ppo.ppo_loss_given_predictions(new_log_probabs,
                                                    old_log_probabs,
                                                    value_predictions,
                                                    actions,
                                                    rewards,
                                                    mask,
                                                    gamma=gamma,
                                                    lambda_=lambda_,
                                                    epsilon=epsilon)

        (combined_loss, ppo_loss_2, value_loss_2,
         entropy_bonus) = (ppo.combined_loss(new_params,
                                             old_params,
                                             net_apply,
                                             observations,
                                             actions,
                                             rewards,
                                             mask,
                                             gamma=gamma,
                                             lambda_=lambda_,
                                             epsilon=epsilon,
                                             c1=c1,
                                             c2=c2))

        # Test that these compute at all and are self consistent.
        self.assertEqual(0.0, entropy_bonus)
        self.assertNear(value_loss_1, value_loss_2, 1e-6)
        self.assertNear(ppo_loss_1, ppo_loss_2, 1e-6)
        self.assertNear(combined_loss, ppo_loss_2 + (c1 * value_loss_2), 1e-6)
Esempio n. 5
0
    def test_collect_trajectories(self):
        self.rng_key, key1, key2, key3, key4 = jax_random.split(self.rng_key,
                                                                num=5)
        observation_shape = (2, 3, 4)
        num_actions = 2
        policy_params, policy_apply = ppo.policy_net(
            key1,
            (-1, -1) + observation_shape,
            num_actions,
            # flatten except batch and time
            # step dimensions.
            [layers.Flatten(num_axis_to_keep=2)])

        # We'll get done at time-step #5, starting from 0, therefore in 6 steps.
        done_time_step = 5
        env = fake_env.FakeEnv(observation_shape,
                               num_actions,
                               done_time_step=done_time_step)

        def policy_fun(obs, rng=None):
            rng, r = jax_random.split(rng)
            return policy_apply(obs, policy_params, rng=r), (), rng

        num_trajectories = 5
        trajectories = ppo.collect_trajectories(
            env,
            policy_fun=policy_fun,
            num_trajectories=num_trajectories,
            policy="categorical-sampling",
            rng=key2)

        # Number of trajectories is as expected.
        self.assertEqual(num_trajectories, len(trajectories))

        # Shapes of observations, actions and rewards are as expected.
        for observations, actions, rewards in trajectories:
            # observations are one more in number than rewards or actions.
            self.assertEqual((done_time_step + 2, ) + observation_shape,
                             observations.shape)
            self.assertEqual((done_time_step + 1, ), actions.shape)
            self.assertEqual((done_time_step + 1, ), rewards.shape)

        # Test collect using a Policy and Value function.
        pnv_params, pnv_apply = ppo.policy_and_value_net(
            key3, (-1, -1) + observation_shape, num_actions,
            lambda: [layers.Flatten(num_axis_to_keep=2)])

        def pnv_fun(obs, rng=None):
            rng, r = jax_random.split(rng)
            lp, v = pnv_apply(obs, pnv_params, rng=r)
            return lp, v, rng

        trajectories = ppo.collect_trajectories(
            env,
            policy_fun=pnv_fun,
            num_trajectories=num_trajectories,
            policy="categorical-sampling",
            rng=key4)

        # Number of trajectories is as expected.
        self.assertEqual(num_trajectories, len(trajectories))

        # Shapes of observations, actions and rewards are as expected.
        for observations, actions, rewards in trajectories:
            # observations are one more in number than rewards or actions.
            self.assertEqual((done_time_step + 2, ) + observation_shape,
                             observations.shape)
            self.assertEqual((done_time_step + 1, ), actions.shape)
            self.assertEqual((done_time_step + 1, ), rewards.shape)