예제 #1
0
    def encode(self, encoder_input_tokens, encoder_padding_mask):
        """Encodes an input sequence with a particular encoder_mask.

    Args:
      encoder_input_tokens: (seq_len,) array to be encoded.
      encoder_padding_mask: (seq_len,) mask array to use for masking.

    Returns:
      A (seq_len, emb_dim) array of encoded inputs.
    """

        chex.assert_rank([encoder_input_tokens, encoder_padding_mask], [1, 1])
        chex.assert_type([encoder_input_tokens, encoder_padding_mask],
                         [int, bool])

        (encoder_input_tokens,
         encoder_padding_mask) = expand_dims(encoder_input_tokens,
                                             encoder_padding_mask)

        encoded = self.module.encode(encoder_input_tokens=encoder_input_tokens,
                                     encoder_padding_mask=encoder_padding_mask,
                                     deterministic=not self.train)

        chex.assert_rank(encoded, 3)
        chex.assert_equal_shape_prefix([encoded, encoder_input_tokens], 2)

        return encoded[0]
예제 #2
0
def implicit_least_squares(phis, phis_for_wi1,
                           phis_for_wi2, phis_for_cov1,
                           phis_for_cov2, psis,
                           psis_for_wi1, psis_for_wi2,
                           alpha):
  """Implicit least squares objective."""
  # Make sure all the shapes agree
  chex.assert_equal_shape([phis_for_cov1, phis_for_cov2])
  chex.assert_equal_shape([phis_for_wi1, phis_for_wi2])
  chex.assert_equal_shape([psis_for_wi1, psis_for_wi2])
  chex.assert_equal_shape_prefix([phis, psis], 1)
  chex.assert_scalar(alpha)
  chex.assert_rank([
      phis,
      phis_for_cov1,
      phis_for_cov2,
      phis_for_wi1,
      phis_for_wi2,
      psis,
      psis_for_wi1,
      psis_for_wi2,
  ], 2)

  # Get w_1 estimate on the forward pass when computing the loss
  w = weight_estimate(phis_for_cov1, phis_for_wi1, psis_for_wi1, alpha=alpha)
  # w = weight_estimate(phis_for_cov1, phis, psis, alpha=alpha)
  # Predict using w_1
  predictions = phis @ w
  # Least-squares cost
  cost = predictions - psis
  # MSE Loss
  mse = 0.5 * jnp.mean(cost**2)

  return mse
예제 #3
0
def implicit_least_squares_fwd(
    phis, phis_for_wi1, phis_for_wi2,
    phis_for_cov1, phis_for_cov2, psis,
    psis_for_wi1, psis_for_wi2,
    alpha):
  """Forward pass for implicit least squares objective."""
  chex.assert_equal_shape([phis_for_cov1, phis_for_cov2])
  chex.assert_equal_shape([phis_for_wi1, phis_for_wi2])
  chex.assert_equal_shape([psis_for_wi1, psis_for_wi2])
  chex.assert_equal_shape_prefix([phis, psis], 1)
  chex.assert_scalar(alpha)
  chex.assert_rank([
      phis,
      phis_for_cov1,
      phis_for_cov2,
      phis_for_wi1,
      phis_for_wi2,
      psis,
      psis_for_wi1,
      psis_for_wi2,
  ], 2)

  # Get w_1 estimate on the forward pass when computing the loss
  w = weight_estimate(phis_for_cov1, phis_for_wi1, psis_for_wi1, alpha=alpha)
  # w = weight_estimate(phis_for_cov1, phis, psis, alpha=alpha)

  # Predict using w_1
  predictions = phis @ w
  # Least-squares cost
  cost = predictions - psis
  # MSE Loss
  mse = implicit_least_squares(
      phis,
      phis_for_wi1,
      phis_for_wi2,
      phis_for_cov1,
      phis_for_cov2,
      psis,
      psis_for_wi1,
      psis_for_wi2,
      alpha=alpha)

  # Return appropriate residuals so we can compute w_2 on backward pass
  return mse, (cost, phis_for_cov2, phis_for_wi2, psis_for_wi2)
