Example #1
0
    def test_convert_to_triplets(self):
        a1 = torch.LongTensor([0, 1, 2, 3])
        p = torch.LongTensor([4, 4, 4, 4])
        a2 = torch.LongTensor([4, 5, 6, 7])
        n = torch.LongTensor([5, 5, 6, 6])
        triplets = lmu.convert_to_triplets((a1, p, a2, n), labels=torch.arange(7))
        self.assertTrue(all(len(x) == 0 for x in triplets))

        a2 = torch.LongTensor([0, 4, 5, 6])
        triplets = lmu.convert_to_triplets((a1, p, a2, n), labels=torch.arange(7))
        self.assertTrue(
            triplets == (torch.tensor([0]), torch.tensor([4]), torch.tensor([5]))
        )

        a1 = torch.LongTensor([0, 1, 0, 2])
        p = torch.LongTensor([5, 6, 7, 8])
        a2 = torch.LongTensor([0, 1, 2, 0])
        n = torch.LongTensor([9, 10, 11, 12])
        triplets = lmu.convert_to_triplets((a1, p, a2, n), labels=torch.arange(13))
        triplets = torch.stack(triplets, dim=1)
        found_set = set()
        for t in triplets:
            found_set.add(tuple(t.cpu().numpy()))
        correct_triplets = {
            (0, 5, 9),
            (0, 5, 12),
            (0, 7, 9),
            (0, 7, 12),
            (1, 6, 10),
            (2, 8, 11),
        }

        self.assertTrue(found_set == correct_triplets)
    def test_convert_to_triplets(self):
        a1 = torch.LongTensor([0,1,2,3])
        p = torch.LongTensor([4,4,4,4])
        a2 = torch.LongTensor([4,5,6,7])
        n = torch.LongTensor([5,5,6,6])
        triplets = lmu.convert_to_triplets((a1,p,a2,n), labels=torch.arange(7))
        self.assertTrue(all(len(x)==0 for x in triplets))

        a2 = torch.LongTensor([0,4,5,6])
        triplets = lmu.convert_to_triplets((a1,p,a2,n), labels=torch.arange(7))
        self.assertTrue(triplets==[torch.LongTensor([0]),torch.LongTensor([4]), torch.LongTensor([5])])
    def compute_loss(self, embeddings, labels, indices_tuple):
        self.num_non_zero_pos_pairs, self.num_non_zero_neg_pairs = 0, 0
        self.num_non_zero_triplets_triplet_loss_only = 0
        self.num_non_zero_triplets = 0

        indices_tuple = lmu.convert_to_triplets(
            indices_tuple, labels, t_per_anchor=self.triplets_per_anchor)

        anchor_idx, positive_idx, negative_idx = indices_tuple

        #print('Anchors, positives, negatives', anchor_idx, positive_idx, negative_idx)
        if len(anchor_idx) == 0:
            return self.zero_losses()

        anchors, positives, negatives = embeddings[anchor_idx], embeddings[
            positive_idx], embeddings[negative_idx]

        mat = lmu.get_pairwise_mat(embeddings,
                                   embeddings,
                                   use_similarity=False,
                                   squared=False)

        #print(mat)
        a_p_dist = mat[anchor_idx, positive_idx]
        a_n_dist = mat[anchor_idx, negative_idx]
        p_n_dist = mat[positive_idx, negative_idx]
        #print('An dist by mat', a_n_dist)

        #print('AP', a_p_dist, 'AN', a_n_dist, 'PN', p_n_dist)
        dist = a_p_dist - a_n_dist

        # Compute triplet loss
        triplet_loss = F.relu(dist + self.triplet_margin)
        self.num_non_zero_triplets_triplet_loss_only = (triplet_loss >
                                                        0).nonzero().size(0)
        #print('Triplet loss', triplet_loss)

        # Compute pos contrastive loss
        contrastive_pos = F.relu(a_p_dist - self.pos_margin)
        self.num_non_zero_pos_pairs = (contrastive_pos > 0).nonzero().size(0)
        #print('Contrastive pos', contrastive_pos)

        # Compute neg contrastive loss
        contrastive_neg = F.relu(self.neg_margin - a_n_dist)
        self.num_non_zero_neg_pairs = (contrastive_neg > 0).nonzero().size(0)
        #print('Contrastive neg', contrastive_neg)

        full_loss = self.alpha * triplet_loss + (contrastive_pos +
                                                 contrastive_neg)
        self.num_non_zero_triplets = (full_loss > 0).nonzero().size(0)

        #print(full_loss)
        loss_dict = {
            "loss": {
                "losses": full_loss,
                "indices": anchor_idx,
                "reduction_type": "element"
            }
        }
        return loss_dict
    def compute_loss(self, embeddings, labels, indices_tuple):
        self.num_non_zero_pos_pairs, self.num_non_zero_neg_pairs = 0, 0
        self.num_non_zero_triplets_triplet_loss_only = 0
        self.num_non_zero_triplets = 0
            
        indices_tuple = lmu.convert_to_triplets(indices_tuple, labels, t_per_anchor=self.triplets_per_anchor)
        
        anchor_idx, positive_idx, negative_idx = indices_tuple
        if len(anchor_idx) == 0:
            return self.zero_losses()
        
        anchors, positives, negatives = embeddings[anchor_idx], embeddings[positive_idx], embeddings[negative_idx]

        a_p_dist = F.pairwise_distance(anchors, positives, self.distance_norm)
        a_n_dist = F.pairwise_distance(anchors, negatives, self.distance_norm)
        p_n_dist = F.pairwise_distance(positives, negatives, self.distance_norm)
        
        #print('AP', a_p_dist, 'AN', a_n_dist, 'PN', p_n_dist)
        dist = a_p_dist - torch.min(a_n_dist, p_n_dist)
        
        # Compute triplet loss
        triplet_loss = F.relu(dist + self.triplet_margin)
        self.num_non_zero_triplets_triplet_loss_only = (triplet_loss > 0).nonzero().size(0)
        #print('Triplet loss', triplet_loss)
        
        # Compute pos contrastive loss
        contrastive_pos = F.relu(a_p_dist - self.pos_margin)
        self.num_non_zero_pos_pairs = (contrastive_pos > 0).nonzero().size(0)
        #print('Contrastive pos', contrastive_pos)
        
        # Compute neg contrastive loss
        contrastive_neg = F.relu(self.neg_margin - torch.min(a_n_dist, p_n_dist))
        self.num_non_zero_neg_pairs = (contrastive_neg > 0).nonzero().size(0)
        #print('Contrastive neg', contrastive_neg)
        
        full_loss = triplet_loss + contrastive_pos + contrastive_neg 
        self.num_non_zero_triplets = (full_loss > 0).nonzero().size(0)
        
        #print(full_loss)
        loss_dict = {"loss": {"losses": full_loss, "indices": anchor_idx, "reduction_type": "element"}}
        return loss_dict
 def pair_based_loss(self, mat, labels, indices_tuple):
     #print('indices_tuple', indices_tuple)
     
     a1, p, a2, n = indices_tuple
     pos_pair, neg_pair = [], []
     if len(a1) > 0:
         pos_pair = mat[a1, p]
     if len(a2) > 0:
         neg_pair = mat[a2, n]
     loss_dict = self._compute_loss(pos_pair, neg_pair, indices_tuple)
     #print('contrastive losses', loss_dict)
     
     triplet_indices_tuple = lmu.convert_to_triplets(indices_tuple, labels, t_per_anchor='all')
     #print('triplet_indices_tuple', triplet_indices_tuple)
     anchor_idx, positive_idx, negative_idx = triplet_indices_tuple
     a_p_dist = mat[anchor_idx, positive_idx] ** self.power
     a_n_dist = mat[anchor_idx, negative_idx] ** self.power
     
     triplet_loss = self.triplet_weight * F.relu(a_p_dist - a_n_dist + self.triplet_margin)
     #print('triplet_loss', triplet_loss)
     loss_dict['triplet_loss'] = {"losses": triplet_loss, "indices": triplet_indices_tuple, "reduction_type": "triplet"}
     return loss_dict