def testDecodeEmptyPngInstanceMasks(self):
        image_tensor = np.random.randint(256,
                                         size=(10, 10, 3)).astype(np.uint8)
        encoded_jpeg = self._EncodeImage(image_tensor)
        encoded_masks = []
        example = tf.train.Example(features=tf.train.Features(
            feature={
                'image/encoded': self._BytesFeature(encoded_jpeg),
                'image/format': self._BytesFeature('jpeg'),
                'image/object/mask': self._BytesFeature(encoded_masks),
                'image/height': self._Int64Feature([10]),
                'image/width': self._Int64Feature([10]),
            })).SerializeToString()

        example_decoder = tf_example_decoder.TfExampleDecoder(
            load_instance_masks=True,
            instance_mask_type=input_reader_pb2.PNG_MASKS)
        tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

        with self.test_session() as sess:
            tensor_dict = sess.run(tensor_dict)
            self.assertAllEqual(
                tensor_dict[
                    fields.InputDataFields.groundtruth_instance_masks].shape,
                [0, 10, 10])
Example #2
0
def build(input_reader_config):
    """Builds a tensor dictionary based on the InputReader config.

  Args:
    input_reader_config: A input_reader_pb2.InputReader object.

  Returns:
    A tensor dict based on the input_reader_config.

  Raises:
    ValueError: On invalid input reader proto.
  """
    if not isinstance(input_reader_config, input_reader_pb2.InputReader):
        raise ValueError('input_reader_config not of type '
                         'input_reader_pb2.InputReader.')

    if input_reader_config.WhichOneof(
            'input_reader') == 'tf_record_input_reader':
        config = input_reader_config.tf_record_input_reader
        _, string_tensor = parallel_reader.parallel_read(
            config.input_path,
            reader_class=tf.TFRecordReader,
            num_epochs=(input_reader_config.num_epochs
                        if input_reader_config.num_epochs else None),
            num_readers=input_reader_config.num_readers,
            shuffle=input_reader_config.shuffle,
            dtypes=[tf.string, tf.string],
            capacity=input_reader_config.queue_capacity,
            min_after_dequeue=input_reader_config.min_after_dequeue)

        return tf_example_decoder.TfExampleDecoder().decode(string_tensor)

    raise ValueError('Unsupported input_reader_config.')
    def testDecodeDefaultGroundtruthWeights(self):
        image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
        encoded_jpeg = self._EncodeImage(image_tensor)
        bbox_ymins = [0.0, 4.0]
        bbox_xmins = [1.0, 5.0]
        bbox_ymaxs = [2.0, 6.0]
        bbox_xmaxs = [3.0, 7.0]
        example = tf.train.Example(features=tf.train.Features(
            feature={
                'image/encoded': self._BytesFeature(encoded_jpeg),
                'image/format': self._BytesFeature('jpeg'),
                'image/object/bbox/ymin': self._FloatFeature(bbox_ymins),
                'image/object/bbox/xmin': self._FloatFeature(bbox_xmins),
                'image/object/bbox/ymax': self._FloatFeature(bbox_ymaxs),
                'image/object/bbox/xmax': self._FloatFeature(bbox_xmaxs),
            })).SerializeToString()

        example_decoder = tf_example_decoder.TfExampleDecoder()
        tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

        self.assertAllEqual((tensor_dict[
            fields.InputDataFields.groundtruth_boxes].get_shape().as_list()),
                            [None, 4])

        with self.test_session() as sess:
            tensor_dict = sess.run(tensor_dict)

        self.assertAllClose(
            tensor_dict[fields.InputDataFields.groundtruth_weights],
            np.ones(2, dtype=np.float32))
    def _predict_input_fn(params=None):
        """Decodes serialized tf.Examples and returns `ServingInputReceiver`.

    Args:
      params: Parameter dictionary passed from the estimator.

    Returns:
      `ServingInputReceiver`.
    """
        del params
        example = tf.placeholder(dtype=tf.string,
                                 shape=[],
                                 name='input_feature')

        num_classes = config_util.get_number_of_classes(model_config)
        model = model_builder.build(model_config, is_training=False)
        image_resizer_config = config_util.get_image_resizer_config(
            model_config)
        image_resizer_fn = image_resizer_builder.build(image_resizer_config)

        transform_fn = functools.partial(transform_input_data,
                                         model_preprocess_fn=model.preprocess,
                                         image_resizer_fn=image_resizer_fn,
                                         num_classes=num_classes,
                                         data_augmentation_fn=None)

        decoder = tf_example_decoder.TfExampleDecoder(
            load_instance_masks=False)
        input_dict = transform_fn(decoder.decode(example))
        images = tf.to_float(input_dict[fields.InputDataFields.image])
        images = tf.expand_dims(images, axis=0)

        return tf.estimator.export.ServingInputReceiver(
            features={fields.InputDataFields.image: images},
            receiver_tensors={SERVING_FED_EXAMPLE_KEY: example})
