コード例 #1
0
 def testDecodeImageLabels(self):
   image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
   encoded_jpeg = self._EncodeImage(image_tensor)
   example = tf.train.Example(
       features=tf.train.Features(
           feature={
               'image/encoded': self._BytesFeature(encoded_jpeg),
               'image/format': self._BytesFeature('jpeg'),
               'image/class/label': self._Int64Feature([1, 2]),
           })).SerializeToString()
   example_decoder = tf_example_decoder.TfExampleDecoder()
   tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
   with self.test_session() as sess:
     tensor_dict = sess.run(tensor_dict)
   self.assertTrue(
       fields.InputDataFields.groundtruth_image_classes in tensor_dict)
   self.assertAllEqual(
       tensor_dict[fields.InputDataFields.groundtruth_image_classes],
       np.array([1, 2]))
   example = tf.train.Example(
       features=tf.train.Features(
           feature={
               'image/encoded': self._BytesFeature(encoded_jpeg),
               'image/format': self._BytesFeature('jpeg'),
               'image/class/text': self._BytesFeature(['dog', 'cat']),
           })).SerializeToString()
   label_map_string = """
     item {
       id:3
       name:'cat'
     }
     item {
       id:1
       name:'dog'
     }
   """
   label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
   with tf.gfile.Open(label_map_path, 'wb') as f:
     f.write(label_map_string)
   example_decoder = tf_example_decoder.TfExampleDecoder(
       label_map_proto_file=label_map_path)
   tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
   with self.test_session() as sess:
     sess.run(tf.tables_initializer())
     tensor_dict = sess.run(tensor_dict)
   self.assertTrue(
       fields.InputDataFields.groundtruth_image_classes in tensor_dict)
   self.assertAllEqual(
       tensor_dict[fields.InputDataFields.groundtruth_image_classes],
       np.array([1, 3]))
コード例 #2
0
  def testInstancesNotAvailableByDefault(self):
    num_instances = 4
    image_height = 5
    image_width = 3
    # Randomly generate image.
    image_tensor = np.random.randint(256, size=(image_height,
                                                image_width,
                                                3)).astype(np.uint8)
    encoded_jpeg = self._EncodeImage(image_tensor)

    # Randomly generate instance segmentation masks.
    instance_masks = (
        np.random.randint(2, size=(num_instances,
                                   image_height,
                                   image_width)).astype(np.float32))
    instance_masks_flattened = np.reshape(instance_masks, [-1])

    # Randomly generate class labels for each instance.
    object_classes = np.random.randint(
        100, size=(num_instances)).astype(np.int64)

    example = tf.train.Example(features=tf.train.Features(feature={
        'image/encoded': self._BytesFeature(encoded_jpeg),
        'image/format': self._BytesFeature('jpeg'),
        'image/height': self._Int64Feature([image_height]),
        'image/width': self._Int64Feature([image_width]),
        'image/object/mask': self._FloatFeature(instance_masks_flattened),
        'image/object/class/label': self._Int64Feature(
            object_classes)})).SerializeToString()
    example_decoder = tf_example_decoder.TfExampleDecoder()
    tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
    self.assertTrue(fields.InputDataFields.groundtruth_instance_masks
                    not in tensor_dict)
コード例 #3
0
  def testDecodeDefaultGroundtruthWeights(self):
    image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
    encoded_jpeg = self._EncodeImage(image_tensor)
    bbox_ymins = [0.0, 4.0]
    bbox_xmins = [1.0, 5.0]
    bbox_ymaxs = [2.0, 6.0]
    bbox_xmaxs = [3.0, 7.0]
    example = tf.train.Example(features=tf.train.Features(feature={
        'image/encoded': self._BytesFeature(encoded_jpeg),
        'image/format': self._BytesFeature('jpeg'),
        'image/object/bbox/ymin': self._FloatFeature(bbox_ymins),
        'image/object/bbox/xmin': self._FloatFeature(bbox_xmins),
        'image/object/bbox/ymax': self._FloatFeature(bbox_ymaxs),
        'image/object/bbox/xmax': self._FloatFeature(bbox_xmaxs),
    })).SerializeToString()

    example_decoder = tf_example_decoder.TfExampleDecoder()
    tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

    self.assertAllEqual((tensor_dict[fields.InputDataFields.groundtruth_boxes].
                         get_shape().as_list()), [None, 4])

    with self.test_session() as sess:
      tensor_dict = sess.run(tensor_dict)

    self.assertAllClose(tensor_dict[fields.InputDataFields.groundtruth_weights],
                        np.ones(2, dtype=np.float32))
