def test_nested_kl_divergence(self): zero = tf.constant([0.0] * 3, dtype=tf.float32) one = tf.constant([1.0] * 3, dtype=tf.float32) dist_neg_one = tfp.distributions.Normal(loc=-one, scale=one) dist_zero = tfp.distributions.Normal(loc=zero, scale=one) dist_one = tfp.distributions.Normal(loc=one, scale=one) nested_dist1 = [dist_zero, [dist_neg_one, dist_one]] nested_dist2 = [dist_one, [dist_one, dist_zero]] kl_divergence = ppo_utils.nested_kl_divergence( nested_dist1, nested_dist2) expected_kl_divergence = 3 * 3.0 # 3 * (0.5 + (2.0 + 0.5)) kl_divergence_ = self.evaluate(kl_divergence) self.assertAllClose(expected_kl_divergence, kl_divergence_) # test for distributions with different shapes one_reshaped = tf.constant([[1.0]] * 3, dtype=tf.float32) dist_neg_one_reshaped = tfp.distributions.Normal( loc=-one_reshaped, scale=one_reshaped) dist_one_reshaped = tfp.distributions.Normal( loc=one_reshaped, scale=one_reshaped) nested_dist1 = [dist_zero, [dist_neg_one_reshaped, dist_one]] nested_dist2 = [dist_one, [dist_one_reshaped, dist_zero]] kl_divergence = ppo_utils.nested_kl_divergence( nested_dist1, nested_dist2) expected_kl_divergence = 3 * 3.0 # 3 * (0.5 + (2.0 + 0.5)) kl_divergence_ = self.evaluate(kl_divergence) self.assertAllClose(expected_kl_divergence, kl_divergence_)
def _kl_divergence(self, time_steps, action_distribution_parameters, current_policy_distribution): outer_dims = list( range(nest_utils.get_outer_rank(time_steps, self.time_step_spec))) old_actions_distribution = ( distribution_spec.nested_distributions_from_specs( self._action_distribution_spec, action_distribution_parameters)) kl_divergence = ppo_utils.nested_kl_divergence( old_actions_distribution, current_policy_distribution, outer_dims=outer_dims) return kl_divergence
def _kl_divergence(self, time_steps, action_distribution_parameters, current_policy_distribution): """Compute mean KL divergence for 2 policies on given batch of timesteps""" outer_dims = list( range(nest_utils.get_outer_rank(time_steps, self.time_step_spec))) old_actions_distribution = distribution_spec.nested_distributions_from_specs( self._action_distribution_spec, action_distribution_parameters["dist_params"]) kl_divergence = ppo_utils.nested_kl_divergence( old_actions_distribution, current_policy_distribution, outer_dims=outer_dims) return kl_divergence