示例#1
0
  def test_kl_constraint_loss_stop_gradients(self):
    """Tests the stop gradients in the `kl_constraint_loss` function.

      The `alpha_loss` term should not affect the KL and the `kl` term should
      not affect `alpha`.
    """
    kl = jnp.array(1., jnp.float32)
    alpha = jnp.array(1., jnp.float32)
    _, _, alpha = mpo_ops.kl_constraint_loss(kl, mpo_ops.LagrangePenalty(
        alpha=alpha, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False),
                                             _PROJECTION_OPERATOR)

    def kl_loss_fn(alpha_):
      penalty = mpo_ops.LagrangePenalty(
          alpha=alpha_, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False)
      kl_loss, _, _ = mpo_ops.kl_constraint_loss(
          kl, penalty, _PROJECTION_OPERATOR)
      return kl_loss

    kl_gradients = jax.grad(kl_loss_fn)(alpha)

    def alpha_loss_fn(kl_):
      penalty = mpo_ops.LagrangePenalty(
          alpha=alpha, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False)
      _, alpha_loss, _ = mpo_ops.kl_constraint_loss(
          kl_, penalty, _PROJECTION_OPERATOR)
      return alpha_loss
    alpha_gradients = jax.grad(alpha_loss_fn)(kl)

    # Test that there are no gradients of KL w.r.t alpha
    self.assertEqual(kl_gradients, 0.)

    # Test that there are no gradients of alpha w.r.t kl
    self.assertEqual(alpha_gradients, 0.)
示例#2
0
  def test_vmpo_input_axis_order_equivalence(self, per_dimension):
    """Test loss functions are equivalent regardless of axis order."""
    key = jax.random.PRNGKey(_RANDOM_SEED)
    key, new_key = jax.random.split(key)
    params = _init_params(new_key)
    out, vmpo_inputs = get_common_loss_fn_inputs(params, key, 'advantages')
    kl_constraints = get_coupled_kl_constraints(out, params,
                                                per_dimension=per_dimension)
    vmpo_inputs.update({'kl_constraints': kl_constraints})

    # Original loss fn inputs are [T B],
    tb_loss, tb_outputs = mpo_ops.vmpo_loss(**vmpo_inputs)
    mean_tb_loss = jnp.mean(tb_loss)

    # Swap axes and try [B T]
    vmpo_inputs.update({
        'sample_log_probs': jnp.swapaxes(vmpo_inputs['sample_log_probs'], 0, 1),
        'advantages': jnp.swapaxes(vmpo_inputs['advantages'], 0, 1),
        'kl_constraints': [(jnp.swapaxes(kl, 0, 1), mpo_ops.LagrangePenalty(
            alpha=jnp.swapaxes(pen.alpha, 0, 1), epsilon=pen.epsilon,
            per_dimension=pen.per_dimension)) for (kl, pen) in kl_constraints],
    })
    bt_loss, bt_outputs = mpo_ops.vmpo_loss(**vmpo_inputs)
    mean_bt_loss = jnp.mean(bt_loss)

    self.assertAlmostEqual(mean_tb_loss, mean_bt_loss, places=4)
    self.assertEqual(tb_outputs.num_samples, bt_outputs.num_samples)
示例#3
0
 def alpha_loss_fn(kl_):
     penalty = mpo_ops.LagrangePenalty(alpha=alpha,
                                       epsilon=_EPSILON_MEAN_BOUND,
                                       per_dimension=False)
     _, alpha_loss, _ = mpo_ops.kl_constraint_loss(
         kl_, penalty, _PROJECTION_OPERATOR)
     return alpha_loss
示例#4
0
  def test_kl_constraint_loss_gradients(self):
    """Tests the gradients in the `_kl_constraint_loss` method."""
    kl = jnp.array(1., jnp.float32)
    alpha = jnp.array(1., jnp.float32)
    _, _, alpha = mpo_ops.kl_constraint_loss(kl, mpo_ops.LagrangePenalty(
        alpha=alpha, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False),
                                             _PROJECTION_OPERATOR)

    def alpha_loss_fn(alpha_):
      penalty = mpo_ops.LagrangePenalty(
          alpha=alpha_, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False)
      _, alpha_loss, _ = mpo_ops.kl_constraint_loss(
          kl, penalty, _PROJECTION_OPERATOR)
      return alpha_loss
    alpha_gradients = jax.grad(alpha_loss_fn)(alpha)
    actual_alpha_gradients = _EPSILON_MEAN_BOUND - kl

    def kl_loss_fn(kl_):
      penalty = mpo_ops.LagrangePenalty(
          alpha=alpha, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False)
      kl_loss, _, _ = mpo_ops.kl_constraint_loss(
          kl_, penalty, _PROJECTION_OPERATOR)
      return kl_loss
    kl_gradients = jax.grad(kl_loss_fn)(kl)
    actual_kl_gradients = alpha

    self.assertAlmostEqual(kl_gradients, actual_kl_gradients)
    self.assertAlmostEqual(alpha_gradients, actual_alpha_gradients)
示例#5
0
 def mean_weights_fn(target_, temperature_):
   temperature_constraint = mpo_ops.LagrangePenalty(
       temperature_, _EPSILON_BOUND)
   _, weights, _ = e_step_fn(
       target_, temperature_constraint=temperature_constraint,
       projection_operator=_PROJECTION_OPERATOR,
       **additional_inputs)
   return jnp.mean(weights)
示例#6
0
 def test_importance_weights(
     self, advantages, importance_weights, expected_temperature_loss):
   """Test that importance weights have the correct effect."""
   temperature_loss, _, _ = mpo_ops.vmpo_compute_weights_and_temperature_loss(
       advantages, np.ones_like(importance_weights), importance_weights,
       mpo_ops.LagrangePenalty(1.0, _EPSILON_BOUND),
       functools.partial(np.clip, a_min=1e-8, a_max=None), 1.0)
   self.assertAlmostEqual(
       temperature_loss, expected_temperature_loss, places=4)
