Beispiel #1
0
 def test_raise_error_on_empty_matcher(self):
     matcher_text_proto = """
     """
     matcher_proto = matcher_pb2.Matcher()
     text_format.Merge(matcher_text_proto, matcher_proto)
     with self.assertRaises(ValueError):
         matcher_builder.build(matcher_proto)
 def test_build_bipartite_matcher(self):
   matcher_text_proto = """
     bipartite_matcher {
     }
   """
   matcher_proto = matcher_pb2.Matcher()
   text_format.Merge(matcher_text_proto, matcher_proto)
   matcher_object = matcher_builder.build(matcher_proto)
   self.assertTrue(
       isinstance(matcher_object, bipartite_matcher.GreedyBipartiteMatcher))
 def test_build_arg_max_matcher_with_defaults(self):
   matcher_text_proto = """
     argmax_matcher {
     }
   """
   matcher_proto = matcher_pb2.Matcher()
   text_format.Merge(matcher_text_proto, matcher_proto)
   matcher_object = matcher_builder.build(matcher_proto)
   self.assertTrue(isinstance(matcher_object, argmax_matcher.ArgMaxMatcher))
   self.assertAlmostEqual(matcher_object._matched_threshold, 0.5)
   self.assertAlmostEqual(matcher_object._unmatched_threshold, 0.5)
   self.assertTrue(matcher_object._negatives_lower_than_unmatched)
   self.assertFalse(matcher_object._force_match_for_each_row)
 def test_build_arg_max_matcher_with_non_default_parameters(self):
   matcher_text_proto = """
     argmax_matcher {
       matched_threshold: 0.7
       unmatched_threshold: 0.3
       negatives_lower_than_unmatched: false
       force_match_for_each_row: true
       use_matmul_gather: true
     }
   """
   matcher_proto = matcher_pb2.Matcher()
   text_format.Merge(matcher_text_proto, matcher_proto)
   matcher_object = matcher_builder.build(matcher_proto)
   self.assertTrue(isinstance(matcher_object, argmax_matcher.ArgMaxMatcher))
   self.assertAlmostEqual(matcher_object._matched_threshold, 0.7)
   self.assertAlmostEqual(matcher_object._unmatched_threshold, 0.3)
   self.assertFalse(matcher_object._negatives_lower_than_unmatched)
   self.assertTrue(matcher_object._force_match_for_each_row)
   self.assertTrue(matcher_object._use_matmul_gather)