예제 #1
0
 def test_raise_error_on_empty_matcher(self):
     matcher_text_proto = """
 """
     matcher_proto = matcher_pb2.Matcher()
     text_format.Merge(matcher_text_proto, matcher_proto)
     with self.assertRaises(ValueError):
         matcher_builder.build(matcher_proto)
예제 #2
0
 def test_build_bipartite_matcher(self):
     matcher_text_proto = """
   bipartite_matcher {
   }
 """
     matcher_proto = matcher_pb2.Matcher()
     text_format.Merge(matcher_text_proto, matcher_proto)
     matcher_object = matcher_builder.build(matcher_proto)
     self.assertTrue(
         isinstance(matcher_object,
                    bipartite_matcher.GreedyBipartiteMatcher))
예제 #3
0
 def test_build_arg_max_matcher_with_defaults(self):
     matcher_text_proto = """
   argmax_matcher {
   }
 """
     matcher_proto = matcher_pb2.Matcher()
     text_format.Merge(matcher_text_proto, matcher_proto)
     matcher_object = matcher_builder.build(matcher_proto)
     self.assertTrue(
         isinstance(matcher_object, argmax_matcher.ArgMaxMatcher))
     self.assertAlmostEqual(matcher_object._matched_threshold, 0.5)
     self.assertAlmostEqual(matcher_object._unmatched_threshold, 0.5)
     self.assertTrue(matcher_object._negatives_lower_than_unmatched)
     self.assertFalse(matcher_object._force_match_for_each_row)
예제 #4
0
 def test_build_arg_max_matcher_with_non_default_parameters(self):
     matcher_text_proto = """
   argmax_matcher {
     matched_threshold: 0.7
     unmatched_threshold: 0.3
     negatives_lower_than_unmatched: false
     force_match_for_each_row: true
     use_matmul_gather: true
   }
 """
     matcher_proto = matcher_pb2.Matcher()
     text_format.Merge(matcher_text_proto, matcher_proto)
     matcher_object = matcher_builder.build(matcher_proto)
     self.assertTrue(
         isinstance(matcher_object, argmax_matcher.ArgMaxMatcher))
     self.assertAlmostEqual(matcher_object._matched_threshold, 0.7)
     self.assertAlmostEqual(matcher_object._unmatched_threshold, 0.3)
     self.assertFalse(matcher_object._negatives_lower_than_unmatched)
     self.assertTrue(matcher_object._force_match_for_each_row)
     self.assertTrue(matcher_object._use_matmul_gather)