예제 #4
0
파일: mpo_ops.py 프로젝트: mhadsouza/rlax
def kl_alpha_loss(restarting_weights: Array,
                  kl_constraints: Sequence[Tuple[Array, LagrangePenalty]] = (),
                  axis_name: str = None):
    """Calculates the losses for multiple KL constraints.

  Args:
    restarting_weights: Restarting weights, shape E*, 0 means that this step is
      the start of a new episode and we ignore losses at this step because the
      agent cannot influence these.
    kl_constraints: KL and variables for applying Lagrangian penalties to bound
      them in the M-step, KLs are [E*, A?]. Here A is the action dimension
      in the case of per-dimension KL constraints.
    axis_name: Optional axis name for `pmap`. If `None`, computations are
      performed locally on each device.

  Returns:
    The kl loss and dual variable loss both shape E*.
  """
    chex.assert_type(restarting_weights, float)
    if kl_constraints:
        for kl, penalty in kl_constraints:
            chex.assert_rank(penalty.epsilon, 0)
            chex.assert_type([kl, penalty.alpha, penalty.epsilon], float)
            chex.assert_equal_shape_prefix([kl, restarting_weights],
                                           restarting_weights.ndim)

        # Implement decoupled KL constraints.
        kl_alpha_losses = [
            kl_constraint_loss(kl, penalty, lambda x: x)[:2]
            for kl, penalty in kl_constraints
        ]
        kl_loss, alpha_loss = [sum(losses) for losses in zip(*kl_alpha_losses)]
        all_sum = base.AllSum(axis_name)
        num_samples = all_sum(restarting_weights) + _EPSILON
        # Reshape in case KL is per dimension.
        kl_loss = all_sum(kl_loss * restarting_weights) / num_samples
        alpha_loss = all_sum(alpha_loss * restarting_weights) / num_samples
    else:
        # No M-step constraint.
        kl_loss = jnp.asarray(0.0)
        alpha_loss = jnp.asarray(0.0)
    return kl_loss, alpha_loss
예제 #5
0
    def decode(self,
               encoded,
               decoder_input_tokens,
               encoder_padding_mask,
               decoder_padding_mask,
               timestep=None):
        """Applies Transformer decoder-branch on encoded-input and target.

    Args:
      encoded: encoded input data from encoder.
      decoder_input_tokens: input token to the decoder.
      encoder_padding_mask: padding mask for encoder.
      decoder_padding_mask: padding mask for decoder.
      timestep: optionally, a timestep to condition the input on.

    Returns:
      logits array from transformer decoder.
    """
        chex.assert_rank([decoder_input_tokens, decoder_padding_mask], [1, 1])

        if encoded is not None:
            chex.assert_rank([encoded, encoder_padding_mask], [2, 1])

        (encoded, decoder_input_tokens, encoder_padding_mask,
         decoder_padding_mask,
         timestep) = expand_dims(encoded, decoder_input_tokens,
                                 encoder_padding_mask, decoder_padding_mask,
                                 timestep)

        logits = self.module.decode(
            encoded=encoded,
            decoder_input_tokens=decoder_input_tokens,
            encoder_padding_mask=encoder_padding_mask,
            decoder_padding_mask=decoder_padding_mask,
            timestep=timestep,
            deterministic=not self.train,
        )

        chex.assert_equal_shape_prefix([logits, decoder_input_tokens], 2)

        return logits[0]
