def test_logits_to_predictions(self):
     logits = np.array([[0.1, 0.9, 0.5, -0.3], [0.8, 0.3, 0.4, -0.5]])
     predictions = eval_utils.logits_to_predictions(logits,
                                                    max_predictions=3)
     expected1 = eval_utils.KpPositionPrediction(start_idx=1,
                                                 phrase_len=1,
                                                 logit=0.9)
     expected2 = eval_utils.KpPositionPrediction(start_idx=0,
                                                 phrase_len=2,
                                                 logit=0.8)
     expected3 = eval_utils.KpPositionPrediction(start_idx=2,
                                                 phrase_len=1,
                                                 logit=0.5)
     predictions.sort(key=lambda prediction: prediction.logit, reverse=True)
     self.assertEqual(predictions[0], expected1)
     self.assertEqual(predictions[1], expected2)
     self.assertEqual(predictions[2], expected3)
def _process_prediction(
        prediction: Mapping[Text, np.ndarray],
        text_examples: Mapping[Text, eval_utils.OpenKpTextExample],
        writer_tfrecord,
        writer_jsonl,
        metrics: Optional[Mapping[int, _MetricAverager]] = None) -> None:
    """Processes a single TF `Estimator.predict` prediction.

  This function assumes that `Estimator.predict` was called with
  `yield_single_examples=True`.

  Args:
    prediction: Prediction from `Estimator.predict` for a single example.
    text_examples: A dictionary of `OpenKpTextExample` objects, keyed by URL.
      This is used to generate the KeyPhrase predictions based on the ngram
      logits in the prediction.
    writer_tfrecord: An open `tf.python_io.TFRecordWriter` to write to.
    writer_jsonl: An open text file writer to write JSON Lines to.
    metrics: Optional `_MetricAverager`s to update with this prediction. If
      None, metric calculation is skipped completely. None is appropriate for
      example if we're just running inference for unlabeled examples.
  """
    # [kp_max_length, long_max_length] shape.
    ngram_logits = prediction['ngram_logits']

    features = collections.OrderedDict()
    features['ngram_logits'] = input_utils.create_float_feature(
        ngram_logits.flatten())

    position_predictions = eval_utils.logits_to_predictions(
        ngram_logits, max_predictions=FLAGS.max_position_predictions)
    # Sort predictions for convenience.
    position_predictions.sort(key=lambda x: x.logit, reverse=True)
    features['top_pos_logit'] = input_utils.create_float_feature(
        x.logit for x in position_predictions)
    features['top_pos_start_idx'] = input_utils.create_int_feature(
        x.start_idx for x in position_predictions)
    features['top_pos_phrase_len'] = input_utils.create_int_feature(
        x.phrase_len for x in position_predictions)

    url = ''.join(chr(x) for x in prediction['url_code_points'] if x != -1)
    features['url'] = input_utils.create_bytes_feature([url])

    if url in text_examples:
        text_example = text_examples[url]
        kp_predictions = text_example.get_key_phrase_predictions(
            position_predictions, max_predictions=FLAGS.max_kp_predictions)
        if len(kp_predictions) < FLAGS.max_kp_predictions:
            tf.logging.warn(
                f'Made fewer than `max_kp_predictions` for URL: {url}')
        writer_jsonl.write(
            json.dumps({
                'url': url,
                'KeyPhrases': [[kp] for kp in kp_predictions]
            }) + '\n')

        features['kp_predictions'] = input_utils.create_bytes_feature(
            kp_predictions)

        if metrics is not None:
            precision, recall, f1 = text_example.get_score_full(kp_predictions)
            for i in (1, 3, 5):
                p = precision[i - 1]
                r = recall[i - 1]
                f = f1[i - 1]
                features[f'p_at_{i}'] = input_utils.create_float_feature([p])
                features[f'r_at_{i}'] = input_utils.create_float_feature([r])
                features[f'f1_at_{i}'] = input_utils.create_float_feature([f])
                metrics[i].add_example(precision=p, recall=r, f1=f)
    else:
        tf.logging.error(f'No text example found for URL: {url}')

    writer_tfrecord.write(
        tf.train.Example(features=tf.train.Features(
            feature=features)).SerializeToString())