def test_detection_decoder(self, tfds_name): decoder = tfds_factory.get_detection_decoder(tfds_name) self.assertIsInstance(decoder, base_decoder.Decoder) decoded_tensor = decoder.decode(self._create_test_example()) self.assertLen(decoded_tensor, 8) self.assertIn('image', decoded_tensor) self.assertIn('source_id', decoded_tensor) self.assertIn('height', decoded_tensor) self.assertIn('width', decoded_tensor) self.assertIn('groundtruth_classes', decoded_tensor) self.assertIn('groundtruth_is_crowd', decoded_tensor) self.assertIn('groundtruth_area', decoded_tensor) self.assertIn('groundtruth_boxes', decoded_tensor)
def build_inputs( self, params: exp_cfg.DataConfig, input_context: Optional[tf.distribute.InputContext] = None): """Build input dataset.""" if params.tfds_name: decoder = tfds_factory.get_detection_decoder(params.tfds_name) else: decoder_cfg = params.decoder.get() if params.decoder.type == 'simple_decoder': decoder = tf_example_decoder.TfExampleDecoder( regenerate_source_id=decoder_cfg.regenerate_source_id) elif params.decoder.type == 'label_map_decoder': decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap( label_map=decoder_cfg.label_map, regenerate_source_id=decoder_cfg.regenerate_source_id) else: raise ValueError('Unknown decoder type: {}!'.format( params.decoder.type)) parser = retinanet_input.Parser( output_size=self.task_config.model.input_size[:2], min_level=self.task_config.model.min_level, max_level=self.task_config.model.max_level, num_scales=self.task_config.model.anchor.num_scales, aspect_ratios=self.task_config.model.anchor.aspect_ratios, anchor_size=self.task_config.model.anchor.anchor_size, dtype=params.dtype, match_threshold=params.parser.match_threshold, unmatched_threshold=params.parser.unmatched_threshold, aug_type=params.parser.aug_type, aug_rand_hflip=params.parser.aug_rand_hflip, aug_scale_min=params.parser.aug_scale_min, aug_scale_max=params.parser.aug_scale_max, skip_crowd_during_training=params.parser. skip_crowd_during_training, max_num_instances=params.parser.max_num_instances) reader = input_reader_factory.input_reader_generator( params, dataset_fn=dataset_fn.pick_dataset_fn(params.file_type), decoder_fn=decoder.decode, parser_fn=parser.parse_fn(params.is_training)) dataset = reader.read(input_context=input_context) return dataset
def build_inputs( self, params: exp_cfg.DataConfig, input_context: Optional[tf.distribute.InputContext] = None): """Build input dataset.""" if params.tfds_name: decoder = tfds_factory.get_detection_decoder(params.tfds_name) else: decoder_cfg = params.decoder.get() if params.decoder.type == 'simple_decoder': decoder = tf_example_decoder.TfExampleDecoder( regenerate_source_id=decoder_cfg.regenerate_source_id) elif params.decoder.type == 'label_map_decoder': decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap( label_map=decoder_cfg.label_map, regenerate_source_id=decoder_cfg.regenerate_source_id) else: raise ValueError('Unknown decoder type: {}!'.format( params.decoder.type)) parser = centernet_input.CenterNetParser( output_height=self.task_config.model.input_size[0], output_width=self.task_config.model.input_size[1], max_num_instances=self.task_config.model.max_num_instances, bgr_ordering=params.parser.bgr_ordering, channel_means=params.parser.channel_means, channel_stds=params.parser.channel_stds, aug_rand_hflip=params.parser.aug_rand_hflip, aug_scale_min=params.parser.aug_scale_min, aug_scale_max=params.parser.aug_scale_max, aug_rand_hue=params.parser.aug_rand_hue, aug_rand_brightness=params.parser.aug_rand_brightness, aug_rand_contrast=params.parser.aug_rand_contrast, aug_rand_saturation=params.parser.aug_rand_saturation, odapi_augmentation=params.parser.odapi_augmentation, dtype=params.dtype) reader = input_reader.InputReader(params, dataset_fn=tf.data.TFRecordDataset, decoder_fn=decoder.decode, parser_fn=parser.parse_fn( params.is_training)) dataset = reader.read(input_context=input_context) return dataset
def _get_data_decoder(self, params): """Get a decoder object to decode the dataset.""" if params.tfds_name: decoder = tfds_factory.get_detection_decoder(params.tfds_name) else: decoder_cfg = params.decoder.get() if params.decoder.type == 'simple_decoder': self._coco_91_to_80 = decoder_cfg.coco91_to_80 decoder = tf_example_decoder.TfExampleDecoder( coco91_to_80=decoder_cfg.coco91_to_80, regenerate_source_id=decoder_cfg.regenerate_source_id) elif params.decoder.type == 'label_map_decoder': decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap( label_map=decoder_cfg.label_map, regenerate_source_id=decoder_cfg.regenerate_source_id) else: raise ValueError('Unknown decoder type: {}!'.format( params.decoder.type)) return decoder
def test_doesnt_exit_detection_decoder(self, tfds_name): with self.assertRaises(ValueError): _ = tfds_factory.get_detection_decoder(tfds_name)