예제 #1
0
    def test_circle_loss(self):
        margin, gamma = 0.25, 256
        Op, On = 1 + margin, -margin
        delta_p, delta_n = 1 - margin, margin
        loss_func = CircleLoss(m=margin, gamma=gamma)

        for dtype in TEST_DTYPES:
            embedding_angles = [0, 20, 40, 60, 80]
            embeddings = torch.tensor(
                [c_f.angle_to_coord(a) for a in embedding_angles],
                requires_grad=True,
                dtype=dtype).to(self.device)  #2D embeddings
            labels = torch.LongTensor([0, 0, 1, 1, 2])

            loss = loss_func(embeddings, labels)
            loss.backward()

            pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)]
            neg_pairs = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4),
                         (2, 0), (2, 1), (2, 4), (3, 0), (3, 1), (3, 4),
                         (4, 0), (4, 1), (4, 2), (4, 3)]

            correct_total = 0
            for i in range(len(embeddings)):
                pos_logits = []
                neg_logits = []
                for a, p in pos_pairs:
                    if a == i:
                        anchor, positive = embeddings[a], embeddings[p]
                        ap_sim = torch.matmul(anchor, positive)
                        logit_p = -gamma * torch.relu(Op - ap_sim) * (ap_sim -
                                                                      delta_p)
                        pos_logits.append(logit_p.unsqueeze(0))

                for a, n in neg_pairs:
                    if a == i:
                        anchor, negative = embeddings[a], embeddings[n]
                        an_sim = torch.matmul(anchor, negative)
                        logit_n = gamma * torch.relu(an_sim - On) * (an_sim -
                                                                     delta_n)
                        neg_logits.append(logit_n.unsqueeze(0))

                if len(pos_logits) == 0 or len(neg_logits) == 0:
                    pass
                else:
                    pos_logits = torch.cat(
                        pos_logits,
                        dim=0) if len(pos_logits) > 1 else pos_logits[0]
                    neg_logits = torch.cat(
                        neg_logits,
                        dim=0) if len(neg_logits) > 1 else neg_logits[0]
                    correct_total += torch.nn.functional.softplus(
                        torch.logsumexp(pos_logits, dim=0) +
                        torch.logsumexp(neg_logits, dim=0))

            correct_total /= 4  # only 4 of the embeddings have both pos and neg pairs
            self.assertTrue(torch.isclose(loss, correct_total))
예제 #2
0
 def test_overflow(self):
     margin, gamma = 0.4, 300
     loss_func = CircleLoss(m=margin, gamma=gamma)
     for dtype in TEST_DTYPES:
         embedding_angles = [0, 20, 40, 60, 80]
         embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings
         labels = torch.LongTensor([0, 0, 1, 1, 2])
         loss = loss_func(embeddings, labels)
         loss.backward()
         self.assertTrue(not torch.isnan(loss) and not torch.isinf(loss))
예제 #3
0
 def test_with_no_valid_pairs(self):
     margin, gamma = 0.4, 80
     loss_func = CircleLoss(m=margin, gamma=gamma)
     for dtype in TEST_DTYPES:
         embedding_angles = [0]
         embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings
         labels = torch.LongTensor([0])
         loss = loss_func(embeddings, labels)
         loss.backward()
         self.assertEqual(loss, 0)
예제 #4
0
 def test_overflow(self):
     margin, gamma = 0.4, 300
     loss_func = CircleLoss(m=margin, gamma=gamma)
     for dtype in TEST_DTYPES:
         embedding_angles = torch.arange(0, 180)
         embeddings = torch.tensor(
             [c_f.angle_to_coord(a) for a in embedding_angles],
             requires_grad=True,
             dtype=dtype).to(self.device)  #2D embeddings
         labels = torch.randint(low=0, high=10, size=(180, ))
         loss = loss_func(embeddings, labels)
         loss.backward()
         self.assertTrue(not torch.isnan(loss) and not torch.isinf(loss))
         self.assertTrue(loss > 0)
    def test_circle_loss(self):
        margin, gamma = 0.4, 2
        Op, On = 1 + margin, -margin
        delta_p, delta_n = 1 - margin, margin
        loss_func = CircleLoss(m=margin, gamma=gamma)

        for dtype in [torch.float16, torch.float32, torch.float64]:
            embedding_angles = [0, 20, 40, 60, 80]
            embeddings = torch.tensor(
                [c_f.angle_to_coord(a) for a in embedding_angles],
                requires_grad=True,
                dtype=dtype).to(self.device)  #2D embeddings
            labels = torch.LongTensor([0, 0, 1, 1, 2])

            loss = loss_func(embeddings, labels)
            loss.backward()

            pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)]
            neg_pairs = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4),
                         (2, 0), (2, 1), (2, 4), (3, 0), (3, 1), (3, 4),
                         (4, 0), (4, 1), (4, 2), (4, 3)]

            correct_total = 0
            totals = []
            for i in range(len(embeddings)):
                pos_exp = 0
                neg_exp = 0
                for a, p in pos_pairs:
                    if a == i:
                        anchor, positive = embeddings[a], embeddings[p]
                        ap_sim = torch.matmul(anchor, positive)
                        logit_p = -gamma * torch.relu(Op - ap_sim) * (ap_sim -
                                                                      delta_p)
                        pos_exp += torch.exp(logit_p)

                for a, n in neg_pairs:
                    if a == i:
                        anchor, negative = embeddings[a], embeddings[n]
                        an_sim = torch.matmul(anchor, negative)
                        logit_n = gamma * torch.relu(an_sim - On) * (an_sim -
                                                                     delta_n)
                        neg_exp += torch.exp(logit_n)

                totals.append(torch.log(1 + pos_exp * neg_exp))
                correct_total += torch.log(1 + pos_exp * neg_exp)

            correct_total /= 4  # only 4 of the embeddings have both pos and neg pairs
            self.assertTrue(torch.isclose(loss, correct_total))
예제 #6
0
 def __init__(self, **kwargs):
     super(Baseline, self).__init__()
     self.__dict__.update(kwargs)
     self.dropout = nn.Dropout(self.dropout_rate)
     self.gru1 = nn.GRU(input_size=self.d_v,
                        hidden_size=self.d_v_h,
                        bidirectional=True)
     self.gru2 = nn.GRU(input_size=self.d_s,
                        hidden_size=self.d_s_h,
                        bidirectional=True)
     self.fc1 = nn.Linear(2 * self.d_v_h, self.final_dim)
     self.fc2 = nn.Linear(2 * self.d_s_h, self.final_dim)
     self.norm1 = nn.BatchNorm1d(self.final_dim)
     self.norm2 = nn.BatchNorm1d(self.final_dim)
     self.mrr_metric = RetrievalMRR()
     self.circle_loss = CircleLoss(m=self.margin)