コード例 #4
0
  def testDecodeAdditionalChannels(self):
    image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
    encoded_jpeg = self._EncodeImage(image_tensor)

    additional_channel_tensor = np.random.randint(
        256, size=(4, 5, 1)).astype(np.uint8)
    encoded_additional_channel = self._EncodeImage(additional_channel_tensor)
    decoded_additional_channel = self._DecodeImage(encoded_additional_channel)

    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                'image/encoded':
                    self._BytesFeature(encoded_jpeg),
                'image/additional_channels/encoded':
                    self._BytesFeatureFromList(
                        np.array([encoded_additional_channel] * 2)),
                'image/format':
                    self._BytesFeature('jpeg'),
                'image/source_id':
                    self._BytesFeature('image_id'),
            })).SerializeToString()

    example_decoder = tf_example_decoder.TfExampleDecoder(
        num_additional_channels=2)
    tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

    with self.test_session() as sess:
      tensor_dict = sess.run(tensor_dict)
      self.assertAllEqual(
          np.concatenate([decoded_additional_channel] * 2, axis=2),
          tensor_dict[fields.InputDataFields.image_additional_channels])
コード例 #5
0
  def testDecodePngInstanceMasks(self):
    image_tensor = np.random.randint(256, size=(10, 10, 3)).astype(np.uint8)
    encoded_jpeg = self._EncodeImage(image_tensor)
    mask_1 = np.random.randint(0, 2, size=(10, 10, 1)).astype(np.uint8)
    mask_2 = np.random.randint(0, 2, size=(10, 10, 1)).astype(np.uint8)
    encoded_png_1 = self._EncodeImage(mask_1, encoding_type='png')
    decoded_png_1 = np.squeeze(mask_1.astype(np.float32))
    encoded_png_2 = self._EncodeImage(mask_2, encoding_type='png')
    decoded_png_2 = np.squeeze(mask_2.astype(np.float32))
    encoded_masks = [encoded_png_1, encoded_png_2]
    decoded_masks = np.stack([decoded_png_1, decoded_png_2])
    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                'image/encoded': self._BytesFeature(encoded_jpeg),
                'image/format': self._BytesFeature('jpeg'),
                'image/object/mask': self._BytesFeature(encoded_masks)
            })).SerializeToString()

    example_decoder = tf_example_decoder.TfExampleDecoder(
        load_instance_masks=True, instance_mask_type=input_reader_pb2.PNG_MASKS)
    tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

    with self.test_session() as sess:
      tensor_dict = sess.run(tensor_dict)

    self.assertAllEqual(
        decoded_masks,
        tensor_dict[fields.InputDataFields.groundtruth_instance_masks])
