def compute_logprob(inputs, model, mask_token=None): """Returns an array of log probabilities for the input sequences.""" assert inputs.ndim == 2 targets = inputs weights = jnp.where(targets != model.pad_token, 1, 0) if mask_token is not None: weights *= jnp.where(targets != mask_token, 1, 0) logits = model.score(inputs) assert logits.ndim == 3 onehot_targets = common_utils.onehot(targets, logits.shape[-1]) log_lik = jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1) log_lik *= weights log_prob = jnp.sum(log_lik, axis=-1) return log_prob
def compute_weighted_cross_entropy(logits, targets, token_weights=None, example_weights=None): """Compute weighted cross entropy and entropy for log probs and targets. The loss is assumed to be sum_i example_weights[i] * logprob[i], where i indexes elements in the batch. logprob[i] is the log probability of sequence i, which is a weighted average of per-token log probabilities with weights according to token_weights. Typically token_weights is a mask for whether tokens are padding or not. Maximum likelihood training sets example_weights[i] = 1. Training with a REINFORCE-style objective may set example_weights[i] to any positive or negative number. Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. token_weights: None or array of shape [batch x length] example_weights: None or array of shape [batch_size] Returns: Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape))) onehot_targets = common_utils.onehot(targets, logits.shape[-1]) loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1) normalizing_factor = onehot_targets.sum() if token_weights is not None: loss = loss * token_weights normalizing_factor = token_weights.sum() if example_weights is not None: loss = loss.sum(axis=1) loss *= example_weights return loss.sum(), normalizing_factor
def compute_weighted_cross_entropy( logits, # 3D ndarray of floats targets, # 2D ndarray of ints weights=None, # 2D ndarray of floats label_smoothing=0.0): """Compute weighted cross entropy and entropy for log probs and targets. Args: logits: [batch, length, num_classes] float array. targets: Categorical targets [batch, length] int array. weights: None or array of shape [batch, length]. label_smoothing: Label smoothing constant, used to determine the on and off values. Returns: Tuple of scalar loss and batch normalizing factor. """ if label_smoothing is None: label_smoothing = 0.0 if logits.ndim != targets.ndim + 1: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape))) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)) soft_targets = common_utils.onehot(targets, vocab_size, on_value=confidence, off_value=low_confidence) loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1) loss = loss - normalizing_constant return weight_loss(loss, targets, weights)