コード例 #1
0
    def _get_dataset_next(self, files, config, batch_size):
        def decode_func(value):
            return [tf.string_to_number(value, out_type=tf.int32)]

        dataset = dataset_util.read_dataset(tf.data.TextLineDataset,
                                            decode_func, files, config)
        dataset = dataset.batch(batch_size)
        return dataset.make_one_shot_iterator().get_next()
コード例 #2
0
  def _get_dataset_next(self, files, config, batch_size):
    def decode_func(value):
      return [tf.string_to_number(value, out_type=tf.int32)]

    dataset = dataset_util.read_dataset(
        tf.data.TextLineDataset, decode_func, files, config)
    dataset = dataset.batch(batch_size)
    return dataset.make_one_shot_iterator().get_next()
コード例 #3
0
def build(input_reader_config,
          transform_input_data_fn=None,
          batch_size=None,
          max_num_boxes=None,
          num_classes=None,
          spatial_image_shape=None):
    """Builds a tf.data.Dataset.

  Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all
  records. Applies a padded batch to the resulting dataset.

  Args:
    input_reader_config: A input_reader_pb2.InputReader object.
    transform_input_data_fn: Function to apply to all records, or None if
      no extra decoding is required.
    batch_size: Batch size. If None, batching is not performed.
    max_num_boxes: Max number of groundtruth boxes needed to compute shapes for
      padding. If None, will use a dynamic shape.
    num_classes: Number of classes in the dataset needed to compute shapes for
      padding. If None, will use a dynamic shape.
    spatial_image_shape: A list of two integers of the form [height, width]
      containing expected spatial shape of the image after applying
      transform_input_data_fn. If None, will use dynamic shapes.

  Returns:
    A tf.data.Dataset based on the input_reader_config.

  Raises:
    ValueError: On invalid input reader proto.
    ValueError: If no input paths are specified.
  """
    if not isinstance(input_reader_config, input_reader_pb2.InputReader):
        raise ValueError('input_reader_config not of type '
                         'input_reader_pb2.InputReader.')

    if input_reader_config.WhichOneof(
            'input_reader') == 'tf_record_input_reader':
        config = input_reader_config.tf_record_input_reader
        if not config.input_path:
            raise ValueError('At least one input path must be specified in '
                             '`input_reader_config`.')

        label_map_proto_file = None
        if input_reader_config.HasField('label_map_path'):
            label_map_proto_file = input_reader_config.label_map_path
        decoder = tf_example_decoder.TfExampleDecoder(
            load_instance_masks=input_reader_config.load_instance_masks,
            instance_mask_type=input_reader_config.mask_type,
            label_map_proto_file=label_map_proto_file)

        def process_fn(value):
            processed = decoder.decode(value)
            if transform_input_data_fn is not None:
                return transform_input_data_fn(processed)
            return processed

        dataset = dataset_util.read_dataset(
            functools.partial(tf.data.TFRecordDataset,
                              buffer_size=8 * 1000 * 1000), process_fn,
            config.input_path[:], input_reader_config)

        if batch_size:
            padding_shapes = _get_padding_shapes(dataset, max_num_boxes,
                                                 num_classes,
                                                 spatial_image_shape)
            dataset = dataset.apply(
                tf.contrib.data.padded_batch_and_drop_remainder(
                    batch_size, padding_shapes))
        return dataset

    raise ValueError('Unsupported input_reader_config.')
