示例#1
0
def neg_log_perplexity(batch, model_predictions):
  """Calculate negative log perplexity."""
  _, targets = batch
  hot_targets = stax.one_hot(targets, model_predictions.shape[-1])
  xent = np.sum(model_predictions * hot_targets, axis=-1)
  return masked_mean(xent, targets)
示例#2
0
def loss(params, batch, model_predict, rng):
  """Calculate loss."""
  inputs, targets = batch
  preds = model_predict(params, inputs, rng=rng)
  xent = np.sum(preds * stax.one_hot(targets, preds.shape[-1]), axis=-1)
  return - masked_mean(xent, targets)
示例#3
0
def loss(params, batch, model_predict):
    """Calculate loss."""
    inputs, targets = batch
    preds = model_predict(params, inputs)
    return -np.mean(
        np.sum(preds * stax.one_hot(targets, preds.shape[-1]), axis=-1))