def test_combined_loss(self): B, T, C, A, OBS = 2, 10, 1, 2, (28, 28, 3) # pylint: disable=invalid-name make_net = lambda: policy_based_utils.policy_and_value_net( # pylint: disable=g-long-lambda bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)], observation_space=gym.spaces.Box(shape=OBS, low=0, high=1), action_space=gym.spaces.Discrete(A), vocab_size=None, two_towers=True, )[0] net = make_net() observations = np.random.uniform(size=(B, T + 1) + OBS) actions = np.random.randint(0, A, size=(B, T, C)) input_signature = shapes.signature((observations, actions)) old_params, _ = net.init(input_signature) new_params, state = make_net().init(input_signature) # Generate a batch of observations. rewards = np.random.uniform(0, 1, size=(B, T)) mask = np.ones_like(rewards) # Just test that this computes at all. (new_log_probabs, value_predictions_new) = ( net((observations, actions), weights=new_params, state=state)) (old_log_probabs, value_predictions_old) = ( net((observations, actions), weights=old_params, state=state)) gamma = 0.99 lambda_ = 0.95 epsilon = 0.2 value_weight = 1.0 entropy_weight = 0.01 nontrainable_params = { 'gamma': gamma, 'lambda': lambda_, 'epsilon': epsilon, 'value_weight': value_weight, 'entropy_weight': entropy_weight, } (value_loss_1, _) = ppo.value_loss_given_predictions( value_predictions_new, rewards, mask, gamma=gamma, value_prediction_old=value_predictions_old, epsilon=epsilon) (ppo_loss_1, _) = ppo.ppo_loss_given_predictions( new_log_probabs[:, :-1], old_log_probabs[:, :-1], value_predictions_old, actions, rewards, mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon) (combined_loss, (ppo_loss_2, value_loss_2, entropy_bonus), _, state) = ( ppo.combined_loss(new_params, old_log_probabs[:, :-1], value_predictions_old, net, observations, actions, rewards, mask, nontrainable_params=nontrainable_params, state=state) ) # Test that these compute at all and are self consistent. self.assertGreater(entropy_bonus, 0.0) self.assertNear(value_loss_1, value_loss_2, 1e-5) self.assertNear(ppo_loss_1, ppo_loss_2, 1e-5) self.assertNear( combined_loss, ppo_loss_2 + (value_weight * value_loss_2) - (entropy_weight * entropy_bonus), 1e-5 )
def test_combined_loss(self): B, T, A, OBS = 2, 10, 2, (28, 28, 3) # pylint: disable=invalid-name batch_observation_shape = (1, 1) + OBS net = ppo.policy_and_value_net( n_controls=1, n_actions=A, vocab_size=None, bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)], two_towers=True, ) input_signature = ShapeDtype(batch_observation_shape) old_params, _ = net.init(input_signature) new_params, state = net.init(input_signature) # Generate a batch of observations. observations = np.random.uniform(size=(B, T + 1) + OBS) actions = np.random.randint(0, A, size=(B, T + 1)) rewards = np.random.uniform(0, 1, size=(B, T)) mask = np.ones_like(rewards) # Just test that this computes at all. (new_log_probabs, value_predictions_new) = ( net(observations, weights=new_params, state=state)) (old_log_probabs, value_predictions_old) = ( net(observations, weights=old_params, state=state)) gamma = 0.99 lambda_ = 0.95 epsilon = 0.2 value_weight = 1.0 entropy_weight = 0.01 nontrainable_params = { 'gamma': gamma, 'lambda': lambda_, 'epsilon': epsilon, 'value_weight': value_weight, 'entropy_weight': entropy_weight, } rewards_to_actions = np.eye(value_predictions_old.shape[1]) (value_loss_1, _) = ppo.value_loss_given_predictions( value_predictions_new, rewards, mask, gamma=gamma, value_prediction_old=value_predictions_old, epsilon=epsilon) (ppo_loss_1, _) = ppo.ppo_loss_given_predictions( new_log_probabs, old_log_probabs, value_predictions_old, actions, rewards_to_actions, rewards, mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon) (combined_loss, (ppo_loss_2, value_loss_2, entropy_bonus), _, state) = ( ppo.combined_loss(new_params, old_log_probabs, value_predictions_old, net, observations, actions, rewards_to_actions, rewards, mask, nontrainable_params=nontrainable_params, state=state) ) # Test that these compute at all and are self consistent. self.assertGreater(entropy_bonus, 0.0) self.assertNear(value_loss_1, value_loss_2, 1e-6) self.assertNear(ppo_loss_1, ppo_loss_2, 1e-6) self.assertNear( combined_loss, ppo_loss_2 + (value_weight * value_loss_2) - (entropy_weight * entropy_bonus), 1e-6 )
def test_combined_loss(self): self.rng_key, key1, key2 = jax_random.split(self.rng_key, num=3) B, T, A, OBS = 2, 10, 2, (28, 28, 3) # pylint: disable=invalid-name batch_observation_shape = (1, 1) + OBS net = ppo.policy_and_value_net( n_controls=1, n_actions=A, vocab_size=None, bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)], two_towers=True, ) old_params, _ = net.initialize_once(batch_observation_shape, np.float32, key1) new_params, state = net.initialize_once(batch_observation_shape, np.float32, key2) # Generate a batch of observations. observations = np.random.uniform(size=(B, T + 1) + OBS) actions = np.random.randint(0, A, size=(B, T + 1)) rewards = np.random.uniform(0, 1, size=(B, T)) mask = np.ones_like(rewards) # Just test that this computes at all. (new_log_probabs, value_predictions_new) = (net(observations, params=new_params, state=state)) (old_log_probabs, value_predictions_old) = (net(observations, params=old_params, state=state)) gamma = 0.99 lambda_ = 0.95 epsilon = 0.2 c1 = 1.0 c2 = 0.01 rewards_to_actions = np.eye(value_predictions_old.shape[1]) (value_loss_1, _) = ppo.value_loss_given_predictions( value_predictions_new, rewards, mask, gamma=gamma, value_prediction_old=value_predictions_old, epsilon=epsilon) (ppo_loss_1, _) = ppo.ppo_loss_given_predictions(new_log_probabs, old_log_probabs, value_predictions_old, actions, rewards_to_actions, rewards, mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon) (combined_loss, (ppo_loss_2, value_loss_2, entropy_bonus), _, state) = (ppo.combined_loss(new_params, old_log_probabs, value_predictions_old, net, observations, actions, rewards_to_actions, rewards, mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon, c1=c1, c2=c2, state=state)) # Test that these compute at all and are self consistent. self.assertGreater(entropy_bonus, 0.0) self.assertNear(value_loss_1, value_loss_2, 1e-6) self.assertNear(ppo_loss_1, ppo_loss_2, 1e-6) self.assertNear( combined_loss, ppo_loss_2 + (c1 * value_loss_2) - (c2 * entropy_bonus), 1e-6)