def encode(self, encoder_input_tokens, encoder_padding_mask): """Encodes an input sequence with a particular encoder_mask. Args: encoder_input_tokens: (seq_len,) array to be encoded. encoder_padding_mask: (seq_len,) mask array to use for masking. Returns: A (seq_len, emb_dim) array of encoded inputs. """ chex.assert_rank([encoder_input_tokens, encoder_padding_mask], [1, 1]) chex.assert_type([encoder_input_tokens, encoder_padding_mask], [int, bool]) (encoder_input_tokens, encoder_padding_mask) = expand_dims(encoder_input_tokens, encoder_padding_mask) encoded = self.module.encode(encoder_input_tokens=encoder_input_tokens, encoder_padding_mask=encoder_padding_mask, deterministic=not self.train) chex.assert_rank(encoded, 3) chex.assert_equal_shape_prefix([encoded, encoder_input_tokens], 2) return encoded[0]
def implicit_least_squares(phis, phis_for_wi1, phis_for_wi2, phis_for_cov1, phis_for_cov2, psis, psis_for_wi1, psis_for_wi2, alpha): """Implicit least squares objective.""" # Make sure all the shapes agree chex.assert_equal_shape([phis_for_cov1, phis_for_cov2]) chex.assert_equal_shape([phis_for_wi1, phis_for_wi2]) chex.assert_equal_shape([psis_for_wi1, psis_for_wi2]) chex.assert_equal_shape_prefix([phis, psis], 1) chex.assert_scalar(alpha) chex.assert_rank([ phis, phis_for_cov1, phis_for_cov2, phis_for_wi1, phis_for_wi2, psis, psis_for_wi1, psis_for_wi2, ], 2) # Get w_1 estimate on the forward pass when computing the loss w = weight_estimate(phis_for_cov1, phis_for_wi1, psis_for_wi1, alpha=alpha) # w = weight_estimate(phis_for_cov1, phis, psis, alpha=alpha) # Predict using w_1 predictions = phis @ w # Least-squares cost cost = predictions - psis # MSE Loss mse = 0.5 * jnp.mean(cost**2) return mse
def implicit_least_squares_fwd( phis, phis_for_wi1, phis_for_wi2, phis_for_cov1, phis_for_cov2, psis, psis_for_wi1, psis_for_wi2, alpha): """Forward pass for implicit least squares objective.""" chex.assert_equal_shape([phis_for_cov1, phis_for_cov2]) chex.assert_equal_shape([phis_for_wi1, phis_for_wi2]) chex.assert_equal_shape([psis_for_wi1, psis_for_wi2]) chex.assert_equal_shape_prefix([phis, psis], 1) chex.assert_scalar(alpha) chex.assert_rank([ phis, phis_for_cov1, phis_for_cov2, phis_for_wi1, phis_for_wi2, psis, psis_for_wi1, psis_for_wi2, ], 2) # Get w_1 estimate on the forward pass when computing the loss w = weight_estimate(phis_for_cov1, phis_for_wi1, psis_for_wi1, alpha=alpha) # w = weight_estimate(phis_for_cov1, phis, psis, alpha=alpha) # Predict using w_1 predictions = phis @ w # Least-squares cost cost = predictions - psis # MSE Loss mse = implicit_least_squares( phis, phis_for_wi1, phis_for_wi2, phis_for_cov1, phis_for_cov2, psis, psis_for_wi1, psis_for_wi2, alpha=alpha) # Return appropriate residuals so we can compute w_2 on backward pass return mse, (cost, phis_for_cov2, phis_for_wi2, psis_for_wi2)
def kl_alpha_loss(restarting_weights: Array, kl_constraints: Sequence[Tuple[Array, LagrangePenalty]] = (), axis_name: str = None): """Calculates the losses for multiple KL constraints. Args: restarting_weights: Restarting weights, shape E*, 0 means that this step is the start of a new episode and we ignore losses at this step because the agent cannot influence these. kl_constraints: KL and variables for applying Lagrangian penalties to bound them in the M-step, KLs are [E*, A?]. Here A is the action dimension in the case of per-dimension KL constraints. axis_name: Optional axis name for `pmap`. If `None`, computations are performed locally on each device. Returns: The kl loss and dual variable loss both shape E*. """ chex.assert_type(restarting_weights, float) if kl_constraints: for kl, penalty in kl_constraints: chex.assert_rank(penalty.epsilon, 0) chex.assert_type([kl, penalty.alpha, penalty.epsilon], float) chex.assert_equal_shape_prefix([kl, restarting_weights], restarting_weights.ndim) # Implement decoupled KL constraints. kl_alpha_losses = [ kl_constraint_loss(kl, penalty, lambda x: x)[:2] for kl, penalty in kl_constraints ] kl_loss, alpha_loss = [sum(losses) for losses in zip(*kl_alpha_losses)] all_sum = base.AllSum(axis_name) num_samples = all_sum(restarting_weights) + _EPSILON # Reshape in case KL is per dimension. kl_loss = all_sum(kl_loss * restarting_weights) / num_samples alpha_loss = all_sum(alpha_loss * restarting_weights) / num_samples else: # No M-step constraint. kl_loss = jnp.asarray(0.0) alpha_loss = jnp.asarray(0.0) return kl_loss, alpha_loss
def decode(self, encoded, decoder_input_tokens, encoder_padding_mask, decoder_padding_mask, timestep=None): """Applies Transformer decoder-branch on encoded-input and target. Args: encoded: encoded input data from encoder. decoder_input_tokens: input token to the decoder. encoder_padding_mask: padding mask for encoder. decoder_padding_mask: padding mask for decoder. timestep: optionally, a timestep to condition the input on. Returns: logits array from transformer decoder. """ chex.assert_rank([decoder_input_tokens, decoder_padding_mask], [1, 1]) if encoded is not None: chex.assert_rank([encoded, encoder_padding_mask], [2, 1]) (encoded, decoder_input_tokens, encoder_padding_mask, decoder_padding_mask, timestep) = expand_dims(encoded, decoder_input_tokens, encoder_padding_mask, decoder_padding_mask, timestep) logits = self.module.decode( encoded=encoded, decoder_input_tokens=decoder_input_tokens, encoder_padding_mask=encoder_padding_mask, decoder_padding_mask=decoder_padding_mask, timestep=timestep, deterministic=not self.train, ) chex.assert_equal_shape_prefix([logits, decoder_input_tokens], 2) return logits[0]
def vmpo_loss( sample_log_probs: Array, advantages: Array, temperature_constraint: LagrangePenalty, kl_constraints: Sequence[Tuple[Array, LagrangePenalty]], projection_operator: Callable[[Numeric], Numeric] = functools.partial( jnp.clip, a_min=_EPSILON), restarting_weights: Optional[Array] = None, importance_weights: Optional[Array] = None, top_k_fraction: float = 0.5, policy_loss_weight: float = 1.0, temperature_loss_weight: float = 1.0, kl_loss_weight: float = 1.0, alpha_loss_weight: float = 1.0, axis_name: Optional[str] = None, use_stop_gradient: bool = True, ) -> Tuple[Array, MpoOutputs]: """Calculates the V-MPO policy improvement loss. Note: This is a per-example loss which works on any shape inputs as long as they are consistent. We denote the shape of the examples E* for ease of reference. Args: sample_log_probs: Log probabilities of actions for each example. Shape E*. advantages: Advantages for the E-step. Shape E*. temperature_constraint: Lagrange constraint for the E-step temperature optimization. kl_constraints: KL and variables for applying Lagrangian penalties to bound them in the M-step, KLs are E* or [E*, A]. Here A is the action dimension in the case of per-dimension KL constraints. projection_operator: Function to project dual variables (temperature and kl constraint alphas) into the positive range. restarting_weights: Optional restarting weights, shape E*, 0 means that this step is the start of a new episode and we ignore losses at this step because the agent cannot influence these. importance_weights: Optional importance weights, shape E*. top_k_fraction: Fraction of samples to use in the E-step. policy_loss_weight: Weight for the policy loss. temperature_loss_weight: Weight for the temperature loss. kl_loss_weight: Weight for the KL loss. alpha_loss_weight: Weight for the alpha loss. axis_name: Optional axis name for `pmap`. If `None`, computations are performed locally on each device. use_stop_gradient: bool indicating whether or not to apply stop gradient. Returns: Per example `loss` with same shape E* as array inputs, and additional data including the components of this loss and the normalized weights in the AdditionalOutputs. """ # Define default restarting weights and importance weights. if restarting_weights is None: restarting_weights = jnp.ones_like(sample_log_probs) if importance_weights is None: importance_weights = jnp.ones_like(sample_log_probs) # Check shapes. chex.assert_equal_shape( [advantages, sample_log_probs, restarting_weights, importance_weights]) chex.assert_rank(temperature_constraint.epsilon, 0) chex.assert_type([ sample_log_probs, advantages, restarting_weights, importance_weights, temperature_constraint.alpha, temperature_constraint.epsilon], float) for kl, penalty in kl_constraints: chex.assert_rank(penalty.epsilon, 0) chex.assert_type([kl, penalty.alpha, penalty.epsilon], float) if penalty.per_dimension: chex.assert_rank(kl, advantages.ndim + 1) chex.assert_equal_shape_prefix([kl, advantages], advantages.ndim) else: chex.assert_equal_shape([kl, advantages]) # E-step: Calculate the reweighting and the temperature loss. temperature_loss, norm_weights, num_samples = ( vmpo_compute_weights_and_temperature_loss( advantages, restarting_weights, importance_weights, temperature_constraint, projection_operator, top_k_fraction, axis_name=axis_name, use_stop_gradient=use_stop_gradient)) # M-step: Supervised learning of reweighted trajectories using the weights # from the E-step, with additional KL constraints. # The weights are normalized so that the sum is 1. We multiply by the number # of examples so that we can give a policy loss per example and take the mean, # and we assume `restarting_weights` are already included. if axis_name: num_examples = jax.lax.all_gather( sample_log_probs, axis_name=axis_name).size else: num_examples = sample_log_probs.size policy_loss = -sample_log_probs * norm_weights * num_examples kl_loss, alpha_loss = compute_parametric_kl_penalty_and_dual_loss( kl_constraints, projection_operator, use_stop_gradient) chex.assert_equal_shape([policy_loss, kl_loss, alpha_loss]) # Calculate the total policy improvement loss. loss = (policy_loss_weight * policy_loss + temperature_loss_weight * temperature_loss + kl_loss_weight * kl_loss + alpha_loss_weight * alpha_loss) return loss, MpoOutputs( temperature_loss=temperature_loss, policy_loss=policy_loss, kl_loss=kl_loss, alpha_loss=alpha_loss, normalized_weights=norm_weights, num_samples=num_samples)