def kl_alpha_loss(restarting_weights: Array, kl_constraints: Sequence[Tuple[Array, LagrangePenalty]] = (), axis_name: str = None): """Calculates the losses for multiple KL constraints. Args: restarting_weights: Restarting weights, shape E*, 0 means that this step is the start of a new episode and we ignore losses at this step because the agent cannot influence these. kl_constraints: KL and variables for applying Lagrangian penalties to bound them in the M-step, KLs are [E*, A?]. Here A is the action dimension in the case of per-dimension KL constraints. axis_name: Optional axis name for `pmap`. If `None`, computations are performed locally on each device. Returns: The kl loss and dual variable loss both shape E*. """ chex.assert_type(restarting_weights, float) if kl_constraints: for kl, penalty in kl_constraints: chex.assert_rank(penalty.epsilon, 0) chex.assert_type([kl, penalty.alpha, penalty.epsilon], float) chex.assert_equal_shape_prefix([kl, restarting_weights], restarting_weights.ndim) # Implement decoupled KL constraints. kl_alpha_losses = [ kl_constraint_loss(kl, penalty, lambda x: x)[:2] for kl, penalty in kl_constraints ] kl_loss, alpha_loss = [sum(losses) for losses in zip(*kl_alpha_losses)] all_sum = base.AllSum(axis_name) num_samples = all_sum(restarting_weights) + _EPSILON # Reshape in case KL is per dimension. kl_loss = all_sum(kl_loss * restarting_weights) / num_samples alpha_loss = all_sum(alpha_loss * restarting_weights) / num_samples else: # No M-step constraint. kl_loss = jnp.asarray(0.0) alpha_loss = jnp.asarray(0.0) return kl_loss, alpha_loss
def vmpo_compute_weights_and_temperature_loss( advantages: Array, restarting_weights: Array, importance_weights: Array, temperature_constraint: LagrangePenalty, projection_operator: Callable[[Numeric], Numeric], top_k_fraction: float, axis_name: Optional[str] = None, use_stop_gradient: bool = True, ) -> Tuple[Scalar, Array, Scalar]: """Computes the weights and temperature loss for V-MPO. Args: advantages: Advantages for the E-step. Shape E*. restarting_weights: Restarting weights, 0 means that this step is the start of a new episode and we ignore losses at this step because the agent cannot influence these. Shape E*. importance_weights: Optional importance weights. Shape E* temperature_constraint: Lagrange constraint for the E-step temperature optimization. projection_operator: Function to project dual variables (temperature and kl constraint alphas) into the positive range. top_k_fraction: Fraction of samples to use in the E-step. axis_name: Optional axis name for `pmap` or 'vmap'. If `None`, computations are performed locally on each device. use_stop_gradient: bool indicating whether or not to apply stop gradient. Returns: The temperature loss, normalized weights and number of samples used. """ chex.assert_equal_shape([advantages, restarting_weights, importance_weights]) chex.assert_rank(temperature_constraint.epsilon, 0) chex.assert_type([ advantages, restarting_weights, importance_weights, temperature_constraint.alpha, temperature_constraint.epsilon], float) importance_weights = jax.lax.select( use_stop_gradient, jax.lax.stop_gradient(importance_weights), importance_weights) # Lagrange constraint. temperature = projection_operator(temperature_constraint.alpha) epsilon_temperature = temperature_constraint.epsilon # Scale the advantages. scaled_advantages = restarting_weights * advantages / temperature max_scaled_advantage = jnp.max(scaled_advantages) # If the axis_name is not None find the maximum across all devices. if axis_name: assert use_stop_gradient # Cant differentiate through pmax. max_scaled_advantage = jax.lax.stop_gradient(max_scaled_advantage) max_scaled_advantage = jax.lax.pmax( max_scaled_advantage, axis_name=axis_name) else: max_scaled_advantage = jax.lax.select( use_stop_gradient, jax.lax.stop_gradient(max_scaled_advantage), max_scaled_advantage) # Maybe don't use all of the advantages. top_k_restarting_weights = get_top_k_weights( top_k_fraction, restarting_weights, scaled_advantages, axis_name, use_stop_gradient) all_sum = base.AllSum(axis_name) # Reweight the old trajectories. unnormalized_weights = (top_k_restarting_weights * importance_weights * jnp.exp(scaled_advantages - max_scaled_advantage)) # If the axis_name is not None these sums will be taken across all devices. sum_weights = all_sum(unnormalized_weights) + _EPSILON num_samples = all_sum(top_k_restarting_weights) + _EPSILON normalized_weights = unnormalized_weights / sum_weights normalized_weights = jax.lax.select(use_stop_gradient, jax.lax.stop_gradient(normalized_weights), normalized_weights) # Calculate the temperature loss. log_mean_weights = (jnp.log(sum_weights) + max_scaled_advantage - jnp.log(num_samples)) temperature_loss = temperature * (epsilon_temperature + log_mean_weights) return temperature_loss, normalized_weights, num_samples