Ejemplo n.º 1
0
 def graph_fn(similarity):
     matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=3.)
     match = matcher.match(similarity)
     matched_cols = match.matched_column_indicator()
     unmatched_cols = match.unmatched_column_indicator()
     match_results = match.match_results
     return (matched_cols, unmatched_cols, match_results)
Ejemplo n.º 2
0
def build(matcher_config):
    """Builds a matcher object based on the matcher config.

  Args:
    matcher_config: A matcher.proto object containing the config for the desired
      Matcher.

  Returns:
    Matcher based on the config.

  Raises:
    ValueError: On empty matcher proto.
  """
    if not isinstance(matcher_config, matcher_pb2.Matcher):
        raise ValueError('matcher_config not of type matcher_pb2.Matcher.')
    if matcher_config.WhichOneof('matcher_oneof') == 'argmax_matcher':
        matcher = matcher_config.argmax_matcher
        matched_threshold = unmatched_threshold = None
        if not matcher.ignore_thresholds:
            matched_threshold = matcher.matched_threshold
            unmatched_threshold = matcher.unmatched_threshold
        return argmax_matcher.ArgMaxMatcher(
            matched_threshold=matched_threshold,
            unmatched_threshold=unmatched_threshold,
            negatives_lower_than_unmatched=matcher.
            negatives_lower_than_unmatched,
            force_match_for_each_row=matcher.force_match_for_each_row,
            use_matmul_gather=matcher.use_matmul_gather)
    if matcher_config.WhichOneof('matcher_oneof') == 'bipartite_matcher':
        matcher = matcher_config.bipartite_matcher
        return bipartite_matcher.GreedyBipartiteMatcher(
            matcher.use_matmul_gather)
    raise ValueError('Empty matcher.')
Ejemplo n.º 3
0
 def graph_fn(similarity):
     matcher = argmax_matcher.ArgMaxMatcher(
         matched_threshold=3.,
         unmatched_threshold=2.,
         negatives_lower_than_unmatched=False)
     match = matcher.match(similarity)
     matched_cols = match.matched_column_indicator()
     unmatched_cols = match.unmatched_column_indicator()
     match_results = match.match_results
     return (matched_cols, unmatched_cols, match_results)
Ejemplo n.º 4
0
 def graph_fn(anchor_means, groundtruth_box_corners):
   similarity_calc = region_similarity_calculator.IouSimilarity()
   matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=0.5,
                                          unmatched_threshold=0.3)
   box_coder = mean_stddev_box_coder.MeanStddevBoxCoder(stddev=0.1)
   target_assigner = targetassigner.TargetAssigner(
       similarity_calc, matcher, box_coder)
   anchors_boxlist = box_list.BoxList(anchor_means)
   groundtruth_boxlist = box_list.BoxList(groundtruth_box_corners)
   result = target_assigner.assign(
       anchors_boxlist, groundtruth_boxlist, unmatched_class_label=None)
   (cls_targets, cls_weights, reg_targets, reg_weights, _) = result
   return (cls_targets, cls_weights, reg_targets, reg_weights)
Ejemplo n.º 5
0
 def graph_fn(anchor_means, groundtruth_box_corners,
              groundtruth_keypoints):
   similarity_calc = region_similarity_calculator.IouSimilarity()
   matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=0.5,
                                          unmatched_threshold=0.5)
   box_coder = keypoint_box_coder.KeypointBoxCoder(
       num_keypoints=6, scale_factors=[10.0, 10.0, 5.0, 5.0])
   target_assigner = targetassigner.TargetAssigner(
       similarity_calc, matcher, box_coder)
   anchors_boxlist = box_list.BoxList(anchor_means)
   groundtruth_boxlist = box_list.BoxList(groundtruth_box_corners)
   groundtruth_boxlist.add_field(fields.BoxListFields.keypoints,
                                 groundtruth_keypoints)
   result = target_assigner.assign(
       anchors_boxlist, groundtruth_boxlist, unmatched_class_label=None)
   (cls_targets, cls_weights, reg_targets, reg_weights, _) = result
   return (cls_targets, cls_weights, reg_targets, reg_weights)
