Beispiel #1
0
  def test_value_loss(self):
    rewards = np.array([
        [1, 2, 4, 8, 16, 32, 64, 128],
        [1, 1, 1, 1, 1, 1, 1, 1],
    ])

    rewards_mask = np.array([
        [1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0],
    ])

    gamma = 0.5

    # Random observations and a value function that returns a constant value.
    # NOTE: Observations have an extra time-step.
    B, T = rewards.shape  # pylint: disable=invalid-name
    observation_shape = (210, 160, 3)  # atari pong
    random_observations = np.random.uniform(size=(B, T + 1) + observation_shape)

    def value_net_apply(observations, params, rng=None):
      del params, rng
      # pylint: disable=invalid-name
      B, T_p_1, OBS = (observations.shape[0], observations.shape[1],
                       observations.shape[2:])
      del OBS
      return np.ones((B, T_p_1, 1))
      # pylint: enable=invalid-name

    value_prediction = value_net_apply(random_observations, [])

    with jax.disable_jit():
      value_loss = ppo.value_loss_given_predictions(
          value_prediction,
          rewards,
          rewards_mask,
          gamma)

    self.assertNear(53.3637084961, value_loss, 1e-6)
Beispiel #2
0
    def test_combined_loss(self):
        self.rng_key, key1, key2 = jax_random.split(self.rng_key, num=3)

        B, T, A, OBS = 2, 10, 2, (28, 28, 3)  # pylint: disable=invalid-name
        batch_observation_shape = (-1, -1) + OBS

        old_params, _ = ppo.policy_and_value_net(
            key1, batch_observation_shape, A,
            [layers.Flatten(num_axis_to_keep=2)])

        new_params, net_apply = ppo.policy_and_value_net(
            key2, batch_observation_shape, A,
            [layers.Flatten(num_axis_to_keep=2)])

        # Generate a batch of observations.

        observations = np.random.uniform(size=(B, T + 1) + OBS)
        actions = np.random.randint(0, A, size=(B, T))
        rewards = np.random.uniform(0, 1, size=(B, T))
        mask = np.ones_like(rewards)

        # Just test that this computes at all.
        new_log_probabs, _ = net_apply(observations, new_params)
        old_log_probabs, value_predictions = net_apply(observations,
                                                       old_params)

        gamma = 0.99
        lambda_ = 0.95
        epsilon = 0.2
        c1 = 1.0
        c2 = 0.01

        value_loss_1 = ppo.value_loss_given_predictions(value_predictions,
                                                        rewards,
                                                        mask,
                                                        gamma=gamma)
        ppo_loss_1 = ppo.ppo_loss_given_predictions(new_log_probabs,
                                                    old_log_probabs,
                                                    value_predictions,
                                                    actions,
                                                    rewards,
                                                    mask,
                                                    gamma=gamma,
                                                    lambda_=lambda_,
                                                    epsilon=epsilon)

        (combined_loss, ppo_loss_2, value_loss_2,
         entropy_bonus) = (ppo.combined_loss(new_params,
                                             old_params,
                                             net_apply,
                                             observations,
                                             actions,
                                             rewards,
                                             mask,
                                             gamma=gamma,
                                             lambda_=lambda_,
                                             epsilon=epsilon,
                                             c1=c1,
                                             c2=c2))

        # Test that these compute at all and are self consistent.
        self.assertEqual(0.0, entropy_bonus)
        self.assertNear(value_loss_1, value_loss_2, 1e-6)
        self.assertNear(ppo_loss_1, ppo_loss_2, 1e-6)
        self.assertNear(combined_loss, ppo_loss_2 + (c1 * value_loss_2), 1e-6)