Example #5
0
    def testDecodeRotatedBoundingBox(self):
        image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
        encoded_jpeg = self._EncodeImage(image_tensor)
        rbbox_cy = [0.0, 4.0]
        rbbox_cx = [1.0, 5.0]
        rbbox_h = [2.0, 6.0]
        rbbox_w = [3.0, 7.0]
        rbbox_ang = [0.1, 0.2]
        example = tf.train.Example(features=tf.train.Features(
            feature={
                'image/encoded': self._BytesFeature(encoded_jpeg),
                'image/format': self._BytesFeature('jpeg'.encode('utf8')),
                'image/object/rbbox/cy': self._FloatFeature(rbbox_cy),
                'image/object/rbbox/cx': self._FloatFeature(rbbox_cx),
                'image/object/rbbox/h': self._FloatFeature(rbbox_h),
                'image/object/rbbox/w': self._FloatFeature(rbbox_w),
                'image/object/rbbox/ang': self._FloatFeature(rbbox_ang),
            })).SerializeToString()

        example_decoder = tf_example_decoder.TfExampleDecoder()
        tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

        self.assertAllEqual((tensor_dict[
            fields.InputDataFields.groundtruth_rboxes].get_shape().as_list()),
                            [None, 5])
        with self.test_session() as sess:
            tensor_dict = sess.run(tensor_dict)

        expected_rboxes = np.vstack(
            [rbbox_cy, rbbox_cx, rbbox_h, rbbox_w,
             rbbox_ang]).transpose().astype(np.float32)
        self.assertAllEqual(
            expected_rboxes,
            tensor_dict[fields.InputDataFields.groundtruth_rboxes])
Example #6
0
 def decode(tf_example_string_tensor):
     tensor_dict = tf_example_decoder.TfExampleDecoder().decode(
         tf_example_string_tensor)
     image_tensor = tensor_dict[fields.InputDataFields.image]
     if input_shape is not None:
         image_tensor = tf.image.resize(image_tensor, input_shape[1:3])
     return image_tensor
    def testDecodePngInstanceMasks(self):
        image_tensor = np.random.randint(256,
                                         size=(10, 10, 3)).astype(np.uint8)
        encoded_jpeg = self._EncodeImage(image_tensor)
        mask_1 = np.random.randint(0, 2, size=(10, 10, 1)).astype(np.uint8)
        mask_2 = np.random.randint(0, 2, size=(10, 10, 1)).astype(np.uint8)
        encoded_png_1 = self._EncodeImage(mask_1, encoding_type='png')
        decoded_png_1 = np.squeeze(mask_1.astype(np.float32))
        encoded_png_2 = self._EncodeImage(mask_2, encoding_type='png')
        decoded_png_2 = np.squeeze(mask_2.astype(np.float32))
        encoded_masks = [encoded_png_1, encoded_png_2]
        decoded_masks = np.stack([decoded_png_1, decoded_png_2])
        example = tf.train.Example(features=tf.train.Features(
            feature={
                'image/encoded': self._BytesFeature(encoded_jpeg),
                'image/format': self._BytesFeature('jpeg'),
                'image/object/mask': self._BytesFeature(encoded_masks)
            })).SerializeToString()

        example_decoder = tf_example_decoder.TfExampleDecoder(
            load_instance_masks=True,
            instance_mask_type=input_reader_pb2.PNG_MASKS)
        tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

        with self.test_session() as sess:
            tensor_dict = sess.run(tensor_dict)

        self.assertAllEqual(
            decoded_masks,
            tensor_dict[fields.InputDataFields.groundtruth_instance_masks])
