Example #1
0
  def testSegmentationInputReader(self, input_size, num_classes, num_channels):
    params = cfg.DataConfig(
        input_path=self._data_path, global_batch_size=2, is_training=False)

    decoder = segmentation_input_3d.Decoder()
    parser = segmentation_input_3d.Parser(
        input_size=input_size,
        num_classes=num_classes,
        num_channels=num_channels)

    reader = input_reader.InputReader(
        params,
        dataset_fn=dataset_fn.pick_dataset_fn('tfrecord'),
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training))

    dataset = reader.read()
    iterator = iter(dataset)
    image, labels = next(iterator)

    # Checks image shape.
    self.assertEqual(
        list(image.numpy().shape),
        [2, input_size[0], input_size[1], input_size[2], num_channels])
    self.assertEqual(
        list(labels.numpy().shape),
        [2, input_size[0], input_size[1], input_size[2], num_classes])
    def testSegmentationInputReader(self, input_size, num_classes,
                                    num_channels, is_training):

        decoder = segmentation_input_3d.Decoder()
        parser = segmentation_input_3d.Parser(input_size=input_size,
                                              num_classes=num_classes,
                                              num_channels=num_channels)

        decoded_tensor = decoder.decode(self._example.SerializeToString())
        image, labels = parser.parse_fn(
            is_training=is_training)(decoded_tensor)

        # Checks image shape.
        self.assertEqual(
            list(image.numpy().shape),
            [input_size[0], input_size[1], input_size[2], num_channels])
        self.assertEqual(
            list(labels.numpy().shape),
            [input_size[0], input_size[1], input_size[2], num_classes])
Example #3
0
  def build_inputs(self, params, input_context=None) -> tf.data.Dataset:
    """Builds classification input."""
    decoder = segmentation_input_3d.Decoder(
        image_field_key=params.image_field_key,
        label_field_key=params.label_field_key)
    parser = segmentation_input_3d.Parser(
        input_size=params.input_size,
        num_classes=params.num_classes,
        num_channels=params.num_channels,
        image_field_key=params.image_field_key,
        label_field_key=params.label_field_key,
        dtype=params.dtype,
        label_dtype=params.label_dtype)

    reader = input_reader.InputReader(
        params,
        dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training))

    dataset = reader.read(input_context=input_context)

    return dataset