Exemplo n.º 1
0
def _save_raw_predictions(checkpoint: str,
                          raw_predictions: Sequence[Mapping[str, np.ndarray]],
                          use_wordpiece: bool) -> None:
    """Save raw prediction to file as tf.Examples."""
    output_file = f"{checkpoint}.predicted-tfrecords"
    with tf.python_io.TFRecordWriter(output_file) as writer:
        for raw_prediction in raw_predictions:
            features = collections.OrderedDict()
            for output_name in ["unique_ids", "type", "level"]:
                features[output_name] = input_utils.create_bytes_feature(
                    [raw_prediction[output_name]])
            for output_name in [
                    "long_token_ids", "long_sentence_ids",
                    "long_token_type_ids", "global_token_ids",
                    "global_sentence_ids", "global_paragraph_ids",
                    "answer_begin_top_indices", "answer_end_top_indices",
                    "answer_types"
            ]:
                features[output_name] = input_utils.create_int_feature(
                    raw_prediction[output_name])
            for output_name in [
                    "supporting_facts_probs",
                    "answer_begin_top_probs",
                    "answer_end_top_probs",
            ]:
                features[output_name] = input_utils.create_float_feature(
                    raw_prediction[output_name])
            if use_wordpiece:
                features[
                    "long_tokens_to_unigrams"] = input_utils.create_int_feature(
                        raw_prediction["long_tokens_to_unigrams"])
            writer.write(
                tf.train.Example(features=tf.train.Features(
                    feature=features)).SerializeToString())
Exemplo n.º 2
0
    def to_tf_example(self) -> tf.train.Example:
        """Returns a TF Example."""
        # All features are int features except for these float features.
        float_features = {
            'global_x_coords', 'global_y_coords', 'global_widths',
            'global_heights', 'global_parent_x_coords',
            'global_parent_y_coords', 'global_parent_widths',
            'global_parent_heights'
        }

        fields = attr.asdict(self)
        assert all(x in fields for x in float_features)

        features = collections.OrderedDict()
        for name in attr.fields_dict(OpenKpEtcFeatures).keys():
            values = getattr(self, name)
            if name in float_features:
                features[name] = input_utils.create_float_feature(values)
            else:
                features[name] = input_utils.create_int_feature(values)
        return tf.train.Example(features=tf.train.Features(feature=features))
def _process_prediction(
        prediction: Mapping[Text, np.ndarray],
        text_examples: Mapping[Text, eval_utils.OpenKpTextExample],
        writer_tfrecord,
        writer_jsonl,
        metrics: Optional[Mapping[int, _MetricAverager]] = None) -> None:
    """Processes a single TF `Estimator.predict` prediction.

  This function assumes that `Estimator.predict` was called with
  `yield_single_examples=True`.

  Args:
    prediction: Prediction from `Estimator.predict` for a single example.
    text_examples: A dictionary of `OpenKpTextExample` objects, keyed by URL.
      This is used to generate the KeyPhrase predictions based on the ngram
      logits in the prediction.
    writer_tfrecord: An open `tf.python_io.TFRecordWriter` to write to.
    writer_jsonl: An open text file writer to write JSON Lines to.
    metrics: Optional `_MetricAverager`s to update with this prediction. If
      None, metric calculation is skipped completely. None is appropriate for
      example if we're just running inference for unlabeled examples.
  """
    # [kp_max_length, long_max_length] shape.
    ngram_logits = prediction['ngram_logits']

    features = collections.OrderedDict()
    features['ngram_logits'] = input_utils.create_float_feature(
        ngram_logits.flatten())

    position_predictions = eval_utils.logits_to_predictions(
        ngram_logits, max_predictions=FLAGS.max_position_predictions)
    # Sort predictions for convenience.
    position_predictions.sort(key=lambda x: x.logit, reverse=True)
    features['top_pos_logit'] = input_utils.create_float_feature(
        x.logit for x in position_predictions)
    features['top_pos_start_idx'] = input_utils.create_int_feature(
        x.start_idx for x in position_predictions)
    features['top_pos_phrase_len'] = input_utils.create_int_feature(
        x.phrase_len for x in position_predictions)

    url = ''.join(chr(x) for x in prediction['url_code_points'] if x != -1)
    features['url'] = input_utils.create_bytes_feature([url])

    if url in text_examples:
        text_example = text_examples[url]
        kp_predictions = text_example.get_key_phrase_predictions(
            position_predictions, max_predictions=FLAGS.max_kp_predictions)
        if len(kp_predictions) < FLAGS.max_kp_predictions:
            tf.logging.warn(
                f'Made fewer than `max_kp_predictions` for URL: {url}')
        writer_jsonl.write(
            json.dumps({
                'url': url,
                'KeyPhrases': [[kp] for kp in kp_predictions]
            }) + '\n')

        features['kp_predictions'] = input_utils.create_bytes_feature(
            kp_predictions)

        if metrics is not None:
            precision, recall, f1 = text_example.get_score_full(kp_predictions)
            for i in (1, 3, 5):
                p = precision[i - 1]
                r = recall[i - 1]
                f = f1[i - 1]
                features[f'p_at_{i}'] = input_utils.create_float_feature([p])
                features[f'r_at_{i}'] = input_utils.create_float_feature([r])
                features[f'f1_at_{i}'] = input_utils.create_float_feature([f])
                metrics[i].add_example(precision=p, recall=r, f1=f)
    else:
        tf.logging.error(f'No text example found for URL: {url}')

    writer_tfrecord.write(
        tf.train.Example(features=tf.train.Features(
            feature=features)).SerializeToString())