def _get_dataset_next(self, files, config, batch_size): def decode_func(value): return [tf.string_to_number(value, out_type=tf.int32)] dataset = dataset_util.read_dataset(tf.data.TextLineDataset, decode_func, files, config) dataset = dataset.batch(batch_size) return dataset.make_one_shot_iterator().get_next()
def _get_dataset_next(self, files, config, batch_size): def decode_func(value): return [tf.string_to_number(value, out_type=tf.int32)] dataset = dataset_util.read_dataset( tf.data.TextLineDataset, decode_func, files, config) dataset = dataset.batch(batch_size) return dataset.make_one_shot_iterator().get_next()
def build(input_reader_config, transform_input_data_fn=None, batch_size=None, max_num_boxes=None, num_classes=None, spatial_image_shape=None): """Builds a tf.data.Dataset. Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all records. Applies a padded batch to the resulting dataset. Args: input_reader_config: A input_reader_pb2.InputReader object. transform_input_data_fn: Function to apply to all records, or None if no extra decoding is required. batch_size: Batch size. If None, batching is not performed. max_num_boxes: Max number of groundtruth boxes needed to compute shapes for padding. If None, will use a dynamic shape. num_classes: Number of classes in the dataset needed to compute shapes for padding. If None, will use a dynamic shape. spatial_image_shape: A list of two integers of the form [height, width] containing expected spatial shape of the image after applying transform_input_data_fn. If None, will use dynamic shapes. Returns: A tf.data.Dataset based on the input_reader_config. Raises: ValueError: On invalid input reader proto. ValueError: If no input paths are specified. """ if not isinstance(input_reader_config, input_reader_pb2.InputReader): raise ValueError('input_reader_config not of type ' 'input_reader_pb2.InputReader.') if input_reader_config.WhichOneof( 'input_reader') == 'tf_record_input_reader': config = input_reader_config.tf_record_input_reader if not config.input_path: raise ValueError('At least one input path must be specified in ' '`input_reader_config`.') label_map_proto_file = None if input_reader_config.HasField('label_map_path'): label_map_proto_file = input_reader_config.label_map_path decoder = tf_example_decoder.TfExampleDecoder( load_instance_masks=input_reader_config.load_instance_masks, instance_mask_type=input_reader_config.mask_type, label_map_proto_file=label_map_proto_file) def process_fn(value): processed = decoder.decode(value) if transform_input_data_fn is not None: return transform_input_data_fn(processed) return processed dataset = dataset_util.read_dataset( functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000), process_fn, config.input_path[:], input_reader_config) if batch_size: padding_shapes = _get_padding_shapes(dataset, max_num_boxes, num_classes, spatial_image_shape) dataset = dataset.apply( tf.contrib.data.padded_batch_and_drop_remainder( batch_size, padding_shapes)) return dataset raise ValueError('Unsupported input_reader_config.')
def build(input_reader_config, transform_input_data_fn=None, batch_size=None, max_num_boxes=None, num_classes=None, spatial_image_shape=None, num_additional_channels=0): """Builds a tf.data.Dataset. Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all records. Applies a padded batch to the resulting dataset. Args: input_reader_config: A input_reader_pb2.InputReader object. transform_input_data_fn: Function to apply to all records, or None if no extra decoding is required. batch_size: Batch size. If None, batching is not performed. max_num_boxes: Max number of groundtruth boxes needed to compute shapes for padding. If None, will use a dynamic shape. num_classes: Number of classes in the dataset needed to compute shapes for padding. If None, will use a dynamic shape. spatial_image_shape: A list of two integers of the form [height, width] containing expected spatial shape of the image after applying transform_input_data_fn. If None, will use dynamic shapes. num_additional_channels: Number of additional channels to use in the input. Returns: A tf.data.Dataset based on the input_reader_config. Raises: ValueError: On invalid input reader proto. ValueError: If no input paths are specified. """ if not isinstance(input_reader_config, input_reader_pb2.InputReader): raise ValueError('input_reader_config not of type ' 'input_reader_pb2.InputReader.') if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader': config = input_reader_config.tf_record_input_reader if not config.input_path: raise ValueError('At least one input path must be specified in ' '`input_reader_config`.') label_map_proto_file = None if input_reader_config.HasField('label_map_path'): label_map_proto_file = input_reader_config.label_map_path decoder = tf_example_decoder.TfExampleDecoder( load_instance_masks=input_reader_config.load_instance_masks, instance_mask_type=input_reader_config.mask_type, label_map_proto_file=label_map_proto_file, use_display_name=input_reader_config.use_display_name, num_additional_channels=num_additional_channels) def process_fn(value): processed = decoder.decode(value) if transform_input_data_fn is not None: return transform_input_data_fn(processed) return processed dataset = dataset_util.read_dataset( functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000), process_fn, config.input_path[:], input_reader_config) if batch_size: padding_shapes = _get_padding_shapes(dataset, max_num_boxes, num_classes, spatial_image_shape) dataset = dataset.apply( tf.contrib.data.padded_batch_and_drop_remainder(batch_size, padding_shapes)) return dataset raise ValueError('Unsupported input_reader_config.')
def build(input_reader_config, transform_input_data_fn=None, batch_size=1, max_num_boxes=None, num_classes=None, spatial_image_shape=None): """Builds a tf.data.Dataset. Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all records. Optionally, if `batch_size` > 1 and `max_num_boxes`, `num_classes` and `spatial_image_shape` are not None, returns a padded batched tf.data.Dataset. Args: input_reader_config: A input_reader_pb2.InputReader object. transform_input_data_fn: Function to apply to all records, or None if no extra decoding is required. batch_size: Batch size. If not None, returns a padded batch dataset. max_num_boxes: Max number of groundtruth boxes needed to computes shapes for padding. This is only used if batch_size is greater than 1. num_classes: Number of classes in the dataset needed to compute shapes for padding. This is only used if batch_size is greater than 1. spatial_image_shape: a list of two integers of the form [height, width] containing expected spatial shape of the image after applying transform_input_data_fn. This is needed to compute shapes for padding and only used if batch_size is greater than 1. Returns: A tf.data.Dataset based on the input_reader_config. Raises: ValueError: On invalid input reader proto. ValueError: If no input paths are specified. ValueError: If batch_size > 1 and any of (max_num_boxes, num_classes, spatial_image_shape) is None. """ if not isinstance(input_reader_config, input_reader_pb2.InputReader): raise ValueError('input_reader_config not of type ' 'input_reader_pb2.InputReader.') if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader': config = input_reader_config.tf_record_input_reader if not config.input_path: raise ValueError('At least one input path must be specified in ' '`input_reader_config`.') label_map_proto_file = None if input_reader_config.HasField('label_map_path'): label_map_proto_file = input_reader_config.label_map_path decoder = tf_example_decoder.TfExampleDecoder( load_instance_masks=input_reader_config.load_instance_masks, instance_mask_type=input_reader_config.mask_type, label_map_proto_file=label_map_proto_file) def process_fn(value): processed = decoder.decode(value) if transform_input_data_fn is not None: return transform_input_data_fn(processed) return processed dataset = dataset_util.read_dataset( functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000), process_fn, config.input_path[:], input_reader_config) if batch_size > 1: if num_classes is None: raise ValueError('`num_classes` must be set when batch_size > 1.') if max_num_boxes is None: raise ValueError('`max_num_boxes` must be set when batch_size > 1.') if spatial_image_shape is None: raise ValueError('`spatial_image_shape` must be set when batch_size > ' '1 .') padding_shapes = _get_padding_shapes(dataset, max_num_boxes, num_classes, spatial_image_shape) dataset = dataset.apply( tf.contrib.data.padded_batch_and_drop_remainder(batch_size, padding_shapes)) return dataset raise ValueError('Unsupported input_reader_config.')
def build(input_reader_config, transform_input_data_fn=None, num_workers=1, worker_index=0, batch_size=1, max_num_boxes=None, num_classes=None, spatial_image_shape=None): """Builds a tf.data.Dataset. Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all records. Optionally, if `batch_size` > 1 and `max_num_boxes`, `num_classes` and `spatial_image_shape` are not None, returns a padded batched tf.data.Dataset. Args: input_reader_config: A input_reader_pb2.InputReader object. transform_input_data_fn: Function to apply to all records, or None if no extra decoding is required. num_workers: Number of workers (tpu shard). worker_index: Id for the current worker (tpu shard). batch_size: Batch size. If not None, returns a padded batch dataset. max_num_boxes: Max number of groundtruth boxes needed to computes shapes for padding. This is only used if batch_size is greater than 1. num_classes: Number of classes in the dataset needed to compute shapes for padding. This is only used if batch_size is greater than 1. spatial_image_shape: a list of two integers of the form [height, width] containing expected spatial shape of the image after applying transform_input_data_fn. This is needed to compute shapes for padding and only used if batch_size is greater than 1. Returns: A tf.data.Dataset based on the input_reader_config. Raises: ValueError: On invalid input reader proto. ValueError: If no input paths are specified. ValueError: If batch_size > 1 and any of (max_num_boxes, num_classes, spatial_image_shape) is None. """ if not isinstance(input_reader_config, input_reader_pb2.InputReader): raise ValueError('input_reader_config not of type ' 'input_reader_pb2.InputReader.') if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader': config = input_reader_config.tf_record_input_reader if not config.input_path: raise ValueError('At least one input path must be specified in ' '`input_reader_config`.') label_map_proto_file = None if input_reader_config.HasField('label_map_path'): label_map_proto_file = input_reader_config.label_map_path decoder = tf_example_decoder.TfExampleDecoder( load_instance_masks=input_reader_config.load_instance_masks, instance_mask_type=input_reader_config.mask_type, label_map_proto_file=label_map_proto_file) def process_fn(value): processed = decoder.decode(value) if transform_input_data_fn is not None: return transform_input_data_fn(processed) return processed dataset = dataset_util.read_dataset( tf.data.TFRecordDataset, process_fn, config.input_path[:], input_reader_config, num_workers, worker_index) if batch_size > 1: if num_classes is None: raise ValueError('`num_classes` must be set when batch_size > 1.') if max_num_boxes is None: raise ValueError('`max_num_boxes` must be set when batch_size > 1.') if spatial_image_shape is None: raise ValueError('`spatial_image_shape` must be set when batch_size > ' '1 .') padding_shapes = _get_padding_shapes(dataset, max_num_boxes, num_classes, spatial_image_shape) dataset = dataset.apply( tf.contrib.data.padded_batch_and_drop_remainder(batch_size, padding_shapes)) return dataset raise ValueError('Unsupported input_reader_config.')