Ejemplo n.º 1
0
 def test_detection_decoder(self, tfds_name):
     decoder = tfds_factory.get_detection_decoder(tfds_name)
     self.assertIsInstance(decoder, base_decoder.Decoder)
     decoded_tensor = decoder.decode(self._create_test_example())
     self.assertLen(decoded_tensor, 8)
     self.assertIn('image', decoded_tensor)
     self.assertIn('source_id', decoded_tensor)
     self.assertIn('height', decoded_tensor)
     self.assertIn('width', decoded_tensor)
     self.assertIn('groundtruth_classes', decoded_tensor)
     self.assertIn('groundtruth_is_crowd', decoded_tensor)
     self.assertIn('groundtruth_area', decoded_tensor)
     self.assertIn('groundtruth_boxes', decoded_tensor)
Ejemplo n.º 2
0
    def build_inputs(
            self,
            params: exp_cfg.DataConfig,
            input_context: Optional[tf.distribute.InputContext] = None):
        """Build input dataset."""

        if params.tfds_name:
            decoder = tfds_factory.get_detection_decoder(params.tfds_name)
        else:
            decoder_cfg = params.decoder.get()
            if params.decoder.type == 'simple_decoder':
                decoder = tf_example_decoder.TfExampleDecoder(
                    regenerate_source_id=decoder_cfg.regenerate_source_id)
            elif params.decoder.type == 'label_map_decoder':
                decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
                    label_map=decoder_cfg.label_map,
                    regenerate_source_id=decoder_cfg.regenerate_source_id)
            else:
                raise ValueError('Unknown decoder type: {}!'.format(
                    params.decoder.type))

        parser = retinanet_input.Parser(
            output_size=self.task_config.model.input_size[:2],
            min_level=self.task_config.model.min_level,
            max_level=self.task_config.model.max_level,
            num_scales=self.task_config.model.anchor.num_scales,
            aspect_ratios=self.task_config.model.anchor.aspect_ratios,
            anchor_size=self.task_config.model.anchor.anchor_size,
            dtype=params.dtype,
            match_threshold=params.parser.match_threshold,
            unmatched_threshold=params.parser.unmatched_threshold,
            aug_type=params.parser.aug_type,
            aug_rand_hflip=params.parser.aug_rand_hflip,
            aug_scale_min=params.parser.aug_scale_min,
            aug_scale_max=params.parser.aug_scale_max,
            skip_crowd_during_training=params.parser.
            skip_crowd_during_training,
            max_num_instances=params.parser.max_num_instances)

        reader = input_reader_factory.input_reader_generator(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))
        dataset = reader.read(input_context=input_context)

        return dataset
Ejemplo n.º 3
0
    def build_inputs(
            self,
            params: exp_cfg.DataConfig,
            input_context: Optional[tf.distribute.InputContext] = None):
        """Build input dataset."""
        if params.tfds_name:
            decoder = tfds_factory.get_detection_decoder(params.tfds_name)
        else:
            decoder_cfg = params.decoder.get()
            if params.decoder.type == 'simple_decoder':
                decoder = tf_example_decoder.TfExampleDecoder(
                    regenerate_source_id=decoder_cfg.regenerate_source_id)
            elif params.decoder.type == 'label_map_decoder':
                decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
                    label_map=decoder_cfg.label_map,
                    regenerate_source_id=decoder_cfg.regenerate_source_id)
            else:
                raise ValueError('Unknown decoder type: {}!'.format(
                    params.decoder.type))

        parser = centernet_input.CenterNetParser(
            output_height=self.task_config.model.input_size[0],
            output_width=self.task_config.model.input_size[1],
            max_num_instances=self.task_config.model.max_num_instances,
            bgr_ordering=params.parser.bgr_ordering,
            channel_means=params.parser.channel_means,
            channel_stds=params.parser.channel_stds,
            aug_rand_hflip=params.parser.aug_rand_hflip,
            aug_scale_min=params.parser.aug_scale_min,
            aug_scale_max=params.parser.aug_scale_max,
            aug_rand_hue=params.parser.aug_rand_hue,
            aug_rand_brightness=params.parser.aug_rand_brightness,
            aug_rand_contrast=params.parser.aug_rand_contrast,
            aug_rand_saturation=params.parser.aug_rand_saturation,
            odapi_augmentation=params.parser.odapi_augmentation,
            dtype=params.dtype)

        reader = input_reader.InputReader(params,
                                          dataset_fn=tf.data.TFRecordDataset,
                                          decoder_fn=decoder.decode,
                                          parser_fn=parser.parse_fn(
                                              params.is_training))

        dataset = reader.read(input_context=input_context)

        return dataset
Ejemplo n.º 4
0
 def _get_data_decoder(self, params):
   """Get a decoder object to decode the dataset."""
   if params.tfds_name:
     decoder = tfds_factory.get_detection_decoder(params.tfds_name)
   else:
     decoder_cfg = params.decoder.get()
     if params.decoder.type == 'simple_decoder':
       self._coco_91_to_80 = decoder_cfg.coco91_to_80
       decoder = tf_example_decoder.TfExampleDecoder(
           coco91_to_80=decoder_cfg.coco91_to_80,
           regenerate_source_id=decoder_cfg.regenerate_source_id)
     elif params.decoder.type == 'label_map_decoder':
       decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
           label_map=decoder_cfg.label_map,
           regenerate_source_id=decoder_cfg.regenerate_source_id)
     else:
       raise ValueError('Unknown decoder type: {}!'.format(
           params.decoder.type))
   return decoder
Ejemplo n.º 5
0
 def test_doesnt_exit_detection_decoder(self, tfds_name):
     with self.assertRaises(ValueError):
         _ = tfds_factory.get_detection_decoder(tfds_name)