Exemplo n.º 1
0
def _category_cross_entropy(  # pylint: disable=invalid-name
    model_output, targets, label_smoothing):
  """Computes category cross entropy with label smoothing."""
  n_categories = model_output.shape[-1]
  target_distributions = core.one_hot(targets, n_categories)
  if label_smoothing:
    if label_smoothing < 0. or label_smoothing > 1.:
      raise ValueError(
          f'Arg label_smoothing ({label_smoothing}) must be between 0 and 1.')
    target_distributions *= (1. - label_smoothing)
    target_distributions += label_smoothing / n_categories
  model_log_distributions = core.log_softmax(model_output)
  return - jnp.sum(target_distributions * model_log_distributions, axis=-1)
Exemplo n.º 2
0
def _category_cross_entropy(model_output, targets):  # pylint: disable=invalid-name
    target_distributions = core.one_hot(targets, model_output.shape[-1])
    model_log_distributions = core.log_softmax(model_output)
    return -jnp.sum(target_distributions * model_log_distributions, axis=-1)