def graph_fn():
     decoder = decoder_builder.build(input_reader_proto)
     tensor_dict = decoder.decode(serialized_seq_example)
     return (tensor_dict[fields.InputDataFields.image],
             tensor_dict[fields.InputDataFields.groundtruth_classes],
             tensor_dict[fields.InputDataFields.groundtruth_boxes],
             tensor_dict[fields.InputDataFields.num_groundtruth_boxes])
def build(input_reader_config, batch_size=None, transform_input_data_fn=None,
          input_context=None, reduce_to_frame_fn=None):
  """Builds a tf.data.Dataset.

  Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all
  records. Applies a padded batch to the resulting dataset.

  Args:
    input_reader_config: A input_reader_pb2.InputReader object.
    batch_size: Batch size. If batch size is None, no batching is performed.
    transform_input_data_fn: Function to apply transformation to all records,
      or None if no extra decoding is required.
    input_context: optional, A tf.distribute.InputContext object used to
      shard filenames and compute per-replica batch_size when this function
      is being called per-replica.
    reduce_to_frame_fn: Function that extracts frames from tf.SequenceExample
      type input data.

  Returns:
    A tf.data.Dataset based on the input_reader_config.

  Raises:
    ValueError: On invalid input reader proto.
    ValueError: If no input paths are specified.
  """
  if not isinstance(input_reader_config, input_reader_pb2.InputReader):
    raise ValueError('input_reader_config not of type '
                     'input_reader_pb2.InputReader.')

  decoder = decoder_builder.build(input_reader_config)

  if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader':
    config = input_reader_config.tf_record_input_reader
    if not config.input_path:
      raise ValueError('At least one input path must be specified in '
                       '`input_reader_config`.')
    shard_fn = shard_function_for_context(input_context)
    if input_context is not None:
      batch_size = input_context.get_per_replica_batch_size(batch_size)
    dataset = read_dataset(
        functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000),
        config.input_path[:], input_reader_config, filename_shard_fn=shard_fn)
    if input_reader_config.sample_1_of_n_examples > 1:
      dataset = dataset.shard(input_reader_config.sample_1_of_n_examples, 0)
    # TODO(rathodv): make batch size a required argument once the old binaries
    # are deleted.
    dataset = dataset.map(decoder.decode, tf.data.experimental.AUTOTUNE)
    if reduce_to_frame_fn:
      dataset = reduce_to_frame_fn(dataset)
    if transform_input_data_fn is not None:
      dataset = dataset.map(transform_input_data_fn,
                            tf.data.experimental.AUTOTUNE)
    if batch_size:
      dataset = dataset.apply(
          tf_data.batch_and_drop_remainder(batch_size))
    dataset = dataset.prefetch(input_reader_config.num_prefetch_batches)
    return dataset

  raise ValueError('Unsupported input_reader_config.')
Beispiel #3
0
    def test_build_tf_record_input_reader_and_load_instance_masks(self):
        input_reader_text_proto = """
      load_instance_masks: true
      tf_record_input_reader {}
    """
        input_reader_proto = input_reader_pb2.InputReader()
        text_format.Parse(input_reader_text_proto, input_reader_proto)

        decoder = decoder_builder.build(input_reader_proto)
        tensor_dict = decoder.decode(self._make_serialized_tf_example())

        with tf.train.MonitoredSession() as sess:
            output_dict = sess.run(tensor_dict)

        self.assertAllEqual((1, 4, 5), output_dict[
            fields.InputDataFields.groundtruth_instance_masks].shape)