コード例 #6
0
  def testDecodeInstanceSegmentation(self):
    num_instances = 4
    image_height = 5
    image_width = 3

    # Randomly generate image.
    image_tensor = np.random.randint(256, size=(image_height,
                                                image_width,
                                                3)).astype(np.uint8)
    encoded_jpeg = self._EncodeImage(image_tensor)

    # Randomly generate instance segmentation masks.
    instance_masks = (
        np.random.randint(2, size=(num_instances,
                                   image_height,
                                   image_width)).astype(np.float32))
    instance_masks_flattened = np.reshape(instance_masks, [-1])

    # Randomly generate class labels for each instance.
    object_classes = np.random.randint(
        100, size=(num_instances)).astype(np.int64)

    example = tf.train.Example(features=tf.train.Features(feature={
        'image/encoded': self._BytesFeature(encoded_jpeg),
        'image/format': self._BytesFeature('jpeg'),
        'image/height': self._Int64Feature([image_height]),
        'image/width': self._Int64Feature([image_width]),
        'image/object/mask': self._FloatFeature(instance_masks_flattened),
        'image/object/class/label': self._Int64Feature(
            object_classes)})).SerializeToString()
    example_decoder = tf_example_decoder.TfExampleDecoder(
        load_instance_masks=True)
    tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

    self.assertAllEqual((
        tensor_dict[fields.InputDataFields.groundtruth_instance_masks].
        get_shape().as_list()), [4, 5, 3])

    self.assertAllEqual((
        tensor_dict[fields.InputDataFields.groundtruth_classes].
        get_shape().as_list()), [4])

    with self.test_session() as sess:
      tensor_dict = sess.run(tensor_dict)

    self.assertAllEqual(
        instance_masks.astype(np.float32),
        tensor_dict[fields.InputDataFields.groundtruth_instance_masks])
    self.assertAllEqual(
        object_classes,
        tensor_dict[fields.InputDataFields.groundtruth_classes])
コード例 #7
0
  def testDecodeImageKeyAndFilename(self):
    image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
    encoded_jpeg = self._EncodeImage(image_tensor)
    example = tf.train.Example(features=tf.train.Features(feature={
        'image/encoded': self._BytesFeature(encoded_jpeg),
        'image/key/sha256': self._BytesFeature('abc'),
        'image/filename': self._BytesFeature('filename')
    })).SerializeToString()

    example_decoder = tf_example_decoder.TfExampleDecoder()
    tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

    with self.test_session() as sess:
      tensor_dict = sess.run(tensor_dict)

    self.assertEqual('abc', tensor_dict[fields.InputDataFields.key])
    self.assertEqual('filename', tensor_dict[fields.InputDataFields.filename])
コード例 #8
0
def build(input_reader_config):
    """Builds a tensor dictionary based on the InputReader config.

  Args:
    input_reader_config: A input_reader_pb2.InputReader object.

  Returns:
    A tensor dict based on the input_reader_config.

  Raises:
    ValueError: On invalid input reader proto.
    ValueError: If no input paths are specified.
  """
    if not isinstance(input_reader_config, input_reader_pb2.InputReader):
        raise ValueError('input_reader_config not of type '
                         'input_reader_pb2.InputReader.')

    if input_reader_config.WhichOneof(
            'input_reader') == 'tf_record_input_reader':
        config = input_reader_config.tf_record_input_reader
        if not config.input_path:
            raise ValueError('At least one input path must be specified in '
                             '`input_reader_config`.')
        _, string_tensor = parallel_reader.parallel_read(
            config.input_path[:],  # Convert `RepeatedScalarContainer` to list.
            reader_class=tf.TFRecordReader,
            num_epochs=(input_reader_config.num_epochs
                        if input_reader_config.num_epochs else None),
            num_readers=input_reader_config.num_readers,
            shuffle=input_reader_config.shuffle,
            dtypes=[tf.string, tf.string],
            capacity=input_reader_config.queue_capacity,
            min_after_dequeue=input_reader_config.min_after_dequeue)

        label_map_proto_file = None
        if input_reader_config.HasField('label_map_path'):
            label_map_proto_file = input_reader_config.label_map_path
        decoder = tf_example_decoder.TfExampleDecoder(
            load_instance_masks=input_reader_config.load_instance_masks,
            instance_mask_type=input_reader_config.mask_type,
            label_map_proto_file=label_map_proto_file)
        return decoder.decode(string_tensor)

    raise ValueError('Unsupported input_reader_config.')
