def test_find_matching_cluster_best_overlap(self): overlaps = tf.constant([1, 1, 1, 11, 2, 1, 1], dtype=tf.float32) overlaps_better = tf.constant([1, 2, 1, 12, 2, 1, 1], dtype=tf.float32) box = tf.constant([1, 3, 1, 13, 2, 1, 1], dtype=tf.float32) cluster_index = wbf.find_matching_cluster((overlaps,), box) self.assertAllClose(cluster_index, 0) cluster_index = wbf.find_matching_cluster((overlaps, overlaps_better), box) self.assertAllClose(cluster_index, 1)
def test_find_matching_cluster_matches(self): matching_cluster = tf.constant([1, 1, 1, 2, 2, 1, 1], dtype=tf.float32) non_matching_cluster = tf.constant([1, 3, 3, 2, 2, 1, 1], dtype=tf.float32) box = tf.constant([1, 1, 1, 2, 2, 1, 1], dtype=tf.float32) cluster_index = wbf.find_matching_cluster( (matching_cluster, non_matching_cluster), box) self.assertAllClose(cluster_index, 0) cluster_index = wbf.find_matching_cluster( (non_matching_cluster, matching_cluster), box) self.assertAllClose(cluster_index, 1)