Beispiel #4
0
    def test_build_tf_record_input_reader_sequence_example(self):
        label_map_path = _get_labelmap_path()
        input_reader_text_proto = """
      input_type: TF_SEQUENCE_EXAMPLE
      tf_record_input_reader {}
    """
        input_reader_proto = input_reader_pb2.InputReader()
        input_reader_proto.label_map_path = label_map_path
        text_format.Parse(input_reader_text_proto, input_reader_proto)

        decoder = decoder_builder.build(input_reader_proto)
        tensor_dict = decoder.decode(
            self._make_serialized_tf_sequence_example())

        with tf.train.MonitoredSession() as sess:
            output_dict = sess.run(tensor_dict)

        expected_groundtruth_classes = [[-1, -1], [1, -1], [1, 2], [-1, -1]]
        expected_groundtruth_boxes = [[[0.0, 0.0, 0.0, 0.0],
                                       [0.0, 0.0, 0.0, 0.0]],
                                      [[0.0, 0.0, 1.0, 1.0],
                                       [0.0, 0.0, 0.0, 0.0]],
                                      [[0.0, 0.0, 1.0, 1.0],
                                       [0.1, 0.1, 0.2, 0.2]],
                                      [[0.0, 0.0, 0.0, 0.0],
                                       [0.0, 0.0, 0.0, 0.0]]]
        expected_num_groundtruth_boxes = [0, 1, 2, 0]

        self.assertNotIn(fields.InputDataFields.groundtruth_instance_masks,
                         output_dict)
        # Sequence example images are encoded.
        self.assertEqual((4, ),
                         output_dict[fields.InputDataFields.image].shape)
        self.assertAllEqual(
            expected_groundtruth_classes,
            output_dict[fields.InputDataFields.groundtruth_classes])
        self.assertEqual(
            (4, 2, 4),
            output_dict[fields.InputDataFields.groundtruth_boxes].shape)
        self.assertAllClose(
            expected_groundtruth_boxes,
            output_dict[fields.InputDataFields.groundtruth_boxes])
        self.assertAllClose(
            expected_num_groundtruth_boxes,
            output_dict[fields.InputDataFields.num_groundtruth_boxes])
    def test_build_tf_record_input_reader_and_load_instance_masks(self):
        input_reader_text_proto = """
      load_instance_masks: true
      tf_record_input_reader {}
    """
        input_reader_proto = input_reader_pb2.InputReader()
        text_format.Parse(input_reader_text_proto, input_reader_proto)

        decoder = decoder_builder.build(input_reader_proto)
        serialized_seq_example = self._make_serialized_tf_example()

        def graph_fn():
            tensor_dict = decoder.decode(serialized_seq_example)
            return tensor_dict[
                fields.InputDataFields.groundtruth_instance_masks]

        masks = self.execute_cpu(graph_fn, [])
        self.assertAllEqual((1, 4, 5), masks.shape)
    def test_build_tf_record_input_reader(self):
        input_reader_text_proto = 'tf_record_input_reader {}'
        input_reader_proto = input_reader_pb2.InputReader()
        text_format.Parse(input_reader_text_proto, input_reader_proto)

        decoder = decoder_builder.build(input_reader_proto)
        serialized_seq_example = self._make_serialized_tf_example()

        def graph_fn():
            tensor_dict = decoder.decode(serialized_seq_example)
            return (tensor_dict[fields.InputDataFields.image],
                    tensor_dict[fields.InputDataFields.groundtruth_classes],
                    tensor_dict[fields.InputDataFields.groundtruth_boxes])

        (image, groundtruth_classes,
         groundtruth_boxes) = self.execute_cpu(graph_fn, [])
        self.assertEqual((4, 5, 3), image.shape)
        self.assertAllEqual([2], groundtruth_classes)
        self.assertEqual((1, 4), groundtruth_boxes.shape)
        self.assertAllEqual([0.0, 0.0, 1.0, 1.0], groundtruth_boxes[0])
    def test_build_tf_record_input_reader_and_load_keypoint_depth(self):
        input_reader_text_proto = """
      load_keypoint_depth_features: true
      num_keypoints: 2
      tf_record_input_reader {}
    """
        input_reader_proto = input_reader_pb2.InputReader()
        text_format.Parse(input_reader_text_proto, input_reader_proto)

        decoder = decoder_builder.build(input_reader_proto)
        serialized_example = self._make_serialized_tf_example()

        def graph_fn():
            tensor_dict = decoder.decode(serialized_example)
            return (tensor_dict[
                fields.InputDataFields.groundtruth_keypoint_depths],
                    tensor_dict[fields.InputDataFields.
                                groundtruth_keypoint_depth_weights])

        (kpts_depths, kpts_depth_weights) = self.execute_cpu(graph_fn, [])
        self.assertAllEqual((1, 2), kpts_depths.shape)
        self.assertAllEqual((1, 2), kpts_depth_weights.shape)