コード例 #9
0
  def testDecodeKeypoint(self):
    image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
    encoded_jpeg = self._EncodeImage(image_tensor)
    bbox_ymins = [0.0, 4.0]
    bbox_xmins = [1.0, 5.0]
    bbox_ymaxs = [2.0, 6.0]
    bbox_xmaxs = [3.0, 7.0]
    keypoint_ys = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
    keypoint_xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
    example = tf.train.Example(features=tf.train.Features(feature={
        'image/encoded': self._BytesFeature(encoded_jpeg),
        'image/format': self._BytesFeature('jpeg'),
        'image/object/bbox/ymin': self._FloatFeature(bbox_ymins),
        'image/object/bbox/xmin': self._FloatFeature(bbox_xmins),
        'image/object/bbox/ymax': self._FloatFeature(bbox_ymaxs),
        'image/object/bbox/xmax': self._FloatFeature(bbox_xmaxs),
        'image/object/keypoint/y': self._FloatFeature(keypoint_ys),
        'image/object/keypoint/x': self._FloatFeature(keypoint_xs),
    })).SerializeToString()

    example_decoder = tf_example_decoder.TfExampleDecoder(num_keypoints=3)
    tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

    self.assertAllEqual((tensor_dict[fields.InputDataFields.groundtruth_boxes].
                         get_shape().as_list()), [None, 4])
    self.assertAllEqual((tensor_dict[fields.InputDataFields.
                                     groundtruth_keypoints].
                         get_shape().as_list()), [2, 3, 2])
    with self.test_session() as sess:
      tensor_dict = sess.run(tensor_dict)

    expected_boxes = np.vstack([bbox_ymins, bbox_xmins,
                                bbox_ymaxs, bbox_xmaxs]).transpose()
    self.assertAllEqual(expected_boxes,
                        tensor_dict[fields.InputDataFields.groundtruth_boxes])
    self.assertAllEqual(
        2, tensor_dict[fields.InputDataFields.num_groundtruth_boxes])

    expected_keypoints = (
        np.vstack([keypoint_ys, keypoint_xs]).transpose().reshape((2, 3, 2)))
    self.assertAllEqual(expected_keypoints,
                        tensor_dict[
                            fields.InputDataFields.groundtruth_keypoints])
コード例 #10
0
  def testDecodeObjectArea(self):
    image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
    encoded_jpeg = self._EncodeImage(image_tensor)
    object_area = [100., 174.]
    example = tf.train.Example(features=tf.train.Features(feature={
        'image/encoded': self._BytesFeature(encoded_jpeg),
        'image/format': self._BytesFeature('jpeg'),
        'image/object/area': self._FloatFeature(object_area),
    })).SerializeToString()

    example_decoder = tf_example_decoder.TfExampleDecoder()
    tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

    self.assertAllEqual((tensor_dict[fields.InputDataFields.groundtruth_area].
                         get_shape().as_list()), [2])
    with self.test_session() as sess:
      tensor_dict = sess.run(tensor_dict)

    self.assertAllEqual(object_area,
                        tensor_dict[fields.InputDataFields.groundtruth_area])
コード例 #11
0
  def testDecodePngImage(self):
    image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
    encoded_png = self._EncodeImage(image_tensor, encoding_type='png')
    decoded_png = self._DecodeImage(encoded_png, encoding_type='png')
    example = tf.train.Example(features=tf.train.Features(feature={
        'image/encoded': self._BytesFeature(encoded_png),
        'image/format': self._BytesFeature('png'),
        'image/source_id': self._BytesFeature('image_id')
    })).SerializeToString()

    example_decoder = tf_example_decoder.TfExampleDecoder()
    tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

    self.assertAllEqual((tensor_dict[fields.InputDataFields.image].
                         get_shape().as_list()), [None, None, 3])
    with self.test_session() as sess:
      tensor_dict = sess.run(tensor_dict)

    self.assertAllEqual(decoded_png, tensor_dict[fields.InputDataFields.image])
    self.assertEqual('image_id', tensor_dict[fields.InputDataFields.source_id])
