コード例 #1
0
 def f(model_output, targets, weights):  # pylint: disable=invalid-name
     predictions = jnp.argmax(model_output, axis=-1)
     shapes.assert_same_shape(predictions, targets)
     position_is_padding = jnp.equal(weights, 0)
     position_is_accurate = jnp.logical_or(jnp.equal(predictions, targets),
                                           position_is_padding)
     sequence_is_accurate = jnp.all(position_is_accurate, axis=-1)
     return jnp.average(sequence_is_accurate)
コード例 #2
0
 def f(model_output, targets):  # pylint: disable=invalid-name
     beta2 = beta**2
     predictions = jnp.argmax(model_output, axis=-1)
     n_categories = model_output.shape[-1]
     f_scores = jnp.empty(0)
     weights = jnp.empty(0)
     for k in range(initial_category_index, n_categories):
         _, _, n_k_targets, precision, recall = _precision_recall(
             predictions, targets, k)
         f_scores = jnp.append(f_scores, _f_score(precision, recall, beta2))
         weights = jnp.append(weights, n_k_targets)
     return jnp.average(f_scores, weights=weights)
コード例 #3
0
 def f(model_output, targets):  # pylint: disable=invalid-name
     probabilities = fastmath.expit(model_output)
     binary_entropies = -(targets * jnp.log(probabilities) + (1 - targets) *
                          (jnp.log(1 - probabilities)))
     return jnp.average(binary_entropies)
コード例 #4
0
 def f(model_output, targets):  # pylint: disable=invalid-name
     cross_entropies = _category_cross_entropy(model_output, targets)
     return jnp.average(cross_entropies)