コード例 #4
0
ファイル: dataset_builder.py プロジェクト: ALISCIFP/models
def build(input_reader_config,
          transform_input_data_fn=None,
          batch_size=None,
          max_num_boxes=None,
          num_classes=None,
          spatial_image_shape=None,
          num_additional_channels=0):
  """Builds a tf.data.Dataset.

  Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all
  records. Applies a padded batch to the resulting dataset.

  Args:
    input_reader_config: A input_reader_pb2.InputReader object.
    transform_input_data_fn: Function to apply to all records, or None if
      no extra decoding is required.
    batch_size: Batch size. If None, batching is not performed.
    max_num_boxes: Max number of groundtruth boxes needed to compute shapes for
      padding. If None, will use a dynamic shape.
    num_classes: Number of classes in the dataset needed to compute shapes for
      padding. If None, will use a dynamic shape.
    spatial_image_shape: A list of two integers of the form [height, width]
      containing expected spatial shape of the image after applying
      transform_input_data_fn. If None, will use dynamic shapes.
    num_additional_channels: Number of additional channels to use in the input.

  Returns:
    A tf.data.Dataset based on the input_reader_config.

  Raises:
    ValueError: On invalid input reader proto.
    ValueError: If no input paths are specified.
  """
  if not isinstance(input_reader_config, input_reader_pb2.InputReader):
    raise ValueError('input_reader_config not of type '
                     'input_reader_pb2.InputReader.')

  if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader':
    config = input_reader_config.tf_record_input_reader
    if not config.input_path:
      raise ValueError('At least one input path must be specified in '
                       '`input_reader_config`.')

    label_map_proto_file = None
    if input_reader_config.HasField('label_map_path'):
      label_map_proto_file = input_reader_config.label_map_path
    decoder = tf_example_decoder.TfExampleDecoder(
        load_instance_masks=input_reader_config.load_instance_masks,
        instance_mask_type=input_reader_config.mask_type,
        label_map_proto_file=label_map_proto_file,
        use_display_name=input_reader_config.use_display_name,
        num_additional_channels=num_additional_channels)

    def process_fn(value):
      processed = decoder.decode(value)
      if transform_input_data_fn is not None:
        return transform_input_data_fn(processed)
      return processed

    dataset = dataset_util.read_dataset(
        functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000),
        process_fn, config.input_path[:], input_reader_config)

    if batch_size:
      padding_shapes = _get_padding_shapes(dataset, max_num_boxes, num_classes,
                                           spatial_image_shape)
      dataset = dataset.apply(
          tf.contrib.data.padded_batch_and_drop_remainder(batch_size,
                                                          padding_shapes))
    return dataset

  raise ValueError('Unsupported input_reader_config.')
コード例 #5
0
ファイル: dataset_builder.py プロジェクト: forging2012/models
def build(input_reader_config, transform_input_data_fn=None,
          batch_size=1, max_num_boxes=None, num_classes=None,
          spatial_image_shape=None):
  """Builds a tf.data.Dataset.

  Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all
  records. Optionally, if `batch_size` > 1 and `max_num_boxes`, `num_classes`
  and `spatial_image_shape` are not None, returns a padded batched
  tf.data.Dataset.

  Args:
    input_reader_config: A input_reader_pb2.InputReader object.
    transform_input_data_fn: Function to apply to all records, or None if
      no extra decoding is required.
    batch_size: Batch size. If not None, returns a padded batch dataset.
    max_num_boxes: Max number of groundtruth boxes needed to computes shapes for
      padding. This is only used if batch_size is greater than 1.
    num_classes: Number of classes in the dataset needed to compute shapes for
      padding. This is only used if batch_size is greater than 1.
    spatial_image_shape: a list of two integers of the form [height, width]
      containing expected spatial shape of the image after applying
      transform_input_data_fn. This is needed to compute shapes for padding and
      only used if batch_size is greater than 1.

  Returns:
    A tf.data.Dataset based on the input_reader_config.

  Raises:
    ValueError: On invalid input reader proto.
    ValueError: If no input paths are specified.
    ValueError: If batch_size > 1 and any of (max_num_boxes, num_classes,
      spatial_image_shape) is None.
  """
  if not isinstance(input_reader_config, input_reader_pb2.InputReader):
    raise ValueError('input_reader_config not of type '
                     'input_reader_pb2.InputReader.')

  if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader':
    config = input_reader_config.tf_record_input_reader
    if not config.input_path:
      raise ValueError('At least one input path must be specified in '
                       '`input_reader_config`.')

    label_map_proto_file = None
    if input_reader_config.HasField('label_map_path'):
      label_map_proto_file = input_reader_config.label_map_path
    decoder = tf_example_decoder.TfExampleDecoder(
        load_instance_masks=input_reader_config.load_instance_masks,
        instance_mask_type=input_reader_config.mask_type,
        label_map_proto_file=label_map_proto_file)

    def process_fn(value):
      processed = decoder.decode(value)
      if transform_input_data_fn is not None:
        return transform_input_data_fn(processed)
      return processed

    dataset = dataset_util.read_dataset(
        functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000),
        process_fn, config.input_path[:], input_reader_config)

    if batch_size > 1:
      if num_classes is None:
        raise ValueError('`num_classes` must be set when batch_size > 1.')
      if max_num_boxes is None:
        raise ValueError('`max_num_boxes` must be set when batch_size > 1.')
      if spatial_image_shape is None:
        raise ValueError('`spatial_image_shape` must be set when batch_size > '
                         '1 .')
      padding_shapes = _get_padding_shapes(dataset, max_num_boxes, num_classes,
                                           spatial_image_shape)
      dataset = dataset.apply(
          tf.contrib.data.padded_batch_and_drop_remainder(batch_size,
                                                          padding_shapes))
    return dataset

  raise ValueError('Unsupported input_reader_config.')
