Exemple #1
0
def kl_alpha_loss(restarting_weights: Array,
                  kl_constraints: Sequence[Tuple[Array, LagrangePenalty]] = (),
                  axis_name: str = None):
    """Calculates the losses for multiple KL constraints.

  Args:
    restarting_weights: Restarting weights, shape E*, 0 means that this step is
      the start of a new episode and we ignore losses at this step because the
      agent cannot influence these.
    kl_constraints: KL and variables for applying Lagrangian penalties to bound
      them in the M-step, KLs are [E*, A?]. Here A is the action dimension
      in the case of per-dimension KL constraints.
    axis_name: Optional axis name for `pmap`. If `None`, computations are
      performed locally on each device.

  Returns:
    The kl loss and dual variable loss both shape E*.
  """
    chex.assert_type(restarting_weights, float)
    if kl_constraints:
        for kl, penalty in kl_constraints:
            chex.assert_rank(penalty.epsilon, 0)
            chex.assert_type([kl, penalty.alpha, penalty.epsilon], float)
            chex.assert_equal_shape_prefix([kl, restarting_weights],
                                           restarting_weights.ndim)

        # Implement decoupled KL constraints.
        kl_alpha_losses = [
            kl_constraint_loss(kl, penalty, lambda x: x)[:2]
            for kl, penalty in kl_constraints
        ]
        kl_loss, alpha_loss = [sum(losses) for losses in zip(*kl_alpha_losses)]
        all_sum = base.AllSum(axis_name)
        num_samples = all_sum(restarting_weights) + _EPSILON
        # Reshape in case KL is per dimension.
        kl_loss = all_sum(kl_loss * restarting_weights) / num_samples
        alpha_loss = all_sum(alpha_loss * restarting_weights) / num_samples
    else:
        # No M-step constraint.
        kl_loss = jnp.asarray(0.0)
        alpha_loss = jnp.asarray(0.0)
    return kl_loss, alpha_loss
Exemple #2
0
def vmpo_compute_weights_and_temperature_loss(
    advantages: Array,
    restarting_weights: Array,
    importance_weights: Array,
    temperature_constraint: LagrangePenalty,
    projection_operator: Callable[[Numeric], Numeric],
    top_k_fraction: float,
    axis_name: Optional[str] = None,
    use_stop_gradient: bool = True,
) -> Tuple[Scalar, Array, Scalar]:
  """Computes the weights and temperature loss for V-MPO.

  Args:
    advantages: Advantages for the E-step. Shape E*.
    restarting_weights: Restarting weights, 0 means that this
      step is the start of a new episode and we ignore losses at this step
      because the agent cannot influence these. Shape E*.
    importance_weights: Optional importance weights. Shape E*
    temperature_constraint: Lagrange constraint for the E-step temperature
      optimization.
    projection_operator: Function to project dual variables (temperature and kl
      constraint alphas) into the positive range.
    top_k_fraction: Fraction of samples to use in the E-step.
    axis_name: Optional axis name for `pmap` or 'vmap'. If `None`, computations
      are performed locally on each device.
    use_stop_gradient: bool indicating whether or not to apply stop gradient.

  Returns:
    The temperature loss, normalized weights and number of samples used.
  """
  chex.assert_equal_shape([advantages, restarting_weights, importance_weights])
  chex.assert_rank(temperature_constraint.epsilon, 0)
  chex.assert_type([
      advantages, restarting_weights, importance_weights,
      temperature_constraint.alpha, temperature_constraint.epsilon], float)

  importance_weights = jax.lax.select(
      use_stop_gradient, jax.lax.stop_gradient(importance_weights),
      importance_weights)

  # Lagrange constraint.
  temperature = projection_operator(temperature_constraint.alpha)
  epsilon_temperature = temperature_constraint.epsilon

  # Scale the advantages.
  scaled_advantages = restarting_weights * advantages / temperature
  max_scaled_advantage = jnp.max(scaled_advantages)
  # If the axis_name is not None find the maximum across all devices.
  if axis_name:
    assert use_stop_gradient  # Cant differentiate through pmax.
    max_scaled_advantage = jax.lax.stop_gradient(max_scaled_advantage)
    max_scaled_advantage = jax.lax.pmax(
        max_scaled_advantage, axis_name=axis_name)
  else:
    max_scaled_advantage = jax.lax.select(
        use_stop_gradient, jax.lax.stop_gradient(max_scaled_advantage),
        max_scaled_advantage)
  # Maybe don't use all of the advantages.
  top_k_restarting_weights = get_top_k_weights(
      top_k_fraction, restarting_weights, scaled_advantages, axis_name,
      use_stop_gradient)

  all_sum = base.AllSum(axis_name)

  # Reweight the old trajectories.
  unnormalized_weights = (top_k_restarting_weights * importance_weights
                          * jnp.exp(scaled_advantages - max_scaled_advantage))
  # If the axis_name is not None these sums will be taken across all devices.
  sum_weights = all_sum(unnormalized_weights) + _EPSILON
  num_samples = all_sum(top_k_restarting_weights) + _EPSILON

  normalized_weights = unnormalized_weights / sum_weights
  normalized_weights = jax.lax.select(use_stop_gradient,
                                      jax.lax.stop_gradient(normalized_weights),
                                      normalized_weights)

  # Calculate the temperature loss.
  log_mean_weights = (jnp.log(sum_weights) + max_scaled_advantage
                      - jnp.log(num_samples))
  temperature_loss = temperature * (epsilon_temperature + log_mean_weights)

  return temperature_loss, normalized_weights, num_samples