Example #8
0
 def _read_and_decode(queue):
     reader = tf.TFRecordReader()
     _, serialized_example = reader.read(queue)
     example_decoder = tf_example_decoder.TfExampleDecoder(dtype=dtype)
     tensor_dict = example_decoder.decode(
         tf.convert_to_tensor(serialized_example))
     return tensor_dict
Example #9
0
    def testDecodeJpegImage(self):
        image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
        encoded_jpeg = self._EncodeImage(image_tensor)
        decoded_jpeg = self._DecodeImage(encoded_jpeg)
        example = tf.train.Example(features=tf.train.Features(
            feature={
                'image/encoded': self._BytesFeature(encoded_jpeg),
                'image/format': self._BytesFeature('jpeg'.encode('utf8')),
                'image/source_id': self._BytesFeature('image_id'.encode(
                    'utf8')),
            })).SerializeToString()

        example_decoder = tf_example_decoder.TfExampleDecoder()
        tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

        self.assertAllEqual(
            (tensor_dict[fields.InputDataFields.image].get_shape().as_list()),
            [None, None, 3])
        with self.test_session() as sess:
            tensor_dict = sess.run(tensor_dict)

        self.assertAllEqual(decoded_jpeg,
                            tensor_dict[fields.InputDataFields.image])
        self.assertEqual(b'image_id',
                         tensor_dict[fields.InputDataFields.source_id])
    def testInstancesNotAvailableByDefault(self):
        num_instances = 4
        image_height = 5
        image_width = 3
        # Randomly generate image.
        image_tensor = np.random.randint(256,
                                         size=(image_height, image_width,
                                               3)).astype(np.uint8)
        encoded_jpeg = self._EncodeImage(image_tensor)

        # Randomly generate instance segmentation masks.
        instance_masks = (np.random.randint(2,
                                            size=(num_instances, image_height,
                                                  image_width)).astype(
                                                      np.float32))
        instance_masks_flattened = np.reshape(instance_masks, [-1])

        # Randomly generate class labels for each instance.
        object_classes = np.random.randint(100, size=(num_instances)).astype(
            np.int64)

        example = tf.train.Example(features=tf.train.Features(
            feature={
                'image/encoded': self._BytesFeature(encoded_jpeg),
                'image/format': self._BytesFeature('jpeg'),
                'image/height': self._Int64Feature([image_height]),
                'image/width': self._Int64Feature([image_width]),
                'image/object/mask': self._FloatFeature(
                    instance_masks_flattened),
                'image/object/class/label': self._Int64Feature(object_classes)
            })).SerializeToString()
        example_decoder = tf_example_decoder.TfExampleDecoder()
        tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
        self.assertTrue(fields.InputDataFields.groundtruth_instance_masks
                        not in tensor_dict)
Example #11
0
    def testDecodeInstanceSegmentation(self):
        num_instances = 4
        image_height = 5
        image_width = 3

        # Randomly generate image.
        image_tensor = np.random.randint(255,
                                         size=(image_height, image_width,
                                               3)).astype(np.uint8)
        encoded_jpeg = self._EncodeImage(image_tensor)

        # Randomly generate instance segmentation masks.
        instance_segmentation = (np.random.randint(
            2,
            size=(num_instances, image_height, image_width)).astype(np.int64))

        # Randomly generate class labels for each instance.
        instance_segmentation_classes = np.random.randint(
            100, size=(num_instances)).astype(np.int64)

        example = tf.train.Example(features=tf.train.Features(
            feature={
                'image/encoded':
                self._BytesFeature(encoded_jpeg),
                'image/format':
                self._BytesFeature('jpeg'.encode('utf8')),
                'image/height':
                self._Int64Feature([image_height]),
                'image/width':
                self._Int64Feature([image_width]),
                'image/segmentation/object':
                self._Int64Feature(instance_segmentation.flatten()),
                'image/segmentation/object/class':
                self._Int64Feature(instance_segmentation_classes)
            })).SerializeToString()
        example_decoder = tf_example_decoder.TfExampleDecoder()
        tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

        self.assertAllEqual(
            (tensor_dict[fields.InputDataFields.groundtruth_instance_masks].
             get_shape().as_list()), [None, None, None])

        self.assertAllEqual(
            (tensor_dict[fields.InputDataFields.groundtruth_instance_classes].
             get_shape().as_list()), [None])

        with self.test_session() as sess:
            tensor_dict = sess.run(tensor_dict)

        self.assertAllEqual(
            instance_segmentation.astype(np.bool),
            tensor_dict[fields.InputDataFields.groundtruth_instance_masks])
        self.assertAllEqual(
            instance_segmentation_classes,
            tensor_dict[fields.InputDataFields.groundtruth_instance_classes])
