Exemplo n.º 1
0
    def build_inputs(
            self,
            params: cfg.DataConfig,
            input_context: Optional[tf.distribute.InputContext] = None):
        """Builds classification input."""
        ignore_label = self.task_config.losses.ignore_label

        if params.tfds_name:
            decoder = tfds_factory.get_segmentation_decoder(params.tfds_name)
        else:
            decoder = segmentation_input.Decoder()

        parser = ClassMappingParser(
            output_size=params.output_size,
            crop_size=params.crop_size,
            ignore_label=ignore_label,
            resize_eval_groundtruth=params.resize_eval_groundtruth,
            groundtruth_padded_size=params.groundtruth_padded_size,
            aug_scale_min=params.aug_scale_min,
            aug_scale_max=params.aug_scale_max,
            aug_rand_hflip=params.aug_rand_hflip,
            dtype=params.dtype)

        parser.max_class = self.task_config.model.num_classes - 1

        reader = input_reader_factory.input_reader_generator(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))

        dataset = reader.read(input_context=input_context)

        return dataset
Exemplo n.º 2
0
 def test_segmentation_decoder(self, tfds_name):
     decoder = tfds_factory.get_segmentation_decoder(tfds_name)
     self.assertIsInstance(decoder, base_decoder.Decoder)
     decoded_tensor = decoder.decode(self._create_test_example())
     self.assertLen(decoded_tensor, 4)
     self.assertIn('image/encoded', decoded_tensor)
     self.assertIn('image/segmentation/class/encoded', decoded_tensor)
     self.assertIn('image/height', decoded_tensor)
     self.assertIn('image/width', decoded_tensor)
Exemplo n.º 3
0
 def test_doesnt_exit_segmentation_decoder(self, tfds_name):
     with self.assertRaises(ValueError):
         _ = tfds_factory.get_segmentation_decoder(tfds_name)