예제 #1
0
def build(region_similarity_calculator_config):
    """Builds region similarity calculator based on the configuration.

  Builds one of [IouSimilarity, IoaSimilarity, NegSqDistSimilarity] objects. See
  core/region_similarity_calculator.proto for details.

  Args:
    region_similarity_calculator_config: RegionSimilarityCalculator
      configuration proto.

  Returns:
    region_similarity_calculator: RegionSimilarityCalculator object.

  Raises:
    ValueError: On unknown region similarity calculator.
  """

    if not isinstance(
            region_similarity_calculator_config,
            region_similarity_calculator_pb2.RegionSimilarityCalculator):
        raise ValueError(
            'region_similarity_calculator_config not of type '
            'region_similarity_calculator_pb2.RegionsSimilarityCalculator')

    similarity_calculator = region_similarity_calculator_config.WhichOneof(
        'region_similarity')
    if similarity_calculator == 'iou_similarity':
        return region_similarity_calculator.IouSimilarity()
    if similarity_calculator == 'ioa_similarity':
        return region_similarity_calculator.IoaSimilarity()
    if similarity_calculator == 'neg_sq_dist_similarity':
        return region_similarity_calculator.NegSqDistSimilarity()

    raise ValueError('Unknown region similarity calculator.')
 def test_get_correct_pairwise_similarity_based_on_ioa(self):
     corners1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
     corners2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
                             [0.0, 0.0, 20.0, 20.0]])
     exp_output_1 = [[2.0 / 12.0, 0, 6.0 / 400.0],
                     [1.0 / 12.0, 0.0, 5.0 / 400.0]]
     exp_output_2 = [[2.0 / 6.0, 1.0 / 5.0], [0, 0], [6.0 / 6.0, 5.0 / 5.0]]
     boxes1 = box_list.BoxList(corners1)
     boxes2 = box_list.BoxList(corners2)
     ioa_similarity_calculator = region_similarity_calculator.IoaSimilarity(
     )
     ioa_similarity_1 = ioa_similarity_calculator.compare(boxes1, boxes2)
     ioa_similarity_2 = ioa_similarity_calculator.compare(boxes2, boxes1)
     with self.test_session() as sess:
         iou_output_1, iou_output_2 = sess.run(
             [ioa_similarity_1, ioa_similarity_2])
         self.assertAllClose(iou_output_1, exp_output_1)
         self.assertAllClose(iou_output_2, exp_output_2)