コード例 #1
0
    def filter_candidate_pool(self, candidate_pool, char):
        """Filter candidate pool based on secondary distance metrics and
        selected threshold. Secondary distance is calculated between the image
        of the anchor character and the image of the candidate character.

        Args:
            candidate_pool: Set of Char, the set of possible confusables.
            char: Char, single character whose corresponding image must exists
                in self.img_dir.
        """
        # Get secondary distance metrics
        image_metrics = Distance(self.img_format).get_metrics()
        if self.secondary_distance_type not in image_metrics.keys():
            raise ValueError(
                "Expect secondary_distance_type to be one of {}.".format(
                    image_metrics.keys()))
        secondary_dis = image_metrics[self.secondary_distance_type]

        # Filter candidate pool to get confusables
        confusables = []
        for candidate in candidate_pool:
            if ord(char) == ord(candidate):
                continue
            dis = calculate_from_path(secondary_dis, self._label_img_map[char],
                                      self._label_img_map[candidate])
            if dis <= self.secondary_filter_threshold:
                confusables.append((candidate, dis))

        return confusables
コード例 #2
0
    def test_emb_metrics(self):
        # Test embedding metrics
        dis = Distance(ImgFormat.EMBEDDINGS)
        metrics = dis.get_metrics()

        # Test manhattan distance
        self.assertEqual(metrics.manhattan(self.emb_0, self.emb_1), 3)
        self.assertEqual(metrics.manhattan(self.emb_1, self.emb_1), 0)
        self.assertAlmostEqual(metrics.manhattan(self.emb_123, self.emb_1), 3)
        self.assertAlmostEqual(metrics.manhattan(self.emb_0, self.emb_123), 6)

        # Test euclidean distance
        self.assertAlmostEqual(metrics.euclidean(self.emb_0, self.emb_1),
                               np.sqrt(3))
        self.assertEqual(metrics.euclidean(self.emb_1, self.emb_1), 0)
        self.assertAlmostEqual(metrics.euclidean(self.emb_123, self.emb_0),
                               np.sqrt((1**2) + (2**2) + (3**2)))

        # Test exception
        with self.assertRaises(TypeError):
            metrics.manhattan(self.emb_0.tolist(), self.emb_1)
        with self.assertRaises(ValueError):
            metrics.manhattan(self.emb_0, np.ones(4))
        with self.assertRaises(TypeError):
            metrics.euclidean(self.emb_0.tolist(), self.emb_1)
        with self.assertRaises(ValueError):
            metrics.euclidean(self.emb_0, np.ones(4))
コード例 #3
0
    def get_candidate_pool_for_char(self, char):
        """Obtain the candidates for confusables for specified 'char'. Use
        the reduced representations generated by PCA and select candidates based
        on primary distance metrics.
        
        Args:
            char: Char, single character, must exists in self.labels.

        Returns:
            candidate_pool: Set of Char, the set of possible confusables.
            candidate_dis: Dict, mapping from candidates to their respective
                distances.
        """
        # Get character index in labels and embeddings
        idx = self.labels.index(char)
        # Get a pool of possible candidates for secondary filter
        candidate_pool = set()
        # Store distances between all confusables and anchor
        candidate_dis = dict()
        for embs in self._reps:
            # Get embedding anchor to compare with others
            emb_anchor = embs[idx]

            # Get primary distance metrics
            embedding_metrics = Distance(ImgFormat.EMBEDDINGS).get_metrics()
            if self.primary_distance_type not in embedding_metrics.keys():
                raise ValueError(
                    "Expect primary_distance_type to be one of {}.".format(
                        embedding_metrics.keys()))
            primary_dis = embedding_metrics[self.primary_distance_type]

            # Get distance from anchor embedding to all other embeddings
            distances = []
            for emb in embs:
                distances.append(primary_dis(emb_anchor, emb))
            label_dis_pairs = list(zip(self.labels, distances))

            # Get top n candidates using the primary distance metric
            top_n = []
            for label, dis in label_dis_pairs:
                if len(top_n) < self.n_candidates:
                    # Append reversed tuple for sorting
                    bisect.insort(top_n, (dis, label))
                else:
                    if dis < top_n[self.n_candidates - 1][0]:
                        # If the distance is lower than the largest of the
                        # candidates we only keep top N
                        bisect.insort(top_n, (dis, label))
                        top_n = top_n[:self.n_candidates - 1]

            # Store all candidate distances
            candidate_dis["PCA" + str(embs.shape[1])] = top_n
            candidate_pool = candidate_pool.union(
                set([entry[1] for entry in top_n]))

        return candidate_pool, candidate_dis
