def test_kl_constraint_loss_stop_gradients(self): """Tests the stop gradients in the `kl_constraint_loss` function. The `alpha_loss` term should not affect the KL and the `kl` term should not affect `alpha`. """ kl = jnp.array(1., jnp.float32) alpha = jnp.array(1., jnp.float32) _, _, alpha = mpo_ops.kl_constraint_loss(kl, mpo_ops.LagrangePenalty( alpha=alpha, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False), _PROJECTION_OPERATOR) def kl_loss_fn(alpha_): penalty = mpo_ops.LagrangePenalty( alpha=alpha_, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False) kl_loss, _, _ = mpo_ops.kl_constraint_loss( kl, penalty, _PROJECTION_OPERATOR) return kl_loss kl_gradients = jax.grad(kl_loss_fn)(alpha) def alpha_loss_fn(kl_): penalty = mpo_ops.LagrangePenalty( alpha=alpha, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False) _, alpha_loss, _ = mpo_ops.kl_constraint_loss( kl_, penalty, _PROJECTION_OPERATOR) return alpha_loss alpha_gradients = jax.grad(alpha_loss_fn)(kl) # Test that there are no gradients of KL w.r.t alpha self.assertEqual(kl_gradients, 0.) # Test that there are no gradients of alpha w.r.t kl self.assertEqual(alpha_gradients, 0.)
def test_vmpo_input_axis_order_equivalence(self, per_dimension): """Test loss functions are equivalent regardless of axis order.""" key = jax.random.PRNGKey(_RANDOM_SEED) key, new_key = jax.random.split(key) params = _init_params(new_key) out, vmpo_inputs = get_common_loss_fn_inputs(params, key, 'advantages') kl_constraints = get_coupled_kl_constraints(out, params, per_dimension=per_dimension) vmpo_inputs.update({'kl_constraints': kl_constraints}) # Original loss fn inputs are [T B], tb_loss, tb_outputs = mpo_ops.vmpo_loss(**vmpo_inputs) mean_tb_loss = jnp.mean(tb_loss) # Swap axes and try [B T] vmpo_inputs.update({ 'sample_log_probs': jnp.swapaxes(vmpo_inputs['sample_log_probs'], 0, 1), 'advantages': jnp.swapaxes(vmpo_inputs['advantages'], 0, 1), 'kl_constraints': [(jnp.swapaxes(kl, 0, 1), mpo_ops.LagrangePenalty( alpha=jnp.swapaxes(pen.alpha, 0, 1), epsilon=pen.epsilon, per_dimension=pen.per_dimension)) for (kl, pen) in kl_constraints], }) bt_loss, bt_outputs = mpo_ops.vmpo_loss(**vmpo_inputs) mean_bt_loss = jnp.mean(bt_loss) self.assertAlmostEqual(mean_tb_loss, mean_bt_loss, places=4) self.assertEqual(tb_outputs.num_samples, bt_outputs.num_samples)
def alpha_loss_fn(kl_): penalty = mpo_ops.LagrangePenalty(alpha=alpha, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False) _, alpha_loss, _ = mpo_ops.kl_constraint_loss( kl_, penalty, _PROJECTION_OPERATOR) return alpha_loss
def test_kl_constraint_loss_gradients(self): """Tests the gradients in the `_kl_constraint_loss` method.""" kl = jnp.array(1., jnp.float32) alpha = jnp.array(1., jnp.float32) _, _, alpha = mpo_ops.kl_constraint_loss(kl, mpo_ops.LagrangePenalty( alpha=alpha, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False), _PROJECTION_OPERATOR) def alpha_loss_fn(alpha_): penalty = mpo_ops.LagrangePenalty( alpha=alpha_, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False) _, alpha_loss, _ = mpo_ops.kl_constraint_loss( kl, penalty, _PROJECTION_OPERATOR) return alpha_loss alpha_gradients = jax.grad(alpha_loss_fn)(alpha) actual_alpha_gradients = _EPSILON_MEAN_BOUND - kl def kl_loss_fn(kl_): penalty = mpo_ops.LagrangePenalty( alpha=alpha, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False) kl_loss, _, _ = mpo_ops.kl_constraint_loss( kl_, penalty, _PROJECTION_OPERATOR) return kl_loss kl_gradients = jax.grad(kl_loss_fn)(kl) actual_kl_gradients = alpha self.assertAlmostEqual(kl_gradients, actual_kl_gradients) self.assertAlmostEqual(alpha_gradients, actual_alpha_gradients)
def mean_weights_fn(target_, temperature_): temperature_constraint = mpo_ops.LagrangePenalty( temperature_, _EPSILON_BOUND) _, weights, _ = e_step_fn( target_, temperature_constraint=temperature_constraint, projection_operator=_PROJECTION_OPERATOR, **additional_inputs) return jnp.mean(weights)
def test_importance_weights( self, advantages, importance_weights, expected_temperature_loss): """Test that importance weights have the correct effect.""" temperature_loss, _, _ = mpo_ops.vmpo_compute_weights_and_temperature_loss( advantages, np.ones_like(importance_weights), importance_weights, mpo_ops.LagrangePenalty(1.0, _EPSILON_BOUND), functools.partial(np.clip, a_min=1e-8, a_max=None), 1.0) self.assertAlmostEqual( temperature_loss, expected_temperature_loss, places=4)
def test_restarting_weights( self, advantages, restarting_weights, expected_temperature_loss): """Test that calculation is correct if restarting weight is set to 0.""" temperature_loss, _, _ = mpo_ops.vmpo_compute_weights_and_temperature_loss( advantages, restarting_weights, np.ones_like(restarting_weights), mpo_ops.LagrangePenalty(1.0, _EPSILON_BOUND), functools.partial(np.clip, a_min=1e-8, a_max=None), 1.0) self.assertAlmostEqual( temperature_loss, expected_temperature_loss, places=4)
def get_decoupled_kl_constraints(out, params, per_dimension): # Factorize KL for Gaussian. kl_mean, kl_covariance = ( distributions.decoupled_multivariate_normal_kl_divergence( out['target_pi_params']['mean'], out['target_pi_params']['stddev'], out['pi_params']['mean'], out['pi_params']['stddev'], per_dimension=per_dimension)) alpha_mean = params['mpo']['alpha_mean'] * jnp.ones_like(kl_mean) alpha_covariance = params['mpo']['alpha_covariance'] * jnp.ones_like( kl_covariance) return [ (kl_mean, mpo_ops.LagrangePenalty( alpha=alpha_mean, epsilon=_EPSILON_MEAN_BOUND, per_dimension=per_dimension)), (kl_covariance, mpo_ops.LagrangePenalty( alpha=alpha_covariance, epsilon=_EPSILON_COVARIANCE_BOUND, per_dimension=per_dimension)), ]
def get_common_loss_fn_inputs(params, key, target_name): out = _mock_outputs(params['online'], params['target'], key, target_name) pi_sample_log_probs = _DIAGONAL_GAUSSIAN_DIST.logprob( out['target_actions'], out['pi_params']['mean'], out['pi_params']['stddev']) return out, { 'sample_log_probs': pi_sample_log_probs, target_name: out[target_name], 'temperature_constraint': mpo_ops.LagrangePenalty( params['mpo']['temperature'], _EPSILON_BOUND)}
def get_coupled_kl_constraints(out, params, per_dimension): kl_mean, kl_covariance = (_decoupled_multivariate_normal_kl_divergence( out['target_pi_params']['mean'], out['target_pi_params']['stddev'], out['pi_params']['mean'], out['pi_params']['stddev'], per_dimension=per_dimension)) alpha_mean = params['mpo']['alpha_mean'] * jnp.ones_like(kl_mean) return [(kl_mean + kl_covariance, mpo_ops.LagrangePenalty(alpha=alpha_mean, epsilon=_EPSILON_MEAN_BOUND + _EPSILON_COVARIANCE_BOUND, per_dimension=per_dimension))]
def test_mpo_input_axis_order_equivalence(self, per_dimension): """Test loss functions are equivalent regardless of axis order.""" key = jax.random.PRNGKey(_RANDOM_SEED) key, new_key = jax.random.split(key) params = _init_params(new_key) out, mpo_inputs = get_common_loss_fn_inputs(params, key, 'sample_q_values') kl_constraints = get_coupled_kl_constraints( out, params, per_dimension=per_dimension) mpo_inputs.update({'kl_constraints': kl_constraints}) # Original loss fn inputs are [S T B], stb_loss, stb_outputs = mpo_ops.mpo_loss(**mpo_inputs) mean_stb_loss = jnp.mean(stb_loss) # Swap axes and try [S B T] mpo_inputs.update({ 'sample_log_probs': jnp.swapaxes(mpo_inputs['sample_log_probs'], 1, 2), 'sample_q_values': jnp.swapaxes(mpo_inputs['sample_q_values'], 1, 2), 'kl_constraints': [(jnp.swapaxes(kl, 0, 1), mpo_ops.LagrangePenalty(alpha=jnp.swapaxes(pen.alpha, 0, 1), epsilon=pen.epsilon, per_dimension=pen.per_dimension)) for (kl, pen) in kl_constraints], }) sbt_loss, sbt_outputs = mpo_ops.mpo_loss(**mpo_inputs) mean_sbt_loss = jnp.mean(sbt_loss) # Try [T B S] denoting sample_axis at 2 instead of 0. mpo_inputs.update({ 'sample_log_probs': jnp.swapaxes(mpo_inputs['sample_log_probs'], 0, 2), 'sample_q_values': jnp.swapaxes(mpo_inputs['sample_q_values'], 0, 2), 'kl_constraints': kl_constraints, # T B 'sample_axis': 2 }) tbs_loss, tbs_outputs = mpo_ops.mpo_loss(**mpo_inputs) mean_tbs_loss = jnp.mean(tbs_loss) self.assertAlmostEqual(mean_stb_loss, mean_sbt_loss, places=4) self.assertAlmostEqual(mean_tbs_loss, mean_sbt_loss, places=4) self.assertEqual(tbs_outputs.num_samples, sbt_outputs.num_samples) self.assertEqual(tbs_outputs.num_samples, stb_outputs.num_samples)