def _category_cross_entropy( # pylint: disable=invalid-name model_output, targets, label_smoothing): """Computes category cross entropy with label smoothing.""" n_categories = model_output.shape[-1] target_distributions = core.one_hot(targets, n_categories) if label_smoothing: if label_smoothing < 0. or label_smoothing > 1.: raise ValueError( f'Arg label_smoothing ({label_smoothing}) must be between 0 and 1.') target_distributions *= (1. - label_smoothing) target_distributions += label_smoothing / n_categories model_log_distributions = core.log_softmax(model_output) return - jnp.sum(target_distributions * model_log_distributions, axis=-1)
def _category_cross_entropy(model_output, targets): # pylint: disable=invalid-name target_distributions = core.one_hot(targets, model_output.shape[-1]) model_log_distributions = core.log_softmax(model_output) return -jnp.sum(target_distributions * model_log_distributions, axis=-1)