def test_lifted_structure_loss(self): neg_margin = 0.5 loss_func = LiftedStructureLoss(neg_margin=neg_margin) embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=torch.float) #2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) loss = loss_func(embeddings, labels) loss.backward() pos_pairs = [(0,1), (1,0), (2,3), (3,2)] neg_pairs = [(0,2), (0,3), (0,4), (1,2), (1,3), (1,4), (2,0), (2,1), (2,4), (3,0), (3,1), (3,4), (4,0), (4,1), (4,2), (4,3)] total_loss = 0 for a1,p in pos_pairs: anchor, positive = embeddings[a1], embeddings[p] pos_pair_component = torch.sqrt(torch.sum((anchor-positive)**2)) neg_pair_component = 0 for a2,n in neg_pairs: negative = embeddings[n] if a2 == a1: neg_pair_component += torch.exp(neg_margin - torch.sqrt(torch.sum((anchor-negative)**2))) elif a2 == p: neg_pair_component += torch.exp(neg_margin - torch.sqrt(torch.sum((positive-negative)**2))) else: continue total_loss += torch.relu(torch.log(neg_pair_component) + pos_pair_component)**2 total_loss /= 2*len(pos_pairs) self.assertTrue(torch.isclose(loss, total_loss))
def test_with_no_valid_pairs(self): loss_func = LiftedStructureLoss(neg_margin=0.5) embedding_angles = [0] embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=torch.float) #2D embeddings labels = torch.LongTensor([0]) loss = loss_func(embeddings, labels) loss.backward() self.assertEqual(loss, 0)