def test_decode_span_prediction_example_fn_with_annotations(self):
        """Tests the span prediction TFExample parsing function."""
        tf_example = self._get_test_span_prediction_tf_example()

        num_blocks_per_example = 2
        block_length = 7
        max_num_answer_annotations = 2

        decode_fn = input_utils.get_span_prediction_example_decode_fn(
            num_blocks_per_example=num_blocks_per_example,
            block_length=block_length,
            max_num_answer_annotations=max_num_answer_annotations)
        features = decode_fn(tf_example.SerializeToString())

        self.assertAllEqual([num_blocks_per_example, block_length],
                            features['token_ids'].shape)
        self.assertAllEqual([1, 2], features['block_ids'])
        self.assertAllEqual([1, 1], features['block_pos'])
        self.assertAllEqual([0, 1], features['prefix_length'])
        self.assertAllEqual([[0, 2], [2, 0]],
                            features['answer_annotation_begins'])
        self.assertAllEqual([[0, 4], [6, 0]],
                            features['answer_annotation_ends'])
        self.assertAllEqual([[1, 2], [3, 0]],
                            features['answer_annotation_labels'])
  def input_fn(params):
    """The actual input function."""
    logging.info("*** Input: Params ***")
    for name in sorted(params.keys()):
      logging.info("  %s = %s", name, params[name])

    # For training, we want a lot of parallel reading and shuffling.
    # For eval, we want no shuffling and parallel reading doesn't matter.
    if is_training:
      d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))

      # From https://www.tensorflow.org/guide/data#randomly_shuffling_input_data
      # Dataset.shuffle doesn't signal the end of an epoch until the shuffle
      # buffer is empty. So a shuffle placed before a repeat will show every
      # element of one epoch before moving to the next.
      d = d.shuffle(buffer_size=len(input_files))
      d = d.repeat()

      # `cycle_length` is the number of parallel files that get read.
      cycle_length = min(num_cpu_threads, len(input_files))

      # `sloppy` mode means that the interleaving is not exact. This adds
      # even more randomness to the training pipeline.
      d = d.apply(
          tf.data.experimental.parallel_interleave(
              tf.data.TFRecordDataset,
              sloppy=is_training,
              cycle_length=cycle_length))
      d = d.shuffle(buffer_size=1000)
    else:
      d = tf.data.TFRecordDataset(input_files)

    extra_int_features_shapes = {
        "answer_type": [num_blocks_per_example],
        "is_supporting_fact": [num_blocks_per_example],
    }
    d = d.map(
        input_utils.get_span_prediction_example_decode_fn(
            num_blocks_per_example,
            block_length,
            max_num_answer_annotations=max_num_annotations,
            max_num_entity_annotations=max_num_annotations,
            extra_int_features_shapes=extra_int_features_shapes,
        ),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    d = d.prefetch(tf.data.experimental.AUTOTUNE)
    return d
示例#3
0
  def test_decode_span_prediction_example_fn_with_summaries(self):
    """Tests the span prediction TFExample parsing function."""
    tf_example = self._get_test_span_prediction_tf_example()

    num_blocks_per_example = 2
    block_length = 7

    decode_fn = input_utils.get_span_prediction_example_decode_fn(
        num_blocks_per_example=num_blocks_per_example,
        block_length=block_length,
        max_num_answer_annotations=None,
        extra_int_features_shapes=dict(
            summary_token_ids=[num_blocks_per_example, block_length]))
    features = decode_fn(tf_example.SerializeToString())

    self.assertAllEqual([num_blocks_per_example, block_length],
                        features['token_ids'].shape)
    self.assertAllEqual([1, 2], features['block_ids'])
    self.assertAllEqual([1, 1], features['block_pos'])
    self.assertAllEqual([0, 1], features['prefix_length'])