def build(input_reader_config):
    """
    Builds a tensor dictionary based on the InputReader config
    :param input_reader_config:
    :return:  A tensor dict based on the input_reader_config.
    """
    _, string_tensor = parallel_reader.parallel_read(data_sources=input_reader_config.input_path,
                                                     reader_class=tf.TFRecordReader,
                                                     num_epochs=input_reader_config.num_epochs,
                                                     num_readers=input_reader_config.num_readers,
                                                     shuffle=input_reader_config.shuffle,
                                                     dtypes=[tf.string, tf.string],
                                                     capacity=input_reader_config.queue_capacity,
                                                     min_after_dequeue=input_reader_config.min_after_dequeue)
    decoder = tf_example_decoder.TfExampleDecoder()
    return decoder.decode(string_tensor)
def build(input_reader_config):
    """Builds a tensor dictionary based on the InputReader config.

  Args:
    input_reader_config: A input_reader_pb2.InputReader object.

  Returns:
    A tensor dict based on the input_reader_config.

  Raises:
    ValueError: On invalid input reader proto.
    ValueError: If no input paths are specified.
  """
    if not isinstance(input_reader_config, input_reader_pb2.InputReader):
        raise ValueError('input_reader_config not of type '
                         'input_reader_pb2.InputReader.')

    if input_reader_config.WhichOneof(
            'input_reader') == 'tf_record_input_reader':
        config = input_reader_config.tf_record_input_reader
        if not config.input_path:
            raise ValueError('At least one input path must be specified in '
                             '`input_reader_config`.')
        _, string_tensor = parallel_reader.parallel_read(
            config.input_path[:],  # Convert `RepeatedScalarContainer` to list.
            reader_class=tf.TFRecordReader,
            num_epochs=(input_reader_config.num_epochs
                        if input_reader_config.num_epochs else None),
            num_readers=input_reader_config.num_readers,
            shuffle=input_reader_config.shuffle,
            dtypes=[tf.string, tf.string],
            capacity=input_reader_config.queue_capacity,
            min_after_dequeue=input_reader_config.min_after_dequeue)

        label_map_proto_file = None
        if input_reader_config.HasField('label_map_path'):
            label_map_proto_file = input_reader_config.label_map_path
        decoder = tf_example_decoder.TfExampleDecoder(
            load_instance_masks=input_reader_config.load_instance_masks,
            instance_mask_type=input_reader_config.mask_type,
            label_map_proto_file=label_map_proto_file)
        return decoder.decode(string_tensor)

    raise ValueError('Unsupported input_reader_config.')
    def testDecodeImageKeyAndFilename(self):
        image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
        encoded_jpeg = self._EncodeImage(image_tensor)
        example = tf.train.Example(features=tf.train.Features(
            feature={
                'image/encoded': self._BytesFeature(encoded_jpeg),
                'image/key/sha256': self._BytesFeature('abc'),
                'image/filename': self._BytesFeature('filename')
            })).SerializeToString()

        example_decoder = tf_example_decoder.TfExampleDecoder()
        tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

        with self.test_session() as sess:
            tensor_dict = sess.run(tensor_dict)

        self.assertEqual('abc', tensor_dict[fields.InputDataFields.key])
        self.assertEqual('filename',
                         tensor_dict[fields.InputDataFields.filename])
