def build_inputs(self, params: cfg.DataConfig, input_context: Optional[tf.distribute.InputContext] = None): """Builds classification input.""" ignore_label = self.task_config.losses.ignore_label if params.tfds_name: decoder = tfds_factory.get_segmentation_decoder(params.tfds_name) else: decoder = segmentation_input.Decoder() parser = ClassMappingParser( output_size=params.output_size, crop_size=params.crop_size, ignore_label=ignore_label, resize_eval_groundtruth=params.resize_eval_groundtruth, groundtruth_padded_size=params.groundtruth_padded_size, aug_scale_min=params.aug_scale_min, aug_scale_max=params.aug_scale_max, aug_rand_hflip=params.aug_rand_hflip, dtype=params.dtype) parser.max_class = self.task_config.model.num_classes-1 reader = input_reader_factory.input_reader_generator( params, dataset_fn=dataset_fn.pick_dataset_fn(params.file_type), decoder_fn=decoder.decode, parser_fn=parser.parse_fn(params.is_training)) dataset = reader.read(input_context=input_context) return dataset
def test_segmentation_decoder(self, tfds_name): decoder = tfds_factory.get_segmentation_decoder(tfds_name) self.assertIsInstance(decoder, base_decoder.Decoder) decoded_tensor = decoder.decode(self._create_test_example()) self.assertLen(decoded_tensor, 4) self.assertIn('image/encoded', decoded_tensor) self.assertIn('image/segmentation/class/encoded', decoded_tensor) self.assertIn('image/height', decoded_tensor) self.assertIn('image/width', decoded_tensor)
def test_doesnt_exit_segmentation_decoder(self, tfds_name): with self.assertRaises(ValueError): _ = tfds_factory.get_segmentation_decoder(tfds_name)