コード例 #12
0
  def testDecodeObjectLabelWithMapping(self):
    image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
    encoded_jpeg = self._EncodeImage(image_tensor)
    bbox_classes_text = ['cat', 'dog']
    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                'image/encoded':
                    self._BytesFeature(encoded_jpeg),
                'image/format':
                    self._BytesFeature('jpeg'),
                'image/object/class/text':
                    self._BytesFeature(bbox_classes_text),
            })).SerializeToString()

    label_map_string = """
      item {
        id:3
        name:'cat'
      }
      item {
        id:1
        name:'dog'
      }
    """
    label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
    with tf.gfile.Open(label_map_path, 'wb') as f:
      f.write(label_map_string)
    example_decoder = tf_example_decoder.TfExampleDecoder(
        label_map_proto_file=label_map_path)
    tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

    self.assertAllEqual((tensor_dict[fields.InputDataFields.groundtruth_classes]
                         .get_shape().as_list()), [None])

    with self.test_session() as sess:
      sess.run(tf.tables_initializer())
      tensor_dict = sess.run(tensor_dict)

    self.assertAllEqual([3, 1],
                        tensor_dict[fields.InputDataFields.groundtruth_classes])
コード例 #13
0
ファイル: inputs.py プロジェクト: uniquetrij/SecureIt
    def _predict_input_fn(params=None):
        """Decodes serialized tf.Examples and returns `ServingInputReceiver`.

    Args:
      params: Parameter dictionary passed from the estimator.

    Returns:
      `ServingInputReceiver`.
    """
        del params
        example = tf.placeholder(dtype=tf.string, shape=[], name='tf_example')

        num_classes = config_util.get_number_of_classes(model_config)
        model = model_builder.build(model_config, is_training=False)
        image_resizer_config = config_util.get_image_resizer_config(
            model_config)
        image_resizer_fn = image_resizer_builder.build(image_resizer_config)

        transform_fn = functools.partial(transform_input_data,
                                         model_preprocess_fn=model.preprocess,
                                         image_resizer_fn=image_resizer_fn,
                                         num_classes=num_classes,
                                         data_augmentation_fn=None)

        decoder = tf_example_decoder.TfExampleDecoder(
            load_instance_masks=False,
            num_additional_channels=predict_input_config.
            num_additional_channels)
        input_dict = transform_fn(decoder.decode(example))
        images = tf.to_float(input_dict[fields.InputDataFields.image])
        images = tf.expand_dims(images, axis=0)
        true_image_shape = tf.expand_dims(
            input_dict[fields.InputDataFields.true_image_shape], axis=0)

        return tf.estimator.export.ServingInputReceiver(
            features={
                fields.InputDataFields.image: images,
                fields.InputDataFields.true_image_shape: true_image_shape
            },
            receiver_tensors={SERVING_FED_EXAMPLE_KEY: example})
コード例 #14
0
  def testDecodeObjectDifficult(self):
    image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
    encoded_jpeg = self._EncodeImage(image_tensor)
    object_difficult = [0, 1]
    example = tf.train.Example(features=tf.train.Features(feature={
        'image/encoded': self._BytesFeature(encoded_jpeg),
        'image/format': self._BytesFeature('jpeg'),
        'image/object/difficult': self._Int64Feature(object_difficult),
    })).SerializeToString()

    example_decoder = tf_example_decoder.TfExampleDecoder()
    tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

    self.assertAllEqual((tensor_dict[
        fields.InputDataFields.groundtruth_difficult].get_shape().as_list()),
                        [2])
    with self.test_session() as sess:
      tensor_dict = sess.run(tensor_dict)

    self.assertAllEqual([bool(item) for item in object_difficult],
                        tensor_dict[
                            fields.InputDataFields.groundtruth_difficult])
