コード例 #1
0
 def test_score_examples(self):
     example = eval_utils.OpenKpTextExample.from_json(EXAMPLE_JSON)
     pr1 = eval_utils.KpPositionPrediction(start_idx=4,
                                           phrase_len=2,
                                           logit=0.1)
     pr2 = eval_utils.KpPositionPrediction(start_idx=0,
                                           phrase_len=2,
                                           logit=0.3)
     summary = eval_utils.score_examples([example, example], [[pr1], [pr2]])
     expected = [0.5, 1 / 6, 1 / 10, 0.25, 0.25, 0.25, 1 / 3, 0.2, 1 / 7]
     for i in range(9):
         self.assertAlmostEqual(summary[i], expected[i])
コード例 #2
0
 def test_get_key_phrase_predictions(self):
     example = eval_utils.OpenKpTextExample.from_json(EXAMPLE_JSON)
     pr1 = eval_utils.KpPositionPrediction(start_idx=4,
                                           phrase_len=2,
                                           logit=0.1)
     pr3 = eval_utils.KpPositionPrediction(start_idx=0,
                                           phrase_len=2,
                                           logit=0.3)
     pr2 = eval_utils.KpPositionPrediction(start_idx=2,
                                           phrase_len=2,
                                           logit=0.2)
     predictions = example.get_key_phrase_predictions([pr1, pr2, pr3],
                                                      max_predictions=2)
     self.assertEqual(predictions, ['star trek', 'discovery season'])
コード例 #3
0
 def test_get_key_phrase_predictions_skip_invalid_indices(self):
     example = eval_utils.OpenKpTextExample.from_json(EXAMPLE_JSON)
     pr1 = eval_utils.KpPositionPrediction(start_idx=4,
                                           phrase_len=2,
                                           logit=0.1)
     # start_idx=20 is longer than the document.
     pr3 = eval_utils.KpPositionPrediction(start_idx=20,
                                           phrase_len=2,
                                           logit=0.3)
     pr2 = eval_utils.KpPositionPrediction(start_idx=2,
                                           phrase_len=2,
                                           logit=0.2)
     predictions = example.get_key_phrase_predictions([pr1, pr2, pr3],
                                                      max_predictions=2)
     self.assertEqual(predictions, ['discovery season', '1 director'])
コード例 #4
0
 def test_logits_to_predictions(self):
     logits = np.array([[0.1, 0.9, 0.5, -0.3], [0.8, 0.3, 0.4, -0.5]])
     predictions = eval_utils.logits_to_predictions(logits,
                                                    max_predictions=3)
     expected1 = eval_utils.KpPositionPrediction(start_idx=1,
                                                 phrase_len=1,
                                                 logit=0.9)
     expected2 = eval_utils.KpPositionPrediction(start_idx=0,
                                                 phrase_len=2,
                                                 logit=0.8)
     expected3 = eval_utils.KpPositionPrediction(start_idx=2,
                                                 phrase_len=1,
                                                 logit=0.5)
     predictions.sort(key=lambda prediction: prediction.logit, reverse=True)
     self.assertEqual(predictions[0], expected1)
     self.assertEqual(predictions[1], expected2)
     self.assertEqual(predictions[2], expected3)