def test_raise_error_on_empty_matcher(self): matcher_text_proto = """ """ matcher_proto = matcher_pb2.Matcher() text_format.Merge(matcher_text_proto, matcher_proto) with self.assertRaises(ValueError): matcher_builder.build(matcher_proto)
def test_build_bipartite_matcher(self): matcher_text_proto = """ bipartite_matcher { } """ matcher_proto = matcher_pb2.Matcher() text_format.Merge(matcher_text_proto, matcher_proto) matcher_object = matcher_builder.build(matcher_proto) self.assertTrue( isinstance(matcher_object, bipartite_matcher.GreedyBipartiteMatcher))
def test_build_bipartite_matcher(self): if tf_version.is_tf2(): self.skipTest('BipartiteMatcher unsupported in TF 2.X. Skipping.') matcher_text_proto = """ bipartite_matcher { } """ matcher_proto = matcher_pb2.Matcher() text_format.Merge(matcher_text_proto, matcher_proto) matcher_object = matcher_builder.build(matcher_proto) self.assertIsInstance(matcher_object, bipartite_matcher.GreedyBipartiteMatcher)
def test_build_arg_max_matcher_with_defaults(self): matcher_text_proto = """ argmax_matcher { } """ matcher_proto = matcher_pb2.Matcher() text_format.Merge(matcher_text_proto, matcher_proto) matcher_object = matcher_builder.build(matcher_proto) self.assertIsInstance(matcher_object, argmax_matcher.ArgMaxMatcher) self.assertAlmostEqual(matcher_object._matched_threshold, 0.5) self.assertAlmostEqual(matcher_object._unmatched_threshold, 0.5) self.assertTrue(matcher_object._negatives_lower_than_unmatched) self.assertFalse(matcher_object._force_match_for_each_row)
def test_build_arg_max_matcher_without_thresholds(self): matcher_text_proto = """ argmax_matcher { ignore_thresholds: true } """ matcher_proto = matcher_pb2.Matcher() text_format.Merge(matcher_text_proto, matcher_proto) matcher_object = matcher_builder.build(matcher_proto) self.assertTrue(isinstance(matcher_object, argmax_matcher.ArgMaxMatcher)) self.assertEqual(matcher_object._matched_threshold, None) self.assertEqual(matcher_object._unmatched_threshold, None) self.assertTrue(matcher_object._negatives_lower_than_unmatched) self.assertFalse(matcher_object._force_match_for_each_row)
def test_build_arg_max_matcher_with_non_default_parameters(self): matcher_text_proto = """ argmax_matcher { matched_threshold: 0.7 unmatched_threshold: 0.3 negatives_lower_than_unmatched: false force_match_for_each_row: true } """ matcher_proto = matcher_pb2.Matcher() text_format.Merge(matcher_text_proto, matcher_proto) matcher_object = matcher_builder.build(matcher_proto) self.assertTrue(isinstance(matcher_object, argmax_matcher.ArgMaxMatcher)) self.assertAlmostEqual(matcher_object._matched_threshold, 0.7) self.assertAlmostEqual(matcher_object._unmatched_threshold, 0.3) self.assertFalse(matcher_object._negatives_lower_than_unmatched) self.assertTrue(matcher_object._force_match_for_each_row)
def generate_ssd_model(num_classes): ssd_config = ssd_pb2.Ssd() ssd_config.num_classes = num_classes # config box_coder from object_detection.protos import box_coder_pb2 from object_detection.protos import faster_rcnn_box_coder_pb2 box_coder = box_coder_pb2.BoxCoder() faster_rcnn_box_coder = faster_rcnn_box_coder_pb2.FasterRcnnBoxCoder() faster_rcnn_box_coder.y_scale = 10.0 faster_rcnn_box_coder.x_scale = 10.0 faster_rcnn_box_coder.height_scale = 5.0 faster_rcnn_box_coder.width_scale = 5.0 box_coder.faster_rcnn_box_coder.CopyFrom(faster_rcnn_box_coder) ssd_config.box_coder.CopyFrom(box_coder) # config matcher from object_detection.protos import matcher_pb2 from object_detection.protos import argmax_matcher_pb2 argmax_matcher = argmax_matcher_pb2.ArgMaxMatcher() argmax_matcher.matched_threshold = 0.5 argmax_matcher.unmatched_threshold = 0.5 argmax_matcher.ignore_thresholds = False argmax_matcher.negatives_lower_than_unmatched = True argmax_matcher.force_match_for_each_row = True matcher = matcher_pb2.Matcher() matcher.argmax_matcher.CopyFrom(argmax_matcher) ssd_config.matcher.CopyFrom(matcher) # config anchor generator from object_detection.protos import anchor_generator_pb2 from object_detection.protos import ssd_anchor_generator_pb2 ssd_anchor_generator = ssd_anchor_generator_pb2.SsdAnchorGenerator() ssd_anchor_generator.num_layers = 6 ssd_anchor_generator.min_scale = 0.2 ssd_anchor_generator.max_scale = 0.95 ssd_anchor_generator.aspect_ratios.extend([1.0, 2.0, 0.5, 3.0, 0.3333]) anchor_generator = anchor_generator_pb2.AnchorGenerator() anchor_generator.ssd_anchor_generator.CopyFrom(ssd_anchor_generator) ssd_config.anchor_generator.CopyFrom(anchor_generator) return ssd_config