コード例 #15
0
  def testDecodeEmptyPngInstanceMasks(self):
    image_tensor = np.random.randint(256, size=(10, 10, 3)).astype(np.uint8)
    encoded_jpeg = self._EncodeImage(image_tensor)
    encoded_masks = []
    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                'image/encoded': self._BytesFeature(encoded_jpeg),
                'image/format': self._BytesFeature('jpeg'),
                'image/object/mask': self._BytesFeature(encoded_masks),
                'image/height': self._Int64Feature([10]),
                'image/width': self._Int64Feature([10]),
            })).SerializeToString()

    example_decoder = tf_example_decoder.TfExampleDecoder(
        load_instance_masks=True, instance_mask_type=input_reader_pb2.PNG_MASKS)
    tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

    with self.test_session() as sess:
      tensor_dict = sess.run(tensor_dict)
      self.assertAllEqual(
          tensor_dict[fields.InputDataFields.groundtruth_instance_masks].shape,
          [0, 10, 10])
コード例 #16
0
ファイル: exporter.py プロジェクト: suvratjain1995/SecureIt
 def decode(tf_example_string_tensor):
     tensor_dict = tf_example_decoder.TfExampleDecoder().decode(
         tf_example_string_tensor)
     image_tensor = tensor_dict[fields.InputDataFields.image]
     return image_tensor
コード例 #17
0
def build(input_reader_config, batch_size=None, transform_input_data_fn=None):
    """Builds a tf.data.Dataset.

  Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all
  records. Applies a padded batch to the resulting dataset.

  Args:
    input_reader_config: A input_reader_pb2.InputReader object.
    batch_size: Batch size. If batch size is None, no batching is performed.
    transform_input_data_fn: Function to apply transformation to all records,
      or None if no extra decoding is required.

  Returns:
    A tf.data.Dataset based on the input_reader_config.

  Raises:
    ValueError: On invalid input reader proto.
    ValueError: If no input paths are specified.
  """
    if not isinstance(input_reader_config, input_reader_pb2.InputReader):
        raise ValueError('input_reader_config not of type '
                         'input_reader_pb2.InputReader.')

    if input_reader_config.WhichOneof(
            'input_reader') == 'tf_record_input_reader':
        config = input_reader_config.tf_record_input_reader
        if not config.input_path:
            raise ValueError('At least one input path must be specified in '
                             '`input_reader_config`.')

        label_map_proto_file = None
        if input_reader_config.HasField('label_map_path'):
            label_map_proto_file = input_reader_config.label_map_path
        decoder = tf_example_decoder.TfExampleDecoder(
            load_instance_masks=input_reader_config.load_instance_masks,
            instance_mask_type=input_reader_config.mask_type,
            label_map_proto_file=label_map_proto_file,
            use_display_name=input_reader_config.use_display_name,
            num_additional_channels=input_reader_config.num_additional_channels
        )

        def process_fn(value):
            """Sets up tf graph that decodes, transforms and pads input data."""
            processed_tensors = decoder.decode(value)
            if transform_input_data_fn is not None:
                processed_tensors = transform_input_data_fn(processed_tensors)
            return processed_tensors

        dataset = read_dataset(
            functools.partial(tf.data.TFRecordDataset,
                              buffer_size=8 * 1000 * 1000),
            config.input_path[:], input_reader_config)
        # TODO(rathodv): make batch size a required argument once the old binaries
        # are deleted.
        if batch_size:
            num_parallel_calls = batch_size * input_reader_config.num_parallel_batches
        else:
            num_parallel_calls = input_reader_config.num_parallel_map_calls
        dataset = dataset.map(process_fn,
                              num_parallel_calls=num_parallel_calls)
        if batch_size:
            dataset = dataset.apply(
                tf.contrib.data.batch_and_drop_remainder(batch_size))
        dataset = dataset.prefetch(input_reader_config.num_prefetch_batches)
        return dataset

    raise ValueError('Unsupported input_reader_config.')