def test_decode_span_prediction_example_fn_with_annotations(self): """Tests the span prediction TFExample parsing function.""" tf_example = self._get_test_span_prediction_tf_example() num_blocks_per_example = 2 block_length = 7 max_num_answer_annotations = 2 decode_fn = input_utils.get_span_prediction_example_decode_fn( num_blocks_per_example=num_blocks_per_example, block_length=block_length, max_num_answer_annotations=max_num_answer_annotations) features = decode_fn(tf_example.SerializeToString()) self.assertAllEqual([num_blocks_per_example, block_length], features['token_ids'].shape) self.assertAllEqual([1, 2], features['block_ids']) self.assertAllEqual([1, 1], features['block_pos']) self.assertAllEqual([0, 1], features['prefix_length']) self.assertAllEqual([[0, 2], [2, 0]], features['answer_annotation_begins']) self.assertAllEqual([[0, 4], [6, 0]], features['answer_annotation_ends']) self.assertAllEqual([[1, 2], [3, 0]], features['answer_annotation_labels'])
def input_fn(params): """The actual input function.""" logging.info("*** Input: Params ***") for name in sorted(params.keys()): logging.info(" %s = %s", name, params[name]) # For training, we want a lot of parallel reading and shuffling. # For eval, we want no shuffling and parallel reading doesn't matter. if is_training: d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) # From https://www.tensorflow.org/guide/data#randomly_shuffling_input_data # Dataset.shuffle doesn't signal the end of an epoch until the shuffle # buffer is empty. So a shuffle placed before a repeat will show every # element of one epoch before moving to the next. d = d.shuffle(buffer_size=len(input_files)) d = d.repeat() # `cycle_length` is the number of parallel files that get read. cycle_length = min(num_cpu_threads, len(input_files)) # `sloppy` mode means that the interleaving is not exact. This adds # even more randomness to the training pipeline. d = d.apply( tf.data.experimental.parallel_interleave( tf.data.TFRecordDataset, sloppy=is_training, cycle_length=cycle_length)) d = d.shuffle(buffer_size=1000) else: d = tf.data.TFRecordDataset(input_files) extra_int_features_shapes = { "answer_type": [num_blocks_per_example], "is_supporting_fact": [num_blocks_per_example], } d = d.map( input_utils.get_span_prediction_example_decode_fn( num_blocks_per_example, block_length, max_num_answer_annotations=max_num_annotations, max_num_entity_annotations=max_num_annotations, extra_int_features_shapes=extra_int_features_shapes, ), num_parallel_calls=tf.data.experimental.AUTOTUNE) d = d.prefetch(tf.data.experimental.AUTOTUNE) return d
def test_decode_span_prediction_example_fn_with_summaries(self): """Tests the span prediction TFExample parsing function.""" tf_example = self._get_test_span_prediction_tf_example() num_blocks_per_example = 2 block_length = 7 decode_fn = input_utils.get_span_prediction_example_decode_fn( num_blocks_per_example=num_blocks_per_example, block_length=block_length, max_num_answer_annotations=None, extra_int_features_shapes=dict( summary_token_ids=[num_blocks_per_example, block_length])) features = decode_fn(tf_example.SerializeToString()) self.assertAllEqual([num_blocks_per_example, block_length], features['token_ids'].shape) self.assertAllEqual([1, 2], features['block_ids']) self.assertAllEqual([1, 1], features['block_pos']) self.assertAllEqual([0, 1], features['prefix_length'])