예제 #6
0
파일: mpo_ops.py 프로젝트: deepmind/rlax
def vmpo_loss(
    sample_log_probs: Array,
    advantages: Array,
    temperature_constraint: LagrangePenalty,
    kl_constraints: Sequence[Tuple[Array, LagrangePenalty]],
    projection_operator: Callable[[Numeric], Numeric] = functools.partial(
        jnp.clip, a_min=_EPSILON),
    restarting_weights: Optional[Array] = None,
    importance_weights: Optional[Array] = None,
    top_k_fraction: float = 0.5,
    policy_loss_weight: float = 1.0,
    temperature_loss_weight: float = 1.0,
    kl_loss_weight: float = 1.0,
    alpha_loss_weight: float = 1.0,
    axis_name: Optional[str] = None,
    use_stop_gradient: bool = True,
) -> Tuple[Array, MpoOutputs]:
  """Calculates the V-MPO policy improvement loss.

  Note: This is a per-example loss which works on any shape inputs as long as
  they are consistent. We denote the shape of the examples E* for ease of
  reference.

  Args:
    sample_log_probs: Log probabilities of actions for each example. Shape E*.
    advantages: Advantages for the E-step. Shape E*.
    temperature_constraint: Lagrange constraint for the E-step temperature
      optimization.
    kl_constraints: KL and variables for applying Lagrangian penalties to bound
      them in the M-step, KLs are E* or [E*, A]. Here A is the action dimension
      in the case of per-dimension KL constraints.
    projection_operator: Function to project dual variables (temperature and kl
      constraint alphas) into the positive range.
    restarting_weights: Optional restarting weights, shape E*, 0 means that this
      step is the start of a new episode and we ignore losses at this step
      because the agent cannot influence these.
    importance_weights: Optional importance weights, shape E*.
    top_k_fraction: Fraction of samples to use in the E-step.
    policy_loss_weight: Weight for the policy loss.
    temperature_loss_weight: Weight for the temperature loss.
    kl_loss_weight: Weight for the KL loss.
    alpha_loss_weight: Weight for the alpha loss.
    axis_name: Optional axis name for `pmap`. If `None`, computations
      are performed locally on each device.
    use_stop_gradient: bool indicating whether or not to apply stop gradient.

  Returns:
    Per example `loss` with same shape E* as array inputs, and additional data
    including the components of this loss and the normalized weights in the
    AdditionalOutputs.
  """
  # Define default restarting weights and importance weights.
  if restarting_weights is None:
    restarting_weights = jnp.ones_like(sample_log_probs)
  if importance_weights is None:
    importance_weights = jnp.ones_like(sample_log_probs)

  # Check shapes.
  chex.assert_equal_shape(
      [advantages, sample_log_probs, restarting_weights, importance_weights])

  chex.assert_rank(temperature_constraint.epsilon, 0)
  chex.assert_type([
      sample_log_probs, advantages, restarting_weights, importance_weights,
      temperature_constraint.alpha, temperature_constraint.epsilon], float)

  for kl, penalty in kl_constraints:
    chex.assert_rank(penalty.epsilon, 0)
    chex.assert_type([kl, penalty.alpha, penalty.epsilon], float)
    if penalty.per_dimension:
      chex.assert_rank(kl, advantages.ndim + 1)
      chex.assert_equal_shape_prefix([kl, advantages], advantages.ndim)
    else:
      chex.assert_equal_shape([kl, advantages])

  # E-step: Calculate the reweighting and the temperature loss.
  temperature_loss, norm_weights, num_samples = (
      vmpo_compute_weights_and_temperature_loss(
          advantages, restarting_weights, importance_weights,
          temperature_constraint, projection_operator, top_k_fraction,
          axis_name=axis_name, use_stop_gradient=use_stop_gradient))

  # M-step: Supervised learning of reweighted trajectories using the weights
  # from the E-step, with additional KL constraints.
  # The weights are normalized so that the sum is 1. We multiply by the number
  # of examples so that we can give a policy loss per example and take the mean,
  # and we assume `restarting_weights` are already included.
  if axis_name:
    num_examples = jax.lax.all_gather(
        sample_log_probs, axis_name=axis_name).size
  else:
    num_examples = sample_log_probs.size
  policy_loss = -sample_log_probs * norm_weights * num_examples

  kl_loss, alpha_loss = compute_parametric_kl_penalty_and_dual_loss(
      kl_constraints, projection_operator, use_stop_gradient)

  chex.assert_equal_shape([policy_loss, kl_loss, alpha_loss])

  # Calculate the total policy improvement loss.
  loss = (policy_loss_weight * policy_loss +
          temperature_loss_weight * temperature_loss +
          kl_loss_weight * kl_loss +
          alpha_loss_weight * alpha_loss)

  return loss, MpoOutputs(
      temperature_loss=temperature_loss, policy_loss=policy_loss,
      kl_loss=kl_loss, alpha_loss=alpha_loss, normalized_weights=norm_weights,
      num_samples=num_samples)