def mine(self, embeddings, labels, ref_emb, ref_labels):
        dtype = embeddings.dtype
        mat = self.distance(embeddings, ref_emb)
        #a1, p, a2, n = lmu.get_all_pairs_indices(labels, ref_labels)
        a1, p, a2, n = get_all_pairs_indices(labels, ref_labels)

        if len(a1) == 0 or len(a2) == 0:
            empty = torch.LongTensor([]).to(labels.device)
            return empty.clone(), empty.clone(), empty.clone(), empty.clone()

        mat_neg_sorting = mat
        mat_pos_sorting = mat.clone()

        pos_ignore = c_f.pos_inf(
            dtype) if self.distance.is_inverted else c_f.neg_inf(dtype)
        neg_ignore = c_f.neg_inf(
            dtype) if self.distance.is_inverted else c_f.pos_inf(dtype)

        mat_pos_sorting[a2, n] = pos_ignore
        mat_neg_sorting[a1, p] = neg_ignore
        if embeddings is ref_emb:
            mat_pos_sorting.fill_diagonal_(pos_ignore)
            mat_neg_sorting.fill_diagonal_(neg_ignore)

        pos_sorted, pos_sorted_idx = torch.sort(mat_pos_sorting, dim=1)
        neg_sorted, neg_sorted_idx = torch.sort(mat_neg_sorting, dim=1)

        if self.distance.is_inverted:
            hard_pos_idx = torch.where(
                pos_sorted - self.epsilon < neg_sorted[:, -1].unsqueeze(1))
            hard_neg_idx = torch.where(
                neg_sorted + self.epsilon > pos_sorted[:, 0].unsqueeze(1))
        else:
            hard_pos_idx = torch.where(
                pos_sorted + self.epsilon > neg_sorted[:, 0].unsqueeze(1))
            hard_neg_idx = torch.where(
                neg_sorted - self.epsilon < pos_sorted[:, -1].unsqueeze(1))

        a1 = hard_pos_idx[0]
        p = pos_sorted_idx[a1, hard_pos_idx[1]]
        a2 = hard_neg_idx[0]
        n = neg_sorted_idx[a2, hard_neg_idx[1]]

        return a1, p, a2, n
    def test_multi_similarity_miner(self):
        epsilon = 0.1
        miner = MultiSimilarityMiner(epsilon)
        for dtype in [torch.float16, torch.float32, torch.float64]:
            embedding_angles = torch.arange(0, 64)
            embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings
            labels = torch.randint(low=0, high=10, size=(64,))
            pos_pairs = []
            neg_pairs = []
            for i in range(len(embeddings)):
                anchor, anchor_label = embeddings[i], labels[i]
                for j in range(len(embeddings)):
                    if j != i:
                        other, other_label = embeddings[j], labels[j]
                        if anchor_label == other_label:
                            pos_pairs.append((i,j,torch.matmul(anchor, other.t()).item()))
                        if anchor_label != other_label:
                            neg_pairs.append((i,j,torch.matmul(anchor, other.t()).item()))
            
            correct_a1, correct_p = [], []
            correct_a2, correct_n = [], []
            for a1,p,ap_sim in pos_pairs:
                max_neg_sim = c_f.neg_inf(dtype)
                for a2,n,an_sim in neg_pairs:
                    if a2==a1:
                        if an_sim > max_neg_sim:
                            max_neg_sim = an_sim
                if ap_sim < max_neg_sim + epsilon:
                    correct_a1.append(a1)
                    correct_p.append(p)

            for a2,n,an_sim in neg_pairs:
                min_pos_sim = c_f.pos_inf(dtype)
                for a1,p,ap_sim in pos_pairs:
                    if a2==a1:
                        if ap_sim < min_pos_sim:
                            min_pos_sim = ap_sim
                if an_sim > min_pos_sim - epsilon:
                    correct_a2.append(a2)
                    correct_n.append(n)

            correct_pos_pairs = set([(a,p) for a,p in zip(correct_a1, correct_p)])
            correct_neg_pairs = set([(a,n) for a,n in zip(correct_a2, correct_n)])

            a1, p1, a2, n2 = miner(embeddings, labels)
            pos_pairs = set([(a.item(),p.item()) for a,p in zip(a1,p1)])
            neg_pairs = set([(a.item(),n.item()) for a,n in zip(a2,n2)])

            self.assertTrue(pos_pairs == correct_pos_pairs)
            self.assertTrue(neg_pairs == correct_neg_pairs)
Exemplo n.º 3
0
    def test_multi_similarity_miner(self):
        epsilon = 0.1
        for dtype in TEST_DTYPES:
            for distance in [CosineSimilarity(), LpDistance()]:
                miner = MultiSimilarityMiner(epsilon, distance=distance)
                embedding_angles = torch.arange(0, 64)
                embeddings = torch.tensor(
                    [c_f.angle_to_coord(a) for a in embedding_angles],
                    requires_grad=True,
                    dtype=dtype,
                ).to(TEST_DEVICE)  # 2D embeddings
                labels = torch.randint(low=0, high=10, size=(64, ))
                mat = distance(embeddings)
                pos_pairs = []
                neg_pairs = []
                for i in range(len(embeddings)):
                    anchor_label = labels[i]
                    for j in range(len(embeddings)):
                        if j != i:
                            other_label = labels[j]
                            if anchor_label == other_label:
                                pos_pairs.append((i, j, mat[i, j]))
                            if anchor_label != other_label:
                                neg_pairs.append((i, j, mat[i, j]))

                correct_a1, correct_p = [], []
                correct_a2, correct_n = [], []
                for a1, p, ap_sim in pos_pairs:
                    most_difficult = (c_f.neg_inf(dtype)
                                      if distance.is_inverted else
                                      c_f.pos_inf(dtype))
                    for a2, n, an_sim in neg_pairs:
                        if a2 == a1:
                            condition = ((an_sim > most_difficult)
                                         if distance.is_inverted else
                                         (an_sim < most_difficult))
                            if condition:
                                most_difficult = an_sim
                    condition = ((ap_sim < most_difficult + epsilon)
                                 if distance.is_inverted else
                                 (ap_sim > most_difficult - epsilon))
                    if condition:
                        correct_a1.append(a1)
                        correct_p.append(p)

                for a2, n, an_sim in neg_pairs:
                    most_difficult = (c_f.pos_inf(dtype)
                                      if distance.is_inverted else
                                      c_f.neg_inf(dtype))
                    for a1, p, ap_sim in pos_pairs:
                        if a2 == a1:
                            condition = ((ap_sim < most_difficult)
                                         if distance.is_inverted else
                                         (ap_sim > most_difficult))
                            if condition:
                                most_difficult = ap_sim
                    condition = ((an_sim > most_difficult - epsilon)
                                 if distance.is_inverted else
                                 (an_sim < most_difficult + epsilon))
                    if condition:
                        correct_a2.append(a2)
                        correct_n.append(n)

                correct_pos_pairs = set([
                    (a, p) for a, p in zip(correct_a1, correct_p)
                ])
                correct_neg_pairs = set([
                    (a, n) for a, n in zip(correct_a2, correct_n)
                ])

                a1, p1, a2, n2 = miner(embeddings, labels)
                pos_pairs = set([(a.item(), p.item()) for a, p in zip(a1, p1)])
                neg_pairs = set([(a.item(), n.item()) for a, n in zip(a2, n2)])

                self.assertTrue(pos_pairs == correct_pos_pairs)
                self.assertTrue(neg_pairs == correct_neg_pairs)