def build(input_reader_config, batch_size=None, transform_input_data_fn=None):
    if not isinstance(input_reader_config, input_reader_pb2.InputReader):
        raise ValueError()

    if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader':
        config = input_reader_config.tf_record_input_reader
        if not config.input_path:
            raise ValueError('')

        decoder = tf_example_decoder.TfExampleDecoder()

        def process_fn(value):
            processed_tensors = decoder.decode(value)
            if transform_input_data_fn is not None:
                processed_tensors = transform_input_data_fn(processed_tensors)
            return processed_tensors

        # Read input_path to string datatset
        dataset = read_dataset(
            functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000),
            config.input_path[:],
            input_reader_config)


        if batch_size:
            num_parallel_calls = batch_size * input_reader_config.num_parallel_batches
        else:
            num_parallel_calls = input_reader_config.num_parallel_map_calls

        if hasattr(dataset, 'map_with_legacy_function'):
            data_map_fn = dataset.map_with_legacy_function
        else:
            data_map_fn = dataset.map


        dataset = data_map_fn(process_fn, num_parallel_calls=num_parallel_calls)

        if batch_size:
            dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
        dataset = dataset.prefetch(input_reader_config.num_prefetch_batches)
        return dataset

    raise ValueError()
    def testDecodeObjectLabelNoText(self):
        image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
        encoded_jpeg = self._EncodeImage(image_tensor)
        bbox_classes = [1, 2]
        example = tf.train.Example(features=tf.train.Features(
            feature={
                'image/encoded': self._BytesFeature(encoded_jpeg),
                'image/format': self._BytesFeature('jpeg'),
                'image/object/class/label': self._Int64Feature(bbox_classes),
            })).SerializeToString()
        label_map_string = """
      item {
        id:1
        name:'cat'
      }
      item {
        id:2
        name:'dog'
      }
    """
        label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
        with tf.gfile.Open(label_map_path, 'wb') as f:
            f.write(label_map_string)

        example_decoder = tf_example_decoder.TfExampleDecoder(
            label_map_proto_file=label_map_path)
        tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

        self.assertAllEqual((tensor_dict[
            fields.InputDataFields.groundtruth_classes].get_shape().as_list()),
                            [None])

        init = tf.tables_initializer()
        with self.test_session() as sess:
            sess.run(init)
            tensor_dict = sess.run(tensor_dict)

        self.assertAllEqual(
            bbox_classes,
            tensor_dict[fields.InputDataFields.groundtruth_classes])
    def testDecodeObjectArea(self):
        image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
        encoded_jpeg = self._EncodeImage(image_tensor)
        object_area = [100., 174.]
        example = tf.train.Example(features=tf.train.Features(
            feature={
                'image/encoded': self._BytesFeature(encoded_jpeg),
                'image/format': self._BytesFeature('jpeg'),
                'image/object/area': self._FloatFeature(object_area),
            })).SerializeToString()

        example_decoder = tf_example_decoder.TfExampleDecoder()
        tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

        self.assertAllEqual((tensor_dict[
            fields.InputDataFields.groundtruth_area].get_shape().as_list()),
                            [None])
        with self.test_session() as sess:
            tensor_dict = sess.run(tensor_dict)

        self.assertAllEqual(
            object_area, tensor_dict[fields.InputDataFields.groundtruth_area])
