def _get_metrics(eval_top_k): """Gets model evaluation metrics.""" eval_metrics = [ metrics.BatchRecall(name='Recall/Recall_{0}'.format(k), top_k=k) for k in eval_top_k ] batch_mean_rank = metrics.BatchMeanRank() eval_metrics.append(batch_mean_rank) return eval_metrics
def test_batch_recall_and_mean_rank(self): batch_recall = keras_metrics.BatchRecall(top_k=2) batch_mean_rank = keras_metrics.BatchMeanRank() true_label = tf.constant([[2], [0], [1]], dtype=tf.int32) logits = tf.constant( [[0.8, 0.1, 1.1, 0.3], [0.2, 0.7, 0.1, 0.5], [0.7, 0.4, 0.9, 0.2]], dtype=tf.float32) batch_recall.update_state(y_true=true_label, y_pred=logits) batch_mean_rank.update_state(y_true=true_label, y_pred=logits) self.assertBetween(batch_recall.result().numpy(), 0.6, 0.7) self.assertEqual(batch_mean_rank.result().numpy(), 1.0)
def _get_metrics(eval_top_k): """Gets model evaluation metrics of both batch samples and full vocabulary.""" metrics_list = [ metrics.GlobalRecall(name=f'Global_Recall/Recall_{k}', top_k=k) for k in eval_top_k ] metrics_list.append(metrics.GlobalMeanRank(name='global_mean_rank')) metrics_list.extend( metrics.BatchRecall(name=f'Batch_Recall/Recall_{k}', top_k=k) for k in eval_top_k) metrics_list.append(metrics.BatchMeanRank(name='batch_mean_rank')) return metrics_list