def testDecodeEmptyPngInstanceMasks(self): image_tensor = np.random.randint(256, size=(10, 10, 3)).astype(np.uint8) encoded_jpeg = self._EncodeImage(image_tensor) encoded_masks = [] example = tf.train.Example(features=tf.train.Features( feature={ 'image/encoded': self._BytesFeature(encoded_jpeg), 'image/format': self._BytesFeature('jpeg'), 'image/object/mask': self._BytesFeature(encoded_masks), 'image/height': self._Int64Feature([10]), 'image/width': self._Int64Feature([10]), })).SerializeToString() example_decoder = tf_example_decoder.TfExampleDecoder( load_instance_masks=True, instance_mask_type=input_reader_pb2.PNG_MASKS) tensor_dict = example_decoder.decode(tf.convert_to_tensor(example)) with self.test_session() as sess: tensor_dict = sess.run(tensor_dict) self.assertAllEqual( tensor_dict[ fields.InputDataFields.groundtruth_instance_masks].shape, [0, 10, 10])
def build(input_reader_config): """Builds a tensor dictionary based on the InputReader config. Args: input_reader_config: A input_reader_pb2.InputReader object. Returns: A tensor dict based on the input_reader_config. Raises: ValueError: On invalid input reader proto. """ if not isinstance(input_reader_config, input_reader_pb2.InputReader): raise ValueError('input_reader_config not of type ' 'input_reader_pb2.InputReader.') if input_reader_config.WhichOneof( 'input_reader') == 'tf_record_input_reader': config = input_reader_config.tf_record_input_reader _, string_tensor = parallel_reader.parallel_read( config.input_path, reader_class=tf.TFRecordReader, num_epochs=(input_reader_config.num_epochs if input_reader_config.num_epochs else None), num_readers=input_reader_config.num_readers, shuffle=input_reader_config.shuffle, dtypes=[tf.string, tf.string], capacity=input_reader_config.queue_capacity, min_after_dequeue=input_reader_config.min_after_dequeue) return tf_example_decoder.TfExampleDecoder().decode(string_tensor) raise ValueError('Unsupported input_reader_config.')
def testDecodeDefaultGroundtruthWeights(self): image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8) encoded_jpeg = self._EncodeImage(image_tensor) bbox_ymins = [0.0, 4.0] bbox_xmins = [1.0, 5.0] bbox_ymaxs = [2.0, 6.0] bbox_xmaxs = [3.0, 7.0] example = tf.train.Example(features=tf.train.Features( feature={ 'image/encoded': self._BytesFeature(encoded_jpeg), 'image/format': self._BytesFeature('jpeg'), 'image/object/bbox/ymin': self._FloatFeature(bbox_ymins), 'image/object/bbox/xmin': self._FloatFeature(bbox_xmins), 'image/object/bbox/ymax': self._FloatFeature(bbox_ymaxs), 'image/object/bbox/xmax': self._FloatFeature(bbox_xmaxs), })).SerializeToString() example_decoder = tf_example_decoder.TfExampleDecoder() tensor_dict = example_decoder.decode(tf.convert_to_tensor(example)) self.assertAllEqual((tensor_dict[ fields.InputDataFields.groundtruth_boxes].get_shape().as_list()), [None, 4]) with self.test_session() as sess: tensor_dict = sess.run(tensor_dict) self.assertAllClose( tensor_dict[fields.InputDataFields.groundtruth_weights], np.ones(2, dtype=np.float32))
def _predict_input_fn(params=None): """Decodes serialized tf.Examples and returns `ServingInputReceiver`. Args: params: Parameter dictionary passed from the estimator. Returns: `ServingInputReceiver`. """ del params example = tf.placeholder(dtype=tf.string, shape=[], name='input_feature') num_classes = config_util.get_number_of_classes(model_config) model = model_builder.build(model_config, is_training=False) image_resizer_config = config_util.get_image_resizer_config( model_config) image_resizer_fn = image_resizer_builder.build(image_resizer_config) transform_fn = functools.partial(transform_input_data, model_preprocess_fn=model.preprocess, image_resizer_fn=image_resizer_fn, num_classes=num_classes, data_augmentation_fn=None) decoder = tf_example_decoder.TfExampleDecoder( load_instance_masks=False) input_dict = transform_fn(decoder.decode(example)) images = tf.to_float(input_dict[fields.InputDataFields.image]) images = tf.expand_dims(images, axis=0) return tf.estimator.export.ServingInputReceiver( features={fields.InputDataFields.image: images}, receiver_tensors={SERVING_FED_EXAMPLE_KEY: example})
def testDecodeRotatedBoundingBox(self): image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8) encoded_jpeg = self._EncodeImage(image_tensor) rbbox_cy = [0.0, 4.0] rbbox_cx = [1.0, 5.0] rbbox_h = [2.0, 6.0] rbbox_w = [3.0, 7.0] rbbox_ang = [0.1, 0.2] example = tf.train.Example(features=tf.train.Features( feature={ 'image/encoded': self._BytesFeature(encoded_jpeg), 'image/format': self._BytesFeature('jpeg'.encode('utf8')), 'image/object/rbbox/cy': self._FloatFeature(rbbox_cy), 'image/object/rbbox/cx': self._FloatFeature(rbbox_cx), 'image/object/rbbox/h': self._FloatFeature(rbbox_h), 'image/object/rbbox/w': self._FloatFeature(rbbox_w), 'image/object/rbbox/ang': self._FloatFeature(rbbox_ang), })).SerializeToString() example_decoder = tf_example_decoder.TfExampleDecoder() tensor_dict = example_decoder.decode(tf.convert_to_tensor(example)) self.assertAllEqual((tensor_dict[ fields.InputDataFields.groundtruth_rboxes].get_shape().as_list()), [None, 5]) with self.test_session() as sess: tensor_dict = sess.run(tensor_dict) expected_rboxes = np.vstack( [rbbox_cy, rbbox_cx, rbbox_h, rbbox_w, rbbox_ang]).transpose().astype(np.float32) self.assertAllEqual( expected_rboxes, tensor_dict[fields.InputDataFields.groundtruth_rboxes])
def decode(tf_example_string_tensor): tensor_dict = tf_example_decoder.TfExampleDecoder().decode( tf_example_string_tensor) image_tensor = tensor_dict[fields.InputDataFields.image] if input_shape is not None: image_tensor = tf.image.resize(image_tensor, input_shape[1:3]) return image_tensor
def testDecodePngInstanceMasks(self): image_tensor = np.random.randint(256, size=(10, 10, 3)).astype(np.uint8) encoded_jpeg = self._EncodeImage(image_tensor) mask_1 = np.random.randint(0, 2, size=(10, 10, 1)).astype(np.uint8) mask_2 = np.random.randint(0, 2, size=(10, 10, 1)).astype(np.uint8) encoded_png_1 = self._EncodeImage(mask_1, encoding_type='png') decoded_png_1 = np.squeeze(mask_1.astype(np.float32)) encoded_png_2 = self._EncodeImage(mask_2, encoding_type='png') decoded_png_2 = np.squeeze(mask_2.astype(np.float32)) encoded_masks = [encoded_png_1, encoded_png_2] decoded_masks = np.stack([decoded_png_1, decoded_png_2]) example = tf.train.Example(features=tf.train.Features( feature={ 'image/encoded': self._BytesFeature(encoded_jpeg), 'image/format': self._BytesFeature('jpeg'), 'image/object/mask': self._BytesFeature(encoded_masks) })).SerializeToString() example_decoder = tf_example_decoder.TfExampleDecoder( load_instance_masks=True, instance_mask_type=input_reader_pb2.PNG_MASKS) tensor_dict = example_decoder.decode(tf.convert_to_tensor(example)) with self.test_session() as sess: tensor_dict = sess.run(tensor_dict) self.assertAllEqual( decoded_masks, tensor_dict[fields.InputDataFields.groundtruth_instance_masks])
def _read_and_decode(queue): reader = tf.TFRecordReader() _, serialized_example = reader.read(queue) example_decoder = tf_example_decoder.TfExampleDecoder(dtype=dtype) tensor_dict = example_decoder.decode( tf.convert_to_tensor(serialized_example)) return tensor_dict
def testDecodeJpegImage(self): image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8) encoded_jpeg = self._EncodeImage(image_tensor) decoded_jpeg = self._DecodeImage(encoded_jpeg) example = tf.train.Example(features=tf.train.Features( feature={ 'image/encoded': self._BytesFeature(encoded_jpeg), 'image/format': self._BytesFeature('jpeg'.encode('utf8')), 'image/source_id': self._BytesFeature('image_id'.encode( 'utf8')), })).SerializeToString() example_decoder = tf_example_decoder.TfExampleDecoder() tensor_dict = example_decoder.decode(tf.convert_to_tensor(example)) self.assertAllEqual( (tensor_dict[fields.InputDataFields.image].get_shape().as_list()), [None, None, 3]) with self.test_session() as sess: tensor_dict = sess.run(tensor_dict) self.assertAllEqual(decoded_jpeg, tensor_dict[fields.InputDataFields.image]) self.assertEqual(b'image_id', tensor_dict[fields.InputDataFields.source_id])
def testInstancesNotAvailableByDefault(self): num_instances = 4 image_height = 5 image_width = 3 # Randomly generate image. image_tensor = np.random.randint(256, size=(image_height, image_width, 3)).astype(np.uint8) encoded_jpeg = self._EncodeImage(image_tensor) # Randomly generate instance segmentation masks. instance_masks = (np.random.randint(2, size=(num_instances, image_height, image_width)).astype( np.float32)) instance_masks_flattened = np.reshape(instance_masks, [-1]) # Randomly generate class labels for each instance. object_classes = np.random.randint(100, size=(num_instances)).astype( np.int64) example = tf.train.Example(features=tf.train.Features( feature={ 'image/encoded': self._BytesFeature(encoded_jpeg), 'image/format': self._BytesFeature('jpeg'), 'image/height': self._Int64Feature([image_height]), 'image/width': self._Int64Feature([image_width]), 'image/object/mask': self._FloatFeature( instance_masks_flattened), 'image/object/class/label': self._Int64Feature(object_classes) })).SerializeToString() example_decoder = tf_example_decoder.TfExampleDecoder() tensor_dict = example_decoder.decode(tf.convert_to_tensor(example)) self.assertTrue(fields.InputDataFields.groundtruth_instance_masks not in tensor_dict)
def testDecodeInstanceSegmentation(self): num_instances = 4 image_height = 5 image_width = 3 # Randomly generate image. image_tensor = np.random.randint(255, size=(image_height, image_width, 3)).astype(np.uint8) encoded_jpeg = self._EncodeImage(image_tensor) # Randomly generate instance segmentation masks. instance_segmentation = (np.random.randint( 2, size=(num_instances, image_height, image_width)).astype(np.int64)) # Randomly generate class labels for each instance. instance_segmentation_classes = np.random.randint( 100, size=(num_instances)).astype(np.int64) example = tf.train.Example(features=tf.train.Features( feature={ 'image/encoded': self._BytesFeature(encoded_jpeg), 'image/format': self._BytesFeature('jpeg'.encode('utf8')), 'image/height': self._Int64Feature([image_height]), 'image/width': self._Int64Feature([image_width]), 'image/segmentation/object': self._Int64Feature(instance_segmentation.flatten()), 'image/segmentation/object/class': self._Int64Feature(instance_segmentation_classes) })).SerializeToString() example_decoder = tf_example_decoder.TfExampleDecoder() tensor_dict = example_decoder.decode(tf.convert_to_tensor(example)) self.assertAllEqual( (tensor_dict[fields.InputDataFields.groundtruth_instance_masks]. get_shape().as_list()), [None, None, None]) self.assertAllEqual( (tensor_dict[fields.InputDataFields.groundtruth_instance_classes]. get_shape().as_list()), [None]) with self.test_session() as sess: tensor_dict = sess.run(tensor_dict) self.assertAllEqual( instance_segmentation.astype(np.bool), tensor_dict[fields.InputDataFields.groundtruth_instance_masks]) self.assertAllEqual( instance_segmentation_classes, tensor_dict[fields.InputDataFields.groundtruth_instance_classes])
def build(input_reader_config): """ Builds a tensor dictionary based on the InputReader config :param input_reader_config: :return: A tensor dict based on the input_reader_config. """ _, string_tensor = parallel_reader.parallel_read(data_sources=input_reader_config.input_path, reader_class=tf.TFRecordReader, num_epochs=input_reader_config.num_epochs, num_readers=input_reader_config.num_readers, shuffle=input_reader_config.shuffle, dtypes=[tf.string, tf.string], capacity=input_reader_config.queue_capacity, min_after_dequeue=input_reader_config.min_after_dequeue) decoder = tf_example_decoder.TfExampleDecoder() return decoder.decode(string_tensor)
def build(input_reader_config): """Builds a tensor dictionary based on the InputReader config. Args: input_reader_config: A input_reader_pb2.InputReader object. Returns: A tensor dict based on the input_reader_config. Raises: ValueError: On invalid input reader proto. ValueError: If no input paths are specified. """ if not isinstance(input_reader_config, input_reader_pb2.InputReader): raise ValueError('input_reader_config not of type ' 'input_reader_pb2.InputReader.') if input_reader_config.WhichOneof( 'input_reader') == 'tf_record_input_reader': config = input_reader_config.tf_record_input_reader if not config.input_path: raise ValueError('At least one input path must be specified in ' '`input_reader_config`.') _, string_tensor = parallel_reader.parallel_read( config.input_path[:], # Convert `RepeatedScalarContainer` to list. reader_class=tf.TFRecordReader, num_epochs=(input_reader_config.num_epochs if input_reader_config.num_epochs else None), num_readers=input_reader_config.num_readers, shuffle=input_reader_config.shuffle, dtypes=[tf.string, tf.string], capacity=input_reader_config.queue_capacity, min_after_dequeue=input_reader_config.min_after_dequeue) label_map_proto_file = None if input_reader_config.HasField('label_map_path'): label_map_proto_file = input_reader_config.label_map_path decoder = tf_example_decoder.TfExampleDecoder( load_instance_masks=input_reader_config.load_instance_masks, instance_mask_type=input_reader_config.mask_type, label_map_proto_file=label_map_proto_file) return decoder.decode(string_tensor) raise ValueError('Unsupported input_reader_config.')
def testDecodeImageKeyAndFilename(self): image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8) encoded_jpeg = self._EncodeImage(image_tensor) example = tf.train.Example(features=tf.train.Features( feature={ 'image/encoded': self._BytesFeature(encoded_jpeg), 'image/key/sha256': self._BytesFeature('abc'), 'image/filename': self._BytesFeature('filename') })).SerializeToString() example_decoder = tf_example_decoder.TfExampleDecoder() tensor_dict = example_decoder.decode(tf.convert_to_tensor(example)) with self.test_session() as sess: tensor_dict = sess.run(tensor_dict) self.assertEqual('abc', tensor_dict[fields.InputDataFields.key]) self.assertEqual('filename', tensor_dict[fields.InputDataFields.filename])
def build(input_reader_config, batch_size=None, transform_input_data_fn=None): if not isinstance(input_reader_config, input_reader_pb2.InputReader): raise ValueError() if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader': config = input_reader_config.tf_record_input_reader if not config.input_path: raise ValueError('') decoder = tf_example_decoder.TfExampleDecoder() def process_fn(value): processed_tensors = decoder.decode(value) if transform_input_data_fn is not None: processed_tensors = transform_input_data_fn(processed_tensors) return processed_tensors # Read input_path to string datatset dataset = read_dataset( functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000), config.input_path[:], input_reader_config) if batch_size: num_parallel_calls = batch_size * input_reader_config.num_parallel_batches else: num_parallel_calls = input_reader_config.num_parallel_map_calls if hasattr(dataset, 'map_with_legacy_function'): data_map_fn = dataset.map_with_legacy_function else: data_map_fn = dataset.map dataset = data_map_fn(process_fn, num_parallel_calls=num_parallel_calls) if batch_size: dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size)) dataset = dataset.prefetch(input_reader_config.num_prefetch_batches) return dataset raise ValueError()
def testDecodeObjectLabelNoText(self): image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8) encoded_jpeg = self._EncodeImage(image_tensor) bbox_classes = [1, 2] example = tf.train.Example(features=tf.train.Features( feature={ 'image/encoded': self._BytesFeature(encoded_jpeg), 'image/format': self._BytesFeature('jpeg'), 'image/object/class/label': self._Int64Feature(bbox_classes), })).SerializeToString() label_map_string = """ item { id:1 name:'cat' } item { id:2 name:'dog' } """ label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') with tf.gfile.Open(label_map_path, 'wb') as f: f.write(label_map_string) example_decoder = tf_example_decoder.TfExampleDecoder( label_map_proto_file=label_map_path) tensor_dict = example_decoder.decode(tf.convert_to_tensor(example)) self.assertAllEqual((tensor_dict[ fields.InputDataFields.groundtruth_classes].get_shape().as_list()), [None]) init = tf.tables_initializer() with self.test_session() as sess: sess.run(init) tensor_dict = sess.run(tensor_dict) self.assertAllEqual( bbox_classes, tensor_dict[fields.InputDataFields.groundtruth_classes])
def testDecodeObjectArea(self): image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8) encoded_jpeg = self._EncodeImage(image_tensor) object_area = [100., 174.] example = tf.train.Example(features=tf.train.Features( feature={ 'image/encoded': self._BytesFeature(encoded_jpeg), 'image/format': self._BytesFeature('jpeg'), 'image/object/area': self._FloatFeature(object_area), })).SerializeToString() example_decoder = tf_example_decoder.TfExampleDecoder() tensor_dict = example_decoder.decode(tf.convert_to_tensor(example)) self.assertAllEqual((tensor_dict[ fields.InputDataFields.groundtruth_area].get_shape().as_list()), [None]) with self.test_session() as sess: tensor_dict = sess.run(tensor_dict) self.assertAllEqual( object_area, tensor_dict[fields.InputDataFields.groundtruth_area])
def _predict_input_fn(params=None): del params example = tf.placeholder(dtype=tf.string, shape=[], name='tf_example') model = model_builder.build(model_config, is_training=False) image_resizer_config = config_utils.get_image_resizer_config(model_config) image_resizer_fn = image_resizer_builder.build(image_resizer_config) transform_fn = functools.partial( transform_input_data, model_preprocess_fn=model.preprocess, image_resizer_fn=image_resizer_fn) decoder = tf_example_decoder.TfExampleDecoder() input_dict = transform_fn(decoder.decode(example)) images = tf.to_float(input_dict[fields.InputDataFields.image]) images = tf.expand_dims(images, axis=0) true_image_shape = tf.expand_dims( input_dict[fields.InputDataFields.true_image_shape], axis=0) return tf.estimator.export.ServingInputReceiver(features={ fields.InputDataFields.image:images, fields.InputDataFields.true_image_shape:true_image_shape}, receiver_tensors={SERVING_FED_EXAMPLE_KEY: example})
def testDecodeObjectGroupOf(self): image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8) encoded_jpeg = self._EncodeImage(image_tensor) object_group_of = [0, 1] example = tf.train.Example(features=tf.train.Features( feature={ 'image/encoded': self._BytesFeature(encoded_jpeg), 'image/format': self._BytesFeature('jpeg'), 'image/object/group_of': self._Int64Feature(object_group_of), })).SerializeToString() example_decoder = tf_example_decoder.TfExampleDecoder() tensor_dict = example_decoder.decode(tf.convert_to_tensor(example)) self.assertAllEqual( (tensor_dict[fields.InputDataFields.groundtruth_group_of]. get_shape().as_list()), [None]) with self.test_session() as sess: tensor_dict = sess.run(tensor_dict) self.assertAllEqual( [bool(item) for item in object_group_of], tensor_dict[fields.InputDataFields.groundtruth_group_of])
def testDecodeObjectLabel(self): image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8) encoded_jpeg = self._EncodeImage(image_tensor) bbox_classes = [0, 1] example = tf.train.Example(features=tf.train.Features( feature={ 'image/encoded': self._BytesFeature(encoded_jpeg), 'image/format': self._BytesFeature('jpeg'.encode('utf8')), 'image/object/class/label': self._Int64Feature(bbox_classes), })).SerializeToString() example_decoder = tf_example_decoder.TfExampleDecoder() tensor_dict = example_decoder.decode(tf.convert_to_tensor(example)) self.assertAllEqual((tensor_dict[ fields.InputDataFields.groundtruth_classes].get_shape().as_list()), [None]) with self.test_session() as sess: tensor_dict = sess.run(tensor_dict) self.assertAllEqual( bbox_classes, tensor_dict[fields.InputDataFields.groundtruth_classes])
def build(input_reader_config, batch_size=None, transform_input_data_fn=None): """Builds a tf.data.Dataset. Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all records. Applies a padded batch to the resulting dataset. Args: input_reader_config: A input_reader_pb2.InputReader object. batch_size: Batch size. If batch size is None, no batching is performed. transform_input_data_fn: Function to apply transformation to all records, or None if no extra decoding is required. Returns: A tf.data.Dataset based on the input_reader_config. Raises: ValueError: On invalid input reader proto. ValueError: If no input paths are specified. """ if not isinstance(input_reader_config, input_reader_pb2.InputReader): raise ValueError('input_reader_config not of type ' 'input_reader_pb2.InputReader.') if input_reader_config.WhichOneof( 'input_reader') == 'tf_record_input_reader': config = input_reader_config.tf_record_input_reader if not config.input_path: raise ValueError('At least one input path must be specified in ' '`input_reader_config`.') label_map_proto_file = None if input_reader_config.HasField('label_map_path'): label_map_proto_file = input_reader_config.label_map_path decoder = tf_example_decoder.TfExampleDecoder( load_instance_masks=input_reader_config.load_instance_masks, load_multiclass_scores=input_reader_config.load_multiclass_scores, instance_mask_type=input_reader_config.mask_type, label_map_proto_file=label_map_proto_file, use_display_name=input_reader_config.use_display_name, num_additional_channels=input_reader_config.num_additional_channels ) def process_fn(value): """Sets up tf graph that decodes, transforms and pads input data.""" processed_tensors = decoder.decode(value) if transform_input_data_fn is not None: processed_tensors = transform_input_data_fn(processed_tensors) return processed_tensors dataset = read_dataset( functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000), config.input_path[:], input_reader_config) if input_reader_config.sample_1_of_n_examples > 1: dataset = dataset.shard(input_reader_config.sample_1_of_n_examples, 0) # TODO(rathodv): make batch size a required argument once the old binaries # are deleted. if batch_size: num_parallel_calls = batch_size * input_reader_config.num_parallel_batches else: num_parallel_calls = input_reader_config.num_parallel_map_calls # TODO(b/123952794): Migrate to V2 function. if hasattr(dataset, 'map_with_legacy_function'): data_map_fn = dataset.map_with_legacy_function else: data_map_fn = dataset.map dataset = data_map_fn(process_fn, num_parallel_calls=num_parallel_calls) if batch_size: dataset = dataset.apply( tf.contrib.data.batch_and_drop_remainder(batch_size)) dataset = dataset.prefetch(input_reader_config.num_prefetch_batches) return dataset raise ValueError('Unsupported input_reader_config.')
def decode(tf_example_string_tensor): tensor_dict = tf_example_decoder.TfExampleDecoder().decode( tf_example_string_tensor) image_tensor = tensor_dict[fields.InputDataFields.image] return image_tensor
def build(input_reader_config, transform_input_data_fn=None, batch_size=None, max_num_boxes=None, num_classes=None, spatial_image_shape=None): """Builds a tf.data.Dataset. Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all records. Applies a padded batch to the resulting dataset. Args: input_reader_config: A input_reader_pb2.InputReader object. transform_input_data_fn: Function to apply to all records, or None if no extra decoding is required. batch_size: Batch size. If None, batching is not performed. max_num_boxes: Max number of groundtruth boxes needed to compute shapes for padding. If None, will use a dynamic shape. num_classes: Number of classes in the dataset needed to compute shapes for padding. If None, will use a dynamic shape. spatial_image_shape: A list of two integers of the form [height, width] containing expected spatial shape of the image after applying transform_input_data_fn. If None, will use dynamic shapes. Returns: A tf.data.Dataset based on the input_reader_config. Raises: ValueError: On invalid input reader proto. ValueError: If no input paths are specified. """ if not isinstance(input_reader_config, input_reader_pb2.InputReader): raise ValueError('input_reader_config not of type ' 'input_reader_pb2.InputReader.') if input_reader_config.WhichOneof( 'input_reader') == 'tf_record_input_reader': config = input_reader_config.tf_record_input_reader if not config.input_path: raise ValueError('At least one input path must be specified in ' '`input_reader_config`.') label_map_proto_file = None if input_reader_config.HasField('label_map_path'): label_map_proto_file = input_reader_config.label_map_path decoder = tf_example_decoder.TfExampleDecoder( load_instance_masks=input_reader_config.load_instance_masks, instance_mask_type=input_reader_config.mask_type, label_map_proto_file=label_map_proto_file) def process_fn(value): processed = decoder.decode(value) if transform_input_data_fn is not None: return transform_input_data_fn(processed) return processed dataset = dataset_util.read_dataset( functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000), process_fn, config.input_path[:], input_reader_config) if batch_size: padding_shapes = _get_padding_shapes(dataset, max_num_boxes, num_classes, spatial_image_shape) dataset = dataset.apply( tf.contrib.data.padded_batch_and_drop_remainder( batch_size, padding_shapes)) return dataset raise ValueError('Unsupported input_reader_config.')