Beispiel #1
0
  def test_combined_loss(self):
    B, T, C, A, OBS = 2, 10, 1, 2, (28, 28, 3)  # pylint: disable=invalid-name

    make_net = lambda: policy_based_utils.policy_and_value_net(  # pylint: disable=g-long-lambda
        bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)],
        observation_space=gym.spaces.Box(shape=OBS, low=0, high=1),
        action_space=gym.spaces.Discrete(A),
        vocab_size=None,
        two_towers=True,
    )[0]
    net = make_net()

    observations = np.random.uniform(size=(B, T + 1) + OBS)
    actions = np.random.randint(0, A, size=(B, T, C))
    input_signature = shapes.signature((observations, actions))
    old_params, _ = net.init(input_signature)
    new_params, state = make_net().init(input_signature)

    # Generate a batch of observations.

    rewards = np.random.uniform(0, 1, size=(B, T))
    mask = np.ones_like(rewards)

    # Just test that this computes at all.
    (new_log_probabs, value_predictions_new) = (
        net((observations, actions), weights=new_params, state=state))
    (old_log_probabs, value_predictions_old) = (
        net((observations, actions), weights=old_params, state=state))

    gamma = 0.99
    lambda_ = 0.95
    epsilon = 0.2
    value_weight = 1.0
    entropy_weight = 0.01

    nontrainable_params = {
        'gamma': gamma,
        'lambda': lambda_,
        'epsilon': epsilon,
        'value_weight': value_weight,
        'entropy_weight': entropy_weight,
    }

    (value_loss_1, _) = ppo.value_loss_given_predictions(
        value_predictions_new, rewards, mask, gamma=gamma,
        value_prediction_old=value_predictions_old, epsilon=epsilon)
    (ppo_loss_1, _) = ppo.ppo_loss_given_predictions(
        new_log_probabs[:, :-1],
        old_log_probabs[:, :-1],
        value_predictions_old,
        actions,
        rewards,
        mask,
        gamma=gamma,
        lambda_=lambda_,
        epsilon=epsilon)

    (combined_loss, (ppo_loss_2, value_loss_2, entropy_bonus), _, state) = (
        ppo.combined_loss(new_params,
                          old_log_probabs[:, :-1],
                          value_predictions_old,
                          net,
                          observations,
                          actions,
                          rewards,
                          mask,
                          nontrainable_params=nontrainable_params,
                          state=state)
    )

    # Test that these compute at all and are self consistent.
    self.assertGreater(entropy_bonus, 0.0)
    self.assertNear(value_loss_1, value_loss_2, 1e-5)
    self.assertNear(ppo_loss_1, ppo_loss_2, 1e-5)
    self.assertNear(
        combined_loss,
        ppo_loss_2 + (value_weight * value_loss_2) -
        (entropy_weight * entropy_bonus),
        1e-5
    )
Beispiel #2
0
  def test_combined_loss(self):
    B, T, A, OBS = 2, 10, 2, (28, 28, 3)  # pylint: disable=invalid-name
    batch_observation_shape = (1, 1) + OBS

    net = ppo.policy_and_value_net(
        n_controls=1,
        n_actions=A,
        vocab_size=None,
        bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)],
        two_towers=True,
    )

    input_signature = ShapeDtype(batch_observation_shape)
    old_params, _ = net.init(input_signature)
    new_params, state = net.init(input_signature)

    # Generate a batch of observations.

    observations = np.random.uniform(size=(B, T + 1) + OBS)
    actions = np.random.randint(0, A, size=(B, T + 1))
    rewards = np.random.uniform(0, 1, size=(B, T))
    mask = np.ones_like(rewards)

    # Just test that this computes at all.
    (new_log_probabs, value_predictions_new) = (
        net(observations, weights=new_params, state=state))
    (old_log_probabs, value_predictions_old) = (
        net(observations, weights=old_params, state=state))

    gamma = 0.99
    lambda_ = 0.95
    epsilon = 0.2
    value_weight = 1.0
    entropy_weight = 0.01

    nontrainable_params = {
        'gamma': gamma,
        'lambda': lambda_,
        'epsilon': epsilon,
        'value_weight': value_weight,
        'entropy_weight': entropy_weight,
    }

    rewards_to_actions = np.eye(value_predictions_old.shape[1])
    (value_loss_1, _) = ppo.value_loss_given_predictions(
        value_predictions_new, rewards, mask, gamma=gamma,
        value_prediction_old=value_predictions_old, epsilon=epsilon)
    (ppo_loss_1, _) = ppo.ppo_loss_given_predictions(
        new_log_probabs,
        old_log_probabs,
        value_predictions_old,
        actions,
        rewards_to_actions,
        rewards,
        mask,
        gamma=gamma,
        lambda_=lambda_,
        epsilon=epsilon)

    (combined_loss, (ppo_loss_2, value_loss_2, entropy_bonus), _, state) = (
        ppo.combined_loss(new_params,
                          old_log_probabs,
                          value_predictions_old,
                          net,
                          observations,
                          actions,
                          rewards_to_actions,
                          rewards,
                          mask,
                          nontrainable_params=nontrainable_params,
                          state=state)
    )

    # Test that these compute at all and are self consistent.
    self.assertGreater(entropy_bonus, 0.0)
    self.assertNear(value_loss_1, value_loss_2, 1e-6)
    self.assertNear(ppo_loss_1, ppo_loss_2, 1e-6)
    self.assertNear(
        combined_loss,
        ppo_loss_2 + (value_weight * value_loss_2) -
        (entropy_weight * entropy_bonus),
        1e-6
    )
