Esempio n. 1
0
  def test_value_loss(self):
    rewards = np.array([
        [1, 2, 4, 8, 16, 32, 64, 128],
        [1, 1, 1, 1, 1, 1, 1, 1],
    ])

    rewards_mask = np.array([
        [1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0],
    ])

    gamma = 0.5

    # Random observations and a value function that returns a constant value.
    # NOTE: Observations have an extra time-step.
    B, T = rewards.shape  # pylint: disable=invalid-name
    observation_shape = (210, 160, 3)  # atari pong
    random_observations = np.random.uniform(size=(B, T + 1) + observation_shape)

    def value_net_apply(observations, params, rng=None):
      del params, rng
      # pylint: disable=invalid-name
      B, T_p_1, OBS = (observations.shape[0], observations.shape[1],
                       observations.shape[2:])
      del OBS
      return np.ones((B, T_p_1, 1))
      # pylint: enable=invalid-name

    value_prediction = value_net_apply(random_observations, [])

    with jax.disable_jit():
      value_loss = ppo.value_loss_given_predictions(
          value_prediction,
          rewards,
          rewards_mask,
          gamma)

    self.assertNear(53.3637084961, value_loss, 1e-6)
Esempio n. 2
0
    def test_combined_loss(self):
        self.rng_key, key1, key2 = jax_random.split(self.rng_key, num=3)

        B, T, A, OBS = 2, 10, 2, (28, 28, 3)  # pylint: disable=invalid-name
        batch_observation_shape = (1, 1) + OBS

        old_params, _, _ = ppo.policy_and_value_net(
            key1, batch_observation_shape, np.float32, A,
            lambda: [layers.Flatten(n_axes_to_keep=2)])

        new_params, state, net_apply = ppo.policy_and_value_net(
            key2, batch_observation_shape, np.float32, A,
            lambda: [layers.Flatten(n_axes_to_keep=2)])

        # Generate a batch of observations.

        observations = np.random.uniform(size=(B, T + 1) + OBS)
        actions = np.random.randint(0, A, size=(B, T))
        rewards = np.random.uniform(0, 1, size=(B, T))
        mask = np.ones_like(rewards)

        # Just test that this computes at all.
        (new_log_probabs,
         value_predictions_new), _ = net_apply(observations, new_params, state)
        (old_log_probabs,
         value_predictions_old), _ = net_apply(observations, old_params, state)

        gamma = 0.99
        lambda_ = 0.95
        epsilon = 0.2
        c1 = 1.0
        c2 = 0.01

        (value_loss_1, _) = ppo.value_loss_given_predictions(
            value_predictions_new,
            rewards,
            mask,
            gamma=gamma,
            value_prediction_old=value_predictions_old,
            epsilon=epsilon)
        (ppo_loss_1, _) = ppo.ppo_loss_given_predictions(new_log_probabs,
                                                         old_log_probabs,
                                                         value_predictions_old,
                                                         actions,
                                                         rewards,
                                                         mask,
                                                         gamma=gamma,
                                                         lambda_=lambda_,
                                                         epsilon=epsilon)

        (combined_loss, (ppo_loss_2, value_loss_2, entropy_bonus), _,
         state) = (ppo.combined_loss(new_params,
                                     old_log_probabs,
                                     value_predictions_old,
                                     net_apply,
                                     observations,
                                     actions,
                                     rewards,
                                     mask,
                                     gamma=gamma,
                                     lambda_=lambda_,
                                     epsilon=epsilon,
                                     c1=c1,
                                     c2=c2,
                                     state=state))

        # Test that these compute at all and are self consistent.
        self.assertGreater(entropy_bonus, 0.0)
        self.assertNear(value_loss_1, value_loss_2, 1e-6)
        self.assertNear(ppo_loss_1, ppo_loss_2, 1e-6)
        self.assertNear(
            combined_loss,
            ppo_loss_2 + (c1 * value_loss_2) - (c2 * entropy_bonus), 1e-6)