def test_global_recall_and_mean_rank(self):
     global_recall = keras_metrics.GlobalRecall(top_k=2)
     global_mean_rank = keras_metrics.GlobalMeanRank()
     true_label = tf.constant([[2], [0], [1]], dtype=tf.int32)
     logits = tf.constant(
         [[0.8, 0.1, 1.1, 0.3], [0.2, 0.7, 0.1, 0.5], [0.7, 0.4, 0.9, 0.2]],
         dtype=tf.float32)
     global_recall.update_state(y_true=true_label, y_pred=logits)
     global_mean_rank.update_state(y_true=true_label, y_pred=logits)
     self.assertBetween(global_recall.result().numpy(), 0.3, 0.4)
     self.assertBetween(global_mean_rank.result().numpy(), 1.3, 1.4)
def _get_metrics(eval_top_k):
  """Gets model evaluation metrics of both batch samples and full vocabulary."""
  metrics_list = [
      metrics.GlobalRecall(name=f'Global_Recall/Recall_{k}', top_k=k)
      for k in eval_top_k
  ]
  metrics_list.append(metrics.GlobalMeanRank(name='global_mean_rank'))
  metrics_list.extend(
      metrics.BatchRecall(name=f'Batch_Recall/Recall_{k}', top_k=k)
      for k in eval_top_k)
  metrics_list.append(metrics.BatchMeanRank(name='batch_mean_rank'))
  return metrics_list