コード例 #4
0
    def test_rgb_metrics(self):
        # Test RGB metrics
        dis = Distance(ImgFormat.RGB)
        metrics = dis.get_metrics()

        # Test manhattan distance
        self.assertEqual(metrics.manhattan(self.img_rgb_0, self.img_rgb_255),
                         255.0)
        self.assertEqual(metrics.manhattan(self.img_rgb_255, self.img_rgb_0),
                         255.0)
        self.assertAlmostEqual(
            metrics.manhattan(self.img_rgb_topleft, self.img_rgb_botright),
            (abs(255 - 200) + abs(255 - 100)) / (3 * 3))
        self.assertEqual(metrics.manhattan(self.img_rgb_255, self.img_rgb_255),
                         0)

        # Test sum squared distance
        self.assertEqual(metrics.sum_squared(self.img_rgb_0, self.img_rgb_255),
                         1.0)
        self.assertEqual(
            metrics.sum_squared(self.img_rgb_0, self.img_rgb_topleft), 1.0)
        self.assertAlmostEqual(
            metrics.sum_squared(self.img_rgb_topleft, self.img_rgb_botright),
            0.049633607)
        self.assertAlmostEqual(
            metrics.sum_squared(self.img_rgb_255, self.img_rgb_botright),
            0.005283143)
        self.assertEqual(
            metrics.sum_squared(self.img_rgb_255, self.img_rgb_255), 0)

        # Test cross correlation distance
        self.assertEqual(
            metrics.cross_correlation(self.img_rgb_0, self.img_rgb_255), 0)
        self.assertAlmostEqual(
            metrics.cross_correlation(self.img_rgb_topleft,
                                      self.img_rgb_botright), 0.9755619)
        self.assertAlmostEqual(
            metrics.cross_correlation(self.img_rgb_255, self.img_rgb_botright),
            0.99759716)
        self.assertEqual(
            metrics.cross_correlation(self.img_rgb_255, self.img_rgb_255), 1.0)

        # Test exception
        with self.assertRaises(TypeError):
            metrics.manhattan(self.img_rgb_0.tolist(), self.img_rgb_255)
        with self.assertRaises(ValueError):
            metrics.manhattan(self.img_gray_0, self.img_rgb_255)
        with self.assertRaises(TypeError):
            metrics.sum_squared(self.img_rgb_0.tolist(), self.img_rgb_255)
        with self.assertRaises(ValueError):
            metrics.sum_squared(self.img_gray_0, self.img_rgb_255)
        with self.assertRaises(TypeError):
            metrics.cross_correlation(self.img_rgb_0.tolist(),
                                      self.img_rgb_255)
        with self.assertRaises(ValueError):
            metrics.cross_correlation(self.img_gray_0, self.img_rgb_255)
コード例 #5
0
    def test_img_format_setter(self):
        # Test setter in initialization
        dis = Distance(img_format=ImgFormat.A8)
        self.assertEqual(dis._img_format, ImgFormat.A8)

        # Test setter after initialization
        dis.img_format = ImgFormat.EMBEDDINGS
        self.assertEqual(dis._img_format, ImgFormat.EMBEDDINGS)

        # Test exception
        with self.assertRaises(TypeError):
            dis.img_format = 5
コード例 #6
0
    def test_default_init(self):
        """Test default initialization. When default initialization value
        changes, or any private attribute does not match public attribute, this
        test will fail."""
        dis = Distance()

        self.assertEqual(dis._img_format, ImgFormat.RGB)