Example #18
0
    def _predict_input_fn(params=None):
        del params
        example = tf.placeholder(dtype=tf.string, shape=[], name='tf_example')
        model = model_builder.build(model_config, is_training=False)
        image_resizer_config = config_utils.get_image_resizer_config(model_config)
        image_resizer_fn = image_resizer_builder.build(image_resizer_config)

        transform_fn = functools.partial(
            transform_input_data,
            model_preprocess_fn=model.preprocess,
            image_resizer_fn=image_resizer_fn)

        decoder = tf_example_decoder.TfExampleDecoder()
        input_dict = transform_fn(decoder.decode(example))
        images = tf.to_float(input_dict[fields.InputDataFields.image])
        images = tf.expand_dims(images, axis=0)
        true_image_shape = tf.expand_dims(
            input_dict[fields.InputDataFields.true_image_shape], axis=0)

        return tf.estimator.export.ServingInputReceiver(features={
            fields.InputDataFields.image:images,
            fields.InputDataFields.true_image_shape:true_image_shape},
            receiver_tensors={SERVING_FED_EXAMPLE_KEY: example})
    def testDecodeObjectGroupOf(self):
        image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
        encoded_jpeg = self._EncodeImage(image_tensor)
        object_group_of = [0, 1]
        example = tf.train.Example(features=tf.train.Features(
            feature={
                'image/encoded': self._BytesFeature(encoded_jpeg),
                'image/format': self._BytesFeature('jpeg'),
                'image/object/group_of': self._Int64Feature(object_group_of),
            })).SerializeToString()

        example_decoder = tf_example_decoder.TfExampleDecoder()
        tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

        self.assertAllEqual(
            (tensor_dict[fields.InputDataFields.groundtruth_group_of].
             get_shape().as_list()), [None])
        with self.test_session() as sess:
            tensor_dict = sess.run(tensor_dict)

        self.assertAllEqual(
            [bool(item) for item in object_group_of],
            tensor_dict[fields.InputDataFields.groundtruth_group_of])
Example #20
0
    def testDecodeObjectLabel(self):
        image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
        encoded_jpeg = self._EncodeImage(image_tensor)
        bbox_classes = [0, 1]
        example = tf.train.Example(features=tf.train.Features(
            feature={
                'image/encoded': self._BytesFeature(encoded_jpeg),
                'image/format': self._BytesFeature('jpeg'.encode('utf8')),
                'image/object/class/label': self._Int64Feature(bbox_classes),
            })).SerializeToString()

        example_decoder = tf_example_decoder.TfExampleDecoder()
        tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

        self.assertAllEqual((tensor_dict[
            fields.InputDataFields.groundtruth_classes].get_shape().as_list()),
                            [None])

        with self.test_session() as sess:
            tensor_dict = sess.run(tensor_dict)

        self.assertAllEqual(
            bbox_classes,
            tensor_dict[fields.InputDataFields.groundtruth_classes])
def build(input_reader_config, batch_size=None, transform_input_data_fn=None):
    """Builds a tf.data.Dataset.

  Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all
  records. Applies a padded batch to the resulting dataset.

  Args:
    input_reader_config: A input_reader_pb2.InputReader object.
    batch_size: Batch size. If batch size is None, no batching is performed.
    transform_input_data_fn: Function to apply transformation to all records,
      or None if no extra decoding is required.

  Returns:
    A tf.data.Dataset based on the input_reader_config.

  Raises:
    ValueError: On invalid input reader proto.
    ValueError: If no input paths are specified.
  """
    if not isinstance(input_reader_config, input_reader_pb2.InputReader):
        raise ValueError('input_reader_config not of type '
                         'input_reader_pb2.InputReader.')

    if input_reader_config.WhichOneof(
            'input_reader') == 'tf_record_input_reader':
        config = input_reader_config.tf_record_input_reader
        if not config.input_path:
            raise ValueError('At least one input path must be specified in '
                             '`input_reader_config`.')

        label_map_proto_file = None
        if input_reader_config.HasField('label_map_path'):
            label_map_proto_file = input_reader_config.label_map_path
        decoder = tf_example_decoder.TfExampleDecoder(
            load_instance_masks=input_reader_config.load_instance_masks,
            load_multiclass_scores=input_reader_config.load_multiclass_scores,
            instance_mask_type=input_reader_config.mask_type,
            label_map_proto_file=label_map_proto_file,
            use_display_name=input_reader_config.use_display_name,
            num_additional_channels=input_reader_config.num_additional_channels
        )

        def process_fn(value):
            """Sets up tf graph that decodes, transforms and pads input data."""
            processed_tensors = decoder.decode(value)
            if transform_input_data_fn is not None:
                processed_tensors = transform_input_data_fn(processed_tensors)
            return processed_tensors

        dataset = read_dataset(
            functools.partial(tf.data.TFRecordDataset,
                              buffer_size=8 * 1000 * 1000),
            config.input_path[:], input_reader_config)
        if input_reader_config.sample_1_of_n_examples > 1:
            dataset = dataset.shard(input_reader_config.sample_1_of_n_examples,
                                    0)
        # TODO(rathodv): make batch size a required argument once the old binaries
        # are deleted.
        if batch_size:
            num_parallel_calls = batch_size * input_reader_config.num_parallel_batches
        else:
            num_parallel_calls = input_reader_config.num_parallel_map_calls
        # TODO(b/123952794): Migrate to V2 function.
        if hasattr(dataset, 'map_with_legacy_function'):
            data_map_fn = dataset.map_with_legacy_function
        else:
            data_map_fn = dataset.map
        dataset = data_map_fn(process_fn,
                              num_parallel_calls=num_parallel_calls)
        if batch_size:
            dataset = dataset.apply(
                tf.contrib.data.batch_and_drop_remainder(batch_size))
        dataset = dataset.prefetch(input_reader_config.num_prefetch_batches)
        return dataset

    raise ValueError('Unsupported input_reader_config.')
 def decode(tf_example_string_tensor):
     tensor_dict = tf_example_decoder.TfExampleDecoder().decode(
         tf_example_string_tensor)
     image_tensor = tensor_dict[fields.InputDataFields.image]
     return image_tensor
