def test_logits_to_predictions(self): logits = np.array([[0.1, 0.9, 0.5, -0.3], [0.8, 0.3, 0.4, -0.5]]) predictions = eval_utils.logits_to_predictions(logits, max_predictions=3) expected1 = eval_utils.KpPositionPrediction(start_idx=1, phrase_len=1, logit=0.9) expected2 = eval_utils.KpPositionPrediction(start_idx=0, phrase_len=2, logit=0.8) expected3 = eval_utils.KpPositionPrediction(start_idx=2, phrase_len=1, logit=0.5) predictions.sort(key=lambda prediction: prediction.logit, reverse=True) self.assertEqual(predictions[0], expected1) self.assertEqual(predictions[1], expected2) self.assertEqual(predictions[2], expected3)
def _process_prediction( prediction: Mapping[Text, np.ndarray], text_examples: Mapping[Text, eval_utils.OpenKpTextExample], writer_tfrecord, writer_jsonl, metrics: Optional[Mapping[int, _MetricAverager]] = None) -> None: """Processes a single TF `Estimator.predict` prediction. This function assumes that `Estimator.predict` was called with `yield_single_examples=True`. Args: prediction: Prediction from `Estimator.predict` for a single example. text_examples: A dictionary of `OpenKpTextExample` objects, keyed by URL. This is used to generate the KeyPhrase predictions based on the ngram logits in the prediction. writer_tfrecord: An open `tf.python_io.TFRecordWriter` to write to. writer_jsonl: An open text file writer to write JSON Lines to. metrics: Optional `_MetricAverager`s to update with this prediction. If None, metric calculation is skipped completely. None is appropriate for example if we're just running inference for unlabeled examples. """ # [kp_max_length, long_max_length] shape. ngram_logits = prediction['ngram_logits'] features = collections.OrderedDict() features['ngram_logits'] = input_utils.create_float_feature( ngram_logits.flatten()) position_predictions = eval_utils.logits_to_predictions( ngram_logits, max_predictions=FLAGS.max_position_predictions) # Sort predictions for convenience. position_predictions.sort(key=lambda x: x.logit, reverse=True) features['top_pos_logit'] = input_utils.create_float_feature( x.logit for x in position_predictions) features['top_pos_start_idx'] = input_utils.create_int_feature( x.start_idx for x in position_predictions) features['top_pos_phrase_len'] = input_utils.create_int_feature( x.phrase_len for x in position_predictions) url = ''.join(chr(x) for x in prediction['url_code_points'] if x != -1) features['url'] = input_utils.create_bytes_feature([url]) if url in text_examples: text_example = text_examples[url] kp_predictions = text_example.get_key_phrase_predictions( position_predictions, max_predictions=FLAGS.max_kp_predictions) if len(kp_predictions) < FLAGS.max_kp_predictions: tf.logging.warn( f'Made fewer than `max_kp_predictions` for URL: {url}') writer_jsonl.write( json.dumps({ 'url': url, 'KeyPhrases': [[kp] for kp in kp_predictions] }) + '\n') features['kp_predictions'] = input_utils.create_bytes_feature( kp_predictions) if metrics is not None: precision, recall, f1 = text_example.get_score_full(kp_predictions) for i in (1, 3, 5): p = precision[i - 1] r = recall[i - 1] f = f1[i - 1] features[f'p_at_{i}'] = input_utils.create_float_feature([p]) features[f'r_at_{i}'] = input_utils.create_float_feature([r]) features[f'f1_at_{i}'] = input_utils.create_float_feature([f]) metrics[i].add_example(precision=p, recall=r, f1=f) else: tf.logging.error(f'No text example found for URL: {url}') writer_tfrecord.write( tf.train.Example(features=tf.train.Features( feature=features)).SerializeToString())