def f(model_output, targets, weights): # pylint: disable=invalid-name predictions = jnp.argmax(model_output, axis=-1) shapes.assert_same_shape(predictions, targets) position_is_padding = jnp.equal(weights, 0) position_is_accurate = jnp.logical_or(jnp.equal(predictions, targets), position_is_padding) sequence_is_accurate = jnp.all(position_is_accurate, axis=-1) return jnp.average(sequence_is_accurate)
def f(model_output, targets): # pylint: disable=invalid-name beta2 = beta**2 predictions = jnp.argmax(model_output, axis=-1) n_categories = model_output.shape[-1] f_scores = jnp.empty(0) weights = jnp.empty(0) for k in range(initial_category_index, n_categories): _, _, n_k_targets, precision, recall = _precision_recall( predictions, targets, k) f_scores = jnp.append(f_scores, _f_score(precision, recall, beta2)) weights = jnp.append(weights, n_k_targets) return jnp.average(f_scores, weights=weights)
def f(model_output, targets): # pylint: disable=invalid-name probabilities = fastmath.expit(model_output) binary_entropies = -(targets * jnp.log(probabilities) + (1 - targets) * (jnp.log(1 - probabilities))) return jnp.average(binary_entropies)
def f(model_output, targets): # pylint: disable=invalid-name cross_entropies = _category_cross_entropy(model_output, targets) return jnp.average(cross_entropies)