def testSegmentationInputReader(self, input_size, num_classes, num_channels): params = cfg.DataConfig( input_path=self._data_path, global_batch_size=2, is_training=False) decoder = segmentation_input_3d.Decoder() parser = segmentation_input_3d.Parser( input_size=input_size, num_classes=num_classes, num_channels=num_channels) reader = input_reader.InputReader( params, dataset_fn=dataset_fn.pick_dataset_fn('tfrecord'), decoder_fn=decoder.decode, parser_fn=parser.parse_fn(params.is_training)) dataset = reader.read() iterator = iter(dataset) image, labels = next(iterator) # Checks image shape. self.assertEqual( list(image.numpy().shape), [2, input_size[0], input_size[1], input_size[2], num_channels]) self.assertEqual( list(labels.numpy().shape), [2, input_size[0], input_size[1], input_size[2], num_classes])
def testSegmentationInputReader(self, input_size, num_classes, num_channels, is_training): decoder = segmentation_input_3d.Decoder() parser = segmentation_input_3d.Parser(input_size=input_size, num_classes=num_classes, num_channels=num_channels) decoded_tensor = decoder.decode(self._example.SerializeToString()) image, labels = parser.parse_fn( is_training=is_training)(decoded_tensor) # Checks image shape. self.assertEqual( list(image.numpy().shape), [input_size[0], input_size[1], input_size[2], num_channels]) self.assertEqual( list(labels.numpy().shape), [input_size[0], input_size[1], input_size[2], num_classes])
def build_inputs(self, params, input_context=None) -> tf.data.Dataset: """Builds classification input.""" decoder = segmentation_input_3d.Decoder( image_field_key=params.image_field_key, label_field_key=params.label_field_key) parser = segmentation_input_3d.Parser( input_size=params.input_size, num_classes=params.num_classes, num_channels=params.num_channels, image_field_key=params.image_field_key, label_field_key=params.label_field_key, dtype=params.dtype, label_dtype=params.label_dtype) reader = input_reader.InputReader( params, dataset_fn=dataset_fn.pick_dataset_fn(params.file_type), decoder_fn=decoder.decode, parser_fn=parser.parse_fn(params.is_training)) dataset = reader.read(input_context=input_context) return dataset