Beispiel #3
0
    def test_combined_loss(self):
        self.rng_key, key1, key2 = jax_random.split(self.rng_key, num=3)

        B, T, A, OBS = 2, 10, 2, (28, 28, 3)  # pylint: disable=invalid-name
        batch_observation_shape = (1, 1) + OBS

        net = ppo.policy_and_value_net(
            n_controls=1,
            n_actions=A,
            vocab_size=None,
            bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)],
            two_towers=True,
        )

        old_params, _ = net.initialize_once(batch_observation_shape,
                                            np.float32, key1)
        new_params, state = net.initialize_once(batch_observation_shape,
                                                np.float32, key2)

        # Generate a batch of observations.

        observations = np.random.uniform(size=(B, T + 1) + OBS)
        actions = np.random.randint(0, A, size=(B, T + 1))
        rewards = np.random.uniform(0, 1, size=(B, T))
        mask = np.ones_like(rewards)

        # Just test that this computes at all.
        (new_log_probabs, value_predictions_new) = (net(observations,
                                                        params=new_params,
                                                        state=state))
        (old_log_probabs, value_predictions_old) = (net(observations,
                                                        params=old_params,
                                                        state=state))

        gamma = 0.99
        lambda_ = 0.95
        epsilon = 0.2
        c1 = 1.0
        c2 = 0.01

        rewards_to_actions = np.eye(value_predictions_old.shape[1])
        (value_loss_1, _) = ppo.value_loss_given_predictions(
            value_predictions_new,
            rewards,
            mask,
            gamma=gamma,
            value_prediction_old=value_predictions_old,
            epsilon=epsilon)
        (ppo_loss_1, _) = ppo.ppo_loss_given_predictions(new_log_probabs,
                                                         old_log_probabs,
                                                         value_predictions_old,
                                                         actions,
                                                         rewards_to_actions,
                                                         rewards,
                                                         mask,
                                                         gamma=gamma,
                                                         lambda_=lambda_,
                                                         epsilon=epsilon)

        (combined_loss, (ppo_loss_2, value_loss_2, entropy_bonus), _,
         state) = (ppo.combined_loss(new_params,
                                     old_log_probabs,
                                     value_predictions_old,
                                     net,
                                     observations,
                                     actions,
                                     rewards_to_actions,
                                     rewards,
                                     mask,
                                     gamma=gamma,
                                     lambda_=lambda_,
                                     epsilon=epsilon,
                                     c1=c1,
                                     c2=c2,
                                     state=state))

        # Test that these compute at all and are self consistent.
        self.assertGreater(entropy_bonus, 0.0)
        self.assertNear(value_loss_1, value_loss_2, 1e-6)
        self.assertNear(ppo_loss_1, ppo_loss_2, 1e-6)
        self.assertNear(
            combined_loss,
            ppo_loss_2 + (c1 * value_loss_2) - (c2 * entropy_bonus), 1e-6)