Ejemplo n.º 6
0
    def graph_fn(anchor_means, groundtruth_box_corners, groundtruth_labels,
                 groundtruth_weights):
      similarity_calc = region_similarity_calculator.IouSimilarity()
      matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=0.5,
                                             unmatched_threshold=0.5)
      box_coder = mean_stddev_box_coder.MeanStddevBoxCoder(stddev=0.1)
      unmatched_class_label = tf.constant([1, 0, 0, 0, 0, 0, 0], tf.float32)
      target_assigner = targetassigner.TargetAssigner(
          similarity_calc, matcher, box_coder)

      anchors_boxlist = box_list.BoxList(anchor_means)
      groundtruth_boxlist = box_list.BoxList(groundtruth_box_corners)
      result = target_assigner.assign(
          anchors_boxlist,
          groundtruth_boxlist,
          groundtruth_labels,
          unmatched_class_label=unmatched_class_label,
          groundtruth_weights=groundtruth_weights)
      (_, cls_weights, _, reg_weights, _) = result
      return (cls_weights, reg_weights)
Ejemplo n.º 7
0
 def _get_target_assigner(self):
   similarity_calc = region_similarity_calculator.IouSimilarity()
   matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=0.5,
                                          unmatched_threshold=0.5)
   box_coder = mean_stddev_box_coder.MeanStddevBoxCoder(stddev=0.1)
   return targetassigner.TargetAssigner(similarity_calc, matcher, box_coder)
Ejemplo n.º 8
0
def create_target_assigner(reference,
                           stage=None,
                           negative_class_weight=1.0,
                           use_matmul_gather=False):
    """Factory function for creating standard target assigners.

  Args:
    reference: string referencing the type of TargetAssigner.
    stage: string denoting stage: {proposal, detection}.
    negative_class_weight: classification weight to be associated to negative
      anchors (default: 1.0)
    use_matmul_gather: whether to use matrix multiplication based gather which
      are better suited for TPUs.

  Returns:
    TargetAssigner: desired target assigner.

  Raises:
    ValueError: if combination reference+stage is invalid.
  """
    if reference == 'Multibox' and stage == 'proposal':
        similarity_calc = sim_calc.NegSqDistSimilarity()
        matcher = bipartite_matcher.GreedyBipartiteMatcher()
        box_coder = mean_stddev_box_coder.MeanStddevBoxCoder()

    elif reference == 'FasterRCNN' and stage == 'proposal':
        similarity_calc = sim_calc.IouSimilarity()
        matcher = argmax_matcher.ArgMaxMatcher(
            matched_threshold=0.7,
            unmatched_threshold=0.3,
            force_match_for_each_row=True,
            use_matmul_gather=use_matmul_gather)
        box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder(
            scale_factors=[10.0, 10.0, 5.0, 5.0])

    elif reference == 'FasterRCNN' and stage == 'detection':
        similarity_calc = sim_calc.IouSimilarity()
        # Uses all proposals with IOU < 0.5 as candidate negatives.
        matcher = argmax_matcher.ArgMaxMatcher(
            matched_threshold=0.5,
            negatives_lower_than_unmatched=True,
            use_matmul_gather=use_matmul_gather)
        box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder(
            scale_factors=[10.0, 10.0, 5.0, 5.0])

    elif reference == 'FastRCNN':
        similarity_calc = sim_calc.IouSimilarity()
        matcher = argmax_matcher.ArgMaxMatcher(
            matched_threshold=0.5,
            unmatched_threshold=0.1,
            force_match_for_each_row=False,
            negatives_lower_than_unmatched=False,
            use_matmul_gather=use_matmul_gather)
        box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder()

    else:
        raise ValueError('No valid combination of reference and stage.')

    return TargetAssigner(similarity_calc,
                          matcher,
                          box_coder,
                          negative_class_weight=negative_class_weight)
Ejemplo n.º 9
0
 def graph_fn(similarity_matrix):
     matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=None)
     match = matcher.match(similarity_matrix)
     return match.unmatched_column_indicator()
Ejemplo n.º 10
0
 def test_invalid_arguments_unmatched_thres_larger_than_matched_thres(self):
     with self.assertRaises(ValueError):
         argmax_matcher.ArgMaxMatcher(matched_threshold=1,
                                      unmatched_threshold=2)
Ejemplo n.º 11
0
 def test_invalid_arguments_no_matched_threshold(self):
     with self.assertRaises(ValueError):
         argmax_matcher.ArgMaxMatcher(matched_threshold=None,
                                      unmatched_threshold=4)
Ejemplo n.º 12
0
 def test_invalid_arguments_corner_case_negatives_lower_than_thres_false(
         self):
     with self.assertRaises(ValueError):
         argmax_matcher.ArgMaxMatcher(matched_threshold=1,
                                      unmatched_threshold=1,
                                      negatives_lower_than_unmatched=False)
Ejemplo n.º 13
0
 def test_valid_arguments_corner_case(self):
     argmax_matcher.ArgMaxMatcher(matched_threshold=1,
                                  unmatched_threshold=1)