def build(input_reader_config, transform_input_data_fn=None, num_workers=1,
          worker_index=0, batch_size=1, max_num_boxes=None, num_classes=None,
          spatial_image_shape=None):
  """Builds a tf.data.Dataset.

  Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all
  records. Optionally, if `batch_size` > 1 and `max_num_boxes`, `num_classes`
  and `spatial_image_shape` are not None, returns a padded batched
  tf.data.Dataset.

  Args:
    input_reader_config: A input_reader_pb2.InputReader object.
    transform_input_data_fn: Function to apply to all records, or None if
      no extra decoding is required.
    num_workers: Number of workers (tpu shard).
    worker_index: Id for the current worker (tpu shard).
    batch_size: Batch size. If not None, returns a padded batch dataset.
    max_num_boxes: Max number of groundtruth boxes needed to computes shapes for
      padding. This is only used if batch_size is greater than 1.
    num_classes: Number of classes in the dataset needed to compute shapes for
      padding. This is only used if batch_size is greater than 1.
    spatial_image_shape: a list of two integers of the form [height, width]
      containing expected spatial shape of the image after applying
      transform_input_data_fn. This is needed to compute shapes for padding and
      only used if batch_size is greater than 1.

  Returns:
    A tf.data.Dataset based on the input_reader_config.

  Raises:
    ValueError: On invalid input reader proto.
    ValueError: If no input paths are specified.
    ValueError: If batch_size > 1 and any of (max_num_boxes, num_classes,
      spatial_image_shape) is None.
  """
  if not isinstance(input_reader_config, input_reader_pb2.InputReader):
    raise ValueError('input_reader_config not of type '
                     'input_reader_pb2.InputReader.')

  if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader':
    config = input_reader_config.tf_record_input_reader
    if not config.input_path:
      raise ValueError('At least one input path must be specified in '
                       '`input_reader_config`.')

    label_map_proto_file = None
    if input_reader_config.HasField('label_map_path'):
      label_map_proto_file = input_reader_config.label_map_path
    decoder = tf_example_decoder.TfExampleDecoder(
        load_instance_masks=input_reader_config.load_instance_masks,
        instance_mask_type=input_reader_config.mask_type,
        label_map_proto_file=label_map_proto_file)

    def process_fn(value):
      processed = decoder.decode(value)
      if transform_input_data_fn is not None:
        return transform_input_data_fn(processed)
      return processed

    dataset = dataset_util.read_dataset(
        tf.data.TFRecordDataset, process_fn, config.input_path[:],
        input_reader_config, num_workers, worker_index)

    if batch_size > 1:
      if num_classes is None:
        raise ValueError('`num_classes` must be set when batch_size > 1.')
      if max_num_boxes is None:
        raise ValueError('`max_num_boxes` must be set when batch_size > 1.')
      if spatial_image_shape is None:
        raise ValueError('`spatial_image_shape` must be set when batch_size > '
                         '1 .')
      padding_shapes = _get_padding_shapes(dataset, max_num_boxes, num_classes,
                                           spatial_image_shape)
      dataset = dataset.apply(
          tf.contrib.data.padded_batch_and_drop_remainder(batch_size,
                                                          padding_shapes))
    return dataset

  raise ValueError('Unsupported input_reader_config.')