Esempio n. 1
0
def _get_metrics(eval_top_k):
  """Gets model evaluation metrics."""
  eval_metrics = [
      metrics.BatchRecall(name='Recall/Recall_{0}'.format(k), top_k=k)
      for k in eval_top_k
  ]
  batch_mean_rank = metrics.BatchMeanRank()
  eval_metrics.append(batch_mean_rank)
  return eval_metrics
 def test_batch_recall_and_mean_rank(self):
     batch_recall = keras_metrics.BatchRecall(top_k=2)
     batch_mean_rank = keras_metrics.BatchMeanRank()
     true_label = tf.constant([[2], [0], [1]], dtype=tf.int32)
     logits = tf.constant(
         [[0.8, 0.1, 1.1, 0.3], [0.2, 0.7, 0.1, 0.5], [0.7, 0.4, 0.9, 0.2]],
         dtype=tf.float32)
     batch_recall.update_state(y_true=true_label, y_pred=logits)
     batch_mean_rank.update_state(y_true=true_label, y_pred=logits)
     self.assertBetween(batch_recall.result().numpy(), 0.6, 0.7)
     self.assertEqual(batch_mean_rank.result().numpy(), 1.0)
def _get_metrics(eval_top_k):
  """Gets model evaluation metrics of both batch samples and full vocabulary."""
  metrics_list = [
      metrics.GlobalRecall(name=f'Global_Recall/Recall_{k}', top_k=k)
      for k in eval_top_k
  ]
  metrics_list.append(metrics.GlobalMeanRank(name='global_mean_rank'))
  metrics_list.extend(
      metrics.BatchRecall(name=f'Batch_Recall/Recall_{k}', top_k=k)
      for k in eval_top_k)
  metrics_list.append(metrics.BatchMeanRank(name='batch_mean_rank'))
  return metrics_list