示例#7
0
 def test_restarting_weights(
     self, advantages, restarting_weights, expected_temperature_loss):
   """Test that calculation is correct if restarting weight is set to 0."""
   temperature_loss, _, _ = mpo_ops.vmpo_compute_weights_and_temperature_loss(
       advantages, restarting_weights, np.ones_like(restarting_weights),
       mpo_ops.LagrangePenalty(1.0, _EPSILON_BOUND),
       functools.partial(np.clip, a_min=1e-8, a_max=None), 1.0)
   self.assertAlmostEqual(
       temperature_loss, expected_temperature_loss, places=4)
示例#8
0
def get_decoupled_kl_constraints(out, params, per_dimension):
  # Factorize KL for Gaussian.
  kl_mean, kl_covariance = (
      distributions.decoupled_multivariate_normal_kl_divergence(
          out['target_pi_params']['mean'], out['target_pi_params']['stddev'],
          out['pi_params']['mean'], out['pi_params']['stddev'],
          per_dimension=per_dimension))
  alpha_mean = params['mpo']['alpha_mean'] * jnp.ones_like(kl_mean)
  alpha_covariance = params['mpo']['alpha_covariance'] * jnp.ones_like(
      kl_covariance)

  return [
      (kl_mean, mpo_ops.LagrangePenalty(
          alpha=alpha_mean, epsilon=_EPSILON_MEAN_BOUND,
          per_dimension=per_dimension)),
      (kl_covariance, mpo_ops.LagrangePenalty(
          alpha=alpha_covariance, epsilon=_EPSILON_COVARIANCE_BOUND,
          per_dimension=per_dimension)),
  ]
示例#9
0
def get_common_loss_fn_inputs(params, key, target_name):
  out = _mock_outputs(params['online'], params['target'], key, target_name)
  pi_sample_log_probs = _DIAGONAL_GAUSSIAN_DIST.logprob(
      out['target_actions'], out['pi_params']['mean'],
      out['pi_params']['stddev'])

  return out, {
      'sample_log_probs': pi_sample_log_probs,
      target_name: out[target_name],
      'temperature_constraint': mpo_ops.LagrangePenalty(
          params['mpo']['temperature'], _EPSILON_BOUND)}
示例#10
0
def get_coupled_kl_constraints(out, params, per_dimension):
    kl_mean, kl_covariance = (_decoupled_multivariate_normal_kl_divergence(
        out['target_pi_params']['mean'],
        out['target_pi_params']['stddev'],
        out['pi_params']['mean'],
        out['pi_params']['stddev'],
        per_dimension=per_dimension))
    alpha_mean = params['mpo']['alpha_mean'] * jnp.ones_like(kl_mean)
    return [(kl_mean + kl_covariance,
             mpo_ops.LagrangePenalty(alpha=alpha_mean,
                                     epsilon=_EPSILON_MEAN_BOUND +
                                     _EPSILON_COVARIANCE_BOUND,
                                     per_dimension=per_dimension))]
示例#11
0
    def test_mpo_input_axis_order_equivalence(self, per_dimension):
        """Test loss functions are equivalent regardless of axis order."""
        key = jax.random.PRNGKey(_RANDOM_SEED)
        key, new_key = jax.random.split(key)
        params = _init_params(new_key)
        out, mpo_inputs = get_common_loss_fn_inputs(params, key,
                                                    'sample_q_values')
        kl_constraints = get_coupled_kl_constraints(
            out, params, per_dimension=per_dimension)
        mpo_inputs.update({'kl_constraints': kl_constraints})

        # Original loss fn inputs are [S T B],
        stb_loss, stb_outputs = mpo_ops.mpo_loss(**mpo_inputs)
        mean_stb_loss = jnp.mean(stb_loss)

        # Swap axes and try [S B T]
        mpo_inputs.update({
            'sample_log_probs':
            jnp.swapaxes(mpo_inputs['sample_log_probs'], 1, 2),
            'sample_q_values':
            jnp.swapaxes(mpo_inputs['sample_q_values'], 1, 2),
            'kl_constraints':
            [(jnp.swapaxes(kl, 0, 1),
              mpo_ops.LagrangePenalty(alpha=jnp.swapaxes(pen.alpha, 0, 1),
                                      epsilon=pen.epsilon,
                                      per_dimension=pen.per_dimension))
             for (kl, pen) in kl_constraints],
        })
        sbt_loss, sbt_outputs = mpo_ops.mpo_loss(**mpo_inputs)
        mean_sbt_loss = jnp.mean(sbt_loss)

        # Try [T B S] denoting sample_axis at 2 instead of 0.
        mpo_inputs.update({
            'sample_log_probs':
            jnp.swapaxes(mpo_inputs['sample_log_probs'], 0, 2),
            'sample_q_values':
            jnp.swapaxes(mpo_inputs['sample_q_values'], 0, 2),
            'kl_constraints':
            kl_constraints,  # T B
            'sample_axis':
            2
        })
        tbs_loss, tbs_outputs = mpo_ops.mpo_loss(**mpo_inputs)
        mean_tbs_loss = jnp.mean(tbs_loss)

        self.assertAlmostEqual(mean_stb_loss, mean_sbt_loss, places=4)
        self.assertAlmostEqual(mean_tbs_loss, mean_sbt_loss, places=4)
        self.assertEqual(tbs_outputs.num_samples, sbt_outputs.num_samples)
        self.assertEqual(tbs_outputs.num_samples, stb_outputs.num_samples)