def test_raise_error_on_empty_matcher(self): matcher_text_proto = """ """ matcher_proto = matcher_pb2.Matcher() text_format.Merge(matcher_text_proto, matcher_proto) with self.assertRaises(ValueError): matcher_builder.build(matcher_proto)
def test_build_bipartite_matcher(self): matcher_text_proto = """ bipartite_matcher { } """ matcher_proto = matcher_pb2.Matcher() text_format.Merge(matcher_text_proto, matcher_proto) matcher_object = matcher_builder.build(matcher_proto) self.assertTrue( isinstance(matcher_object, bipartite_matcher.GreedyBipartiteMatcher))
def test_build_arg_max_matcher_with_defaults(self): matcher_text_proto = """ argmax_matcher { } """ matcher_proto = matcher_pb2.Matcher() text_format.Merge(matcher_text_proto, matcher_proto) matcher_object = matcher_builder.build(matcher_proto) self.assertTrue( isinstance(matcher_object, argmax_matcher.ArgMaxMatcher)) self.assertAlmostEqual(matcher_object._matched_threshold, 0.5) self.assertAlmostEqual(matcher_object._unmatched_threshold, 0.5) self.assertTrue(matcher_object._negatives_lower_than_unmatched) self.assertFalse(matcher_object._force_match_for_each_row)
def test_build_arg_max_matcher_with_non_default_parameters(self): matcher_text_proto = """ argmax_matcher { matched_threshold: 0.7 unmatched_threshold: 0.3 negatives_lower_than_unmatched: false force_match_for_each_row: true use_matmul_gather: true } """ matcher_proto = matcher_pb2.Matcher() text_format.Merge(matcher_text_proto, matcher_proto) matcher_object = matcher_builder.build(matcher_proto) self.assertTrue( isinstance(matcher_object, argmax_matcher.ArgMaxMatcher)) self.assertAlmostEqual(matcher_object._matched_threshold, 0.7) self.assertAlmostEqual(matcher_object._unmatched_threshold, 0.3) self.assertFalse(matcher_object._negatives_lower_than_unmatched) self.assertTrue(matcher_object._force_match_for_each_row) self.assertTrue(matcher_object._use_matmul_gather)
def _build_ssd_model(ssd_config, is_training, add_summaries, add_background_class=True): """Builds an SSD detection model based on the model config. Args: ssd_config: A ssd.proto object containing the config for the desired SSDMetaArch. is_training: True if this model is being built for training purposes. add_summaries: Whether to add tf summaries in the model. add_background_class: Whether to add an implicit background class to one-hot encodings of groundtruth labels. Set to false if using groundtruth labels with an explicit background class or using multiclass scores instead of truth in the case of distillation. Returns: SSDMetaArch based on the config. Raises: ValueError: If ssd_config.type is not recognized (i.e. not registered in model_class_map). """ num_classes = ssd_config.num_classes # Feature extractor feature_extractor = _build_ssd_feature_extractor( feature_extractor_config=ssd_config.feature_extractor, is_training=is_training) box_coder = box_coder_builder.build(ssd_config.box_coder) matcher = matcher_builder.build(ssd_config.matcher) region_similarity_calculator = sim_calc.build( ssd_config.similarity_calculator) encode_background_as_zeros = ssd_config.encode_background_as_zeros negative_class_weight = ssd_config.negative_class_weight ssd_box_predictor = box_predictor_builder.build(hyperparams_builder.build, ssd_config.box_predictor, is_training, num_classes) anchor_generator = anchor_generator_builder.build( ssd_config.anchor_generator) image_resizer_fn = image_resizer_builder.build(ssd_config.image_resizer) non_max_suppression_fn, score_conversion_fn = post_processing_builder.build( ssd_config.post_processing) (classification_loss, localization_loss, classification_weight, localization_weight, hard_example_miner, random_example_sampler) = losses_builder.build(ssd_config.loss) normalize_loss_by_num_matches = ssd_config.normalize_loss_by_num_matches normalize_loc_loss_by_codesize = ssd_config.normalize_loc_loss_by_codesize return ssd_meta_arch.SSDMetaArch( is_training, anchor_generator, ssd_box_predictor, box_coder, feature_extractor, matcher, region_similarity_calculator, encode_background_as_zeros, negative_class_weight, image_resizer_fn, non_max_suppression_fn, score_conversion_fn, classification_loss, localization_loss, classification_weight, localization_weight, normalize_loss_by_num_matches, hard_example_miner, add_summaries=add_summaries, normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize, freeze_batchnorm=ssd_config.freeze_batchnorm, inplace_batchnorm_update=ssd_config.inplace_batchnorm_update, add_background_class=add_background_class, random_example_sampler=random_example_sampler)