Beispiel #8
0
    def test_build_tf_record_input_reader(self):
        input_reader_text_proto = 'tf_record_input_reader {}'
        input_reader_proto = input_reader_pb2.InputReader()
        text_format.Parse(input_reader_text_proto, input_reader_proto)

        decoder = decoder_builder.build(input_reader_proto)
        tensor_dict = decoder.decode(self._make_serialized_tf_example())

        with tf.train.MonitoredSession() as sess:
            output_dict = sess.run(tensor_dict)

        self.assertNotIn(fields.InputDataFields.groundtruth_instance_masks,
                         output_dict)
        self.assertEqual((4, 5, 3),
                         output_dict[fields.InputDataFields.image].shape)
        self.assertAllEqual(
            [2], output_dict[fields.InputDataFields.groundtruth_classes])
        self.assertEqual(
            (1, 4),
            output_dict[fields.InputDataFields.groundtruth_boxes].shape)
        self.assertAllEqual(
            [0.0, 0.0, 1.0, 1.0],
            output_dict[fields.InputDataFields.groundtruth_boxes][0])
def build(input_reader_config, batch_size=None, transform_input_data_fn=None,
          input_context=None):
  """Builds a tf.data.Dataset.

  Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all
  records. Applies a padded batch to the resulting dataset.

  Args:
    input_reader_config: A input_reader_pb2.InputReader object.
    batch_size: Batch size. If batch size is None, no batching is performed.
    transform_input_data_fn: Function to apply transformation to all records,
      or None if no extra decoding is required.
    input_context: optional, A tf.distribute.InputContext object used to
      shard filenames and compute per-replica batch_size when this function
      is being called per-replica.

  Returns:
    A tf.data.Dataset based on the input_reader_config.

  Raises:
    ValueError: On invalid input reader proto.
    ValueError: If no input paths are specified.
  """
  if not isinstance(input_reader_config, input_reader_pb2.InputReader):
    raise ValueError('input_reader_config not of type '
                     'input_reader_pb2.InputReader.')

  decoder = decoder_builder.build(input_reader_config)

  if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader':
    config = input_reader_config.tf_record_input_reader
    if not config.input_path:
      raise ValueError('At least one input path must be specified in '
                       '`input_reader_config`.')

    def process_fn(value):
      """Sets up tf graph that decodes, transforms and pads input data."""
      processed_tensors = decoder.decode(value)
      if transform_input_data_fn is not None:
        processed_tensors = transform_input_data_fn(processed_tensors)
      return processed_tensors

    shard_fn = shard_function_for_context(input_context)
    if input_context is not None:
      batch_size = input_context.get_per_replica_batch_size(batch_size)

    dataset = read_dataset(
        functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000),
        config.input_path[:], input_reader_config, filename_shard_fn=shard_fn)
    if input_reader_config.sample_1_of_n_examples > 1:
      dataset = dataset.shard(input_reader_config.sample_1_of_n_examples, 0)
    # TODO(rathodv): make batch size a required argument once the old binaries
    # are deleted.
    if batch_size:
      num_parallel_calls = batch_size * input_reader_config.num_parallel_batches
    else:
      num_parallel_calls = input_reader_config.num_parallel_map_calls
    # TODO(b/123952794): Migrate to V2 function.
    if hasattr(dataset, 'map_with_legacy_function'):
      data_map_fn = dataset.map_with_legacy_function
    else:
      data_map_fn = dataset.map
    dataset = data_map_fn(process_fn, num_parallel_calls=num_parallel_calls)
    if batch_size:
      dataset = dataset.apply(
          tf_data.batch_and_drop_remainder(batch_size))
    dataset = dataset.prefetch(input_reader_config.num_prefetch_batches)
    return dataset

  raise ValueError('Unsupported input_reader_config.')