예제 #1
0
def compute_logprob(inputs, model, mask_token=None):
  """Returns an array of log probabilities for the input sequences."""

  assert inputs.ndim == 2

  targets = inputs
  weights = jnp.where(targets != model.pad_token, 1, 0)
  if mask_token is not None:
    weights *= jnp.where(targets != mask_token, 1, 0)
  logits = model.score(inputs)
  assert logits.ndim == 3

  onehot_targets = common_utils.onehot(targets, logits.shape[-1])
  log_lik = jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1)
  log_lik *= weights
  log_prob = jnp.sum(log_lik, axis=-1)

  return log_prob
예제 #2
0
def compute_weighted_cross_entropy(logits,
                                   targets,
                                   token_weights=None,
                                   example_weights=None):
    """Compute weighted cross entropy and entropy for log probs and targets.

  The loss is assumed to be sum_i example_weights[i] * logprob[i], where
  i indexes elements in the batch.

  logprob[i] is the log probability of sequence i, which is a weighted
  average of per-token log probabilities with weights according
  to token_weights. Typically token_weights is a mask for whether tokens are
  padding or not.

  Maximum likelihood training sets example_weights[i] = 1.
  Training with a REINFORCE-style objective may set example_weights[i]
  to any positive or negative number.

  Args:
   logits: [batch, length, num_classes] float array.
   targets: categorical targets [batch, length] int array.
   token_weights: None or array of shape [batch x length]
   example_weights: None or array of shape [batch_size]
  Returns:
    Tuple of scalar loss and batch normalizing factor.
  """
    if logits.ndim != targets.ndim + 1:
        raise ValueError(
            'Incorrect shapes. Got shape %s logits and %s targets' %
            (str(logits.shape), str(targets.shape)))
    onehot_targets = common_utils.onehot(targets, logits.shape[-1])
    loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1)
    normalizing_factor = onehot_targets.sum()
    if token_weights is not None:
        loss = loss * token_weights
        normalizing_factor = token_weights.sum()

    if example_weights is not None:
        loss = loss.sum(axis=1)
        loss *= example_weights

    return loss.sum(), normalizing_factor
예제 #3
0
def compute_weighted_cross_entropy(
        logits,  # 3D ndarray of floats
        targets,  # 2D ndarray of ints
        weights=None,  # 2D ndarray of floats
        label_smoothing=0.0):
    """Compute weighted cross entropy and entropy for log probs and targets.

  Args:
   logits: [batch, length, num_classes] float array.
   targets: Categorical targets [batch, length] int array.
   weights: None or array of shape [batch, length].
   label_smoothing: Label smoothing constant, used to determine the on and off
     values.

  Returns:
    Tuple of scalar loss and batch normalizing factor.
  """
    if label_smoothing is None:
        label_smoothing = 0.0
    if logits.ndim != targets.ndim + 1:
        raise ValueError(
            'Incorrect shapes. Got shape %s logits and %s targets' %
            (str(logits.shape), str(targets.shape)))
    vocab_size = logits.shape[-1]
    confidence = 1.0 - label_smoothing
    low_confidence = (1.0 - confidence) / (vocab_size - 1)
    normalizing_constant = -(
        confidence * jnp.log(confidence) +
        (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20))
    soft_targets = common_utils.onehot(targets,
                                       vocab_size,
                                       on_value=confidence,
                                       off_value=low_confidence)

    loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1)
    loss = loss - normalizing_constant
    return weight_loss(loss, targets, weights)