예제 #1
0
    def test_ppo_loss(self):
        self.rng_key, key1, key2, key3 = jax_random.split(self.rng_key, num=4)

        B, T, A, OBS = 2, 10, 2, (28, 28, 3)  # pylint: disable=invalid-name
        batch_observation_shape = (-1, -1) + OBS

        old_policy_params, _ = ppo.policy_net(key1, batch_observation_shape, A,
                                              [stax.Flatten(2)])

        new_policy_params, policy_apply = ppo.policy_net(
            key2, batch_observation_shape, A, [stax.Flatten(2)])

        value_params, value_apply = ppo.value_net(key3,
                                                  batch_observation_shape, A,
                                                  [stax.Flatten(2)])

        # Generate a batch of observations.

        observations = np.random.uniform(size=(B, T + 1) + OBS)
        actions = np.random.randint(0, A, size=(B, T))
        rewards = np.random.uniform(0, 1, size=(B, T))
        mask = np.ones_like(rewards)

        # Just test that this computes at all.
        _ = ppo.ppo_loss(policy_apply, new_policy_params, old_policy_params,
                         value_apply, value_params, observations, actions,
                         rewards, mask)
예제 #2
0
    def test_collect_trajectories(self):
        observation_shape = (2, 3, 4)
        num_actions = 2
        policy_params, policy_apply = ppo.policy_net(
            self.rng_key,
            (-1, -1) + observation_shape,
            num_actions,
            # flatten except batch and time
            # step dimensions.
            [stax.Flatten(2)])

        # We'll get done at time-step #10, starting from 0, therefore in 11 steps.
        done_time_step = 5
        env = fake_env.FakeEnv(observation_shape,
                               num_actions,
                               done_time_step=done_time_step)

        num_trajectories = 5
        trajectories = ppo.collect_trajectories(env,
                                                policy_apply,
                                                policy_params,
                                                num_trajectories,
                                                policy="categorical-sampling")

        # Number of trajectories is as expected.
        self.assertEqual(num_trajectories, len(trajectories))

        # Shapes of observations, actions and rewards are as expected.
        for observations, actions, rewards in trajectories:
            # observations are one more in number than rewards or actions.
            self.assertEqual((done_time_step + 2, ) + observation_shape,
                             observations.shape)
            self.assertEqual((done_time_step + 1, ), actions.shape)
            self.assertEqual((done_time_step + 1, ), rewards.shape)
예제 #3
0
    def test_policy_net(self):
        observation_shape = (3, 4)
        num_actions = 2
        policy_params, policy_apply = ppo.policy_net(
            self.rng_key,
            (-1, -1) + observation_shape,
            num_actions,
            # flatten except batch and time
            # step dimensions.
            [stax.Flatten(2)])

        # Generate a batch of observations.
        batch = 2
        time_steps = 10
        batch_of_observations = np.random.uniform(size=(batch, time_steps) +
                                                  observation_shape)

        # Apply the policy net on observations
        policy_output = policy_apply(policy_params, batch_of_observations)

        # Verify certain expectations on the output.
        self.assertEqual((batch, time_steps, num_actions), policy_output.shape)

        # Also last axis normalizes to 1, since these are probabilities.
        sum_actions = np.sum(policy_output, axis=-1)
        self.assertAllClose(np.ones_like(sum_actions), sum_actions)
예제 #4
0
    def test_value_net(self):
        observation_shape = (3, 4, 5)
        num_actions = 2
        value_params, value_apply = ppo.value_net(
            self.rng_key, (-1, -1) + observation_shape, num_actions,
            [stax.Flatten(num_axis_to_keep=2)])
        batch = 2
        time_steps = 10
        batch_of_observations = np.random.uniform(size=(batch, time_steps) +
                                                  observation_shape)
        value_output = value_apply(value_params, batch_of_observations)

        # NOTE: The extra dimension at the end because of Dense(1).
        self.assertEqual((batch, time_steps, 1), value_output.shape)
예제 #5
0
    def test_policy_and_value_net(self):
        observation_shape = (3, 4, 5)
        batch_observation_shape = (-1, -1) + observation_shape
        num_actions = 2
        pnv_params, pnv_apply = ppo.policy_and_value_net(
            self.rng_key, batch_observation_shape, num_actions,
            [stax.Flatten(2)])
        batch = 2
        time_steps = 10
        batch_of_observations = np.random.uniform(size=(batch, time_steps) +
                                                  observation_shape)
        pnv_output = pnv_apply(pnv_params, batch_of_observations)

        # Output is a list, first is probab of actions and the next is value output.
        self.assertEqual(2, len(pnv_output))
        self.assertEqual((batch, time_steps, num_actions), pnv_output[0].shape)
        self.assertEqual((batch, time_steps, 1), pnv_output[1].shape)