def test_global_recall_and_mean_rank(self): global_recall = keras_metrics.GlobalRecall(top_k=2) global_mean_rank = keras_metrics.GlobalMeanRank() true_label = tf.constant([[2], [0], [1]], dtype=tf.int32) logits = tf.constant( [[0.8, 0.1, 1.1, 0.3], [0.2, 0.7, 0.1, 0.5], [0.7, 0.4, 0.9, 0.2]], dtype=tf.float32) global_recall.update_state(y_true=true_label, y_pred=logits) global_mean_rank.update_state(y_true=true_label, y_pred=logits) self.assertBetween(global_recall.result().numpy(), 0.3, 0.4) self.assertBetween(global_mean_rank.result().numpy(), 1.3, 1.4)
def _get_metrics(eval_top_k): """Gets model evaluation metrics of both batch samples and full vocabulary.""" metrics_list = [ metrics.GlobalRecall(name=f'Global_Recall/Recall_{k}', top_k=k) for k in eval_top_k ] metrics_list.append(metrics.GlobalMeanRank(name='global_mean_rank')) metrics_list.extend( metrics.BatchRecall(name=f'Batch_Recall/Recall_{k}', top_k=k) for k in eval_top_k) metrics_list.append(metrics.BatchMeanRank(name='batch_mean_rank')) return metrics_list