def test_score_examples(self): example = eval_utils.OpenKpTextExample.from_json(EXAMPLE_JSON) pr1 = eval_utils.KpPositionPrediction(start_idx=4, phrase_len=2, logit=0.1) pr2 = eval_utils.KpPositionPrediction(start_idx=0, phrase_len=2, logit=0.3) summary = eval_utils.score_examples([example, example], [[pr1], [pr2]]) expected = [0.5, 1 / 6, 1 / 10, 0.25, 0.25, 0.25, 1 / 3, 0.2, 1 / 7] for i in range(9): self.assertAlmostEqual(summary[i], expected[i])
def test_get_key_phrase_predictions(self): example = eval_utils.OpenKpTextExample.from_json(EXAMPLE_JSON) pr1 = eval_utils.KpPositionPrediction(start_idx=4, phrase_len=2, logit=0.1) pr3 = eval_utils.KpPositionPrediction(start_idx=0, phrase_len=2, logit=0.3) pr2 = eval_utils.KpPositionPrediction(start_idx=2, phrase_len=2, logit=0.2) predictions = example.get_key_phrase_predictions([pr1, pr2, pr3], max_predictions=2) self.assertEqual(predictions, ['star trek', 'discovery season'])
def test_get_key_phrase_predictions_skip_invalid_indices(self): example = eval_utils.OpenKpTextExample.from_json(EXAMPLE_JSON) pr1 = eval_utils.KpPositionPrediction(start_idx=4, phrase_len=2, logit=0.1) # start_idx=20 is longer than the document. pr3 = eval_utils.KpPositionPrediction(start_idx=20, phrase_len=2, logit=0.3) pr2 = eval_utils.KpPositionPrediction(start_idx=2, phrase_len=2, logit=0.2) predictions = example.get_key_phrase_predictions([pr1, pr2, pr3], max_predictions=2) self.assertEqual(predictions, ['discovery season', '1 director'])
def test_logits_to_predictions(self): logits = np.array([[0.1, 0.9, 0.5, -0.3], [0.8, 0.3, 0.4, -0.5]]) predictions = eval_utils.logits_to_predictions(logits, max_predictions=3) expected1 = eval_utils.KpPositionPrediction(start_idx=1, phrase_len=1, logit=0.9) expected2 = eval_utils.KpPositionPrediction(start_idx=0, phrase_len=2, logit=0.8) expected3 = eval_utils.KpPositionPrediction(start_idx=2, phrase_len=1, logit=0.5) predictions.sort(key=lambda prediction: prediction.logit, reverse=True) self.assertEqual(predictions[0], expected1) self.assertEqual(predictions[1], expected2) self.assertEqual(predictions[2], expected3)