Example #1
0
  def build_inputs(self,
                   params: cfg.DataConfig,
                   input_context: Optional[tf.distribute.InputContext] = None):
    """Builds classification input."""
    ignore_label = self.task_config.losses.ignore_label

    if params.tfds_name:
      decoder = tfds_factory.get_segmentation_decoder(params.tfds_name)
    else:
      decoder = segmentation_input.Decoder()

    parser = ClassMappingParser(
        output_size=params.output_size,
        crop_size=params.crop_size,
        ignore_label=ignore_label,
        resize_eval_groundtruth=params.resize_eval_groundtruth,
        groundtruth_padded_size=params.groundtruth_padded_size,
        aug_scale_min=params.aug_scale_min,
        aug_scale_max=params.aug_scale_max,
        aug_rand_hflip=params.aug_rand_hflip,
        dtype=params.dtype)

    parser.max_class = self.task_config.model.num_classes-1

    reader = input_reader_factory.input_reader_generator(
        params,
        dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training))

    dataset = reader.read(input_context=input_context)

    return dataset
 def test_segmentation_decoder(self, tfds_name):
     decoder = tfds_factory.get_segmentation_decoder(tfds_name)
     self.assertIsInstance(decoder, base_decoder.Decoder)
     decoded_tensor = decoder.decode(self._create_test_example())
     self.assertLen(decoded_tensor, 4)
     self.assertIn('image/encoded', decoded_tensor)
     self.assertIn('image/segmentation/class/encoded', decoded_tensor)
     self.assertIn('image/height', decoded_tensor)
     self.assertIn('image/width', decoded_tensor)
 def test_doesnt_exit_segmentation_decoder(self, tfds_name):
     with self.assertRaises(ValueError):
         _ = tfds_factory.get_segmentation_decoder(tfds_name)