def build(input_reader_config,
          transform_input_data_fn=None,
          batch_size=None,
          max_num_boxes=None,
          num_classes=None,
          spatial_image_shape=None):
    """Builds a tf.data.Dataset.

  Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all
  records. Applies a padded batch to the resulting dataset.

  Args:
    input_reader_config: A input_reader_pb2.InputReader object.
    transform_input_data_fn: Function to apply to all records, or None if
      no extra decoding is required.
    batch_size: Batch size. If None, batching is not performed.
    max_num_boxes: Max number of groundtruth boxes needed to compute shapes for
      padding. If None, will use a dynamic shape.
    num_classes: Number of classes in the dataset needed to compute shapes for
      padding. If None, will use a dynamic shape.
    spatial_image_shape: A list of two integers of the form [height, width]
      containing expected spatial shape of the image after applying
      transform_input_data_fn. If None, will use dynamic shapes.

  Returns:
    A tf.data.Dataset based on the input_reader_config.

  Raises:
    ValueError: On invalid input reader proto.
    ValueError: If no input paths are specified.
  """
    if not isinstance(input_reader_config, input_reader_pb2.InputReader):
        raise ValueError('input_reader_config not of type '
                         'input_reader_pb2.InputReader.')

    if input_reader_config.WhichOneof(
            'input_reader') == 'tf_record_input_reader':
        config = input_reader_config.tf_record_input_reader
        if not config.input_path:
            raise ValueError('At least one input path must be specified in '
                             '`input_reader_config`.')

        label_map_proto_file = None
        if input_reader_config.HasField('label_map_path'):
            label_map_proto_file = input_reader_config.label_map_path
        decoder = tf_example_decoder.TfExampleDecoder(
            load_instance_masks=input_reader_config.load_instance_masks,
            instance_mask_type=input_reader_config.mask_type,
            label_map_proto_file=label_map_proto_file)

        def process_fn(value):
            processed = decoder.decode(value)
            if transform_input_data_fn is not None:
                return transform_input_data_fn(processed)
            return processed

        dataset = dataset_util.read_dataset(
            functools.partial(tf.data.TFRecordDataset,
                              buffer_size=8 * 1000 * 1000), process_fn,
            config.input_path[:], input_reader_config)

        if batch_size:
            padding_shapes = _get_padding_shapes(dataset, max_num_boxes,
                                                 num_classes,
                                                 spatial_image_shape)
            dataset = dataset.apply(
                tf.contrib.data.padded_batch_and_drop_remainder(
                    batch_size, padding_shapes))
        return dataset

    raise ValueError('Unsupported input_reader_config.')