예제 #1
0
 def test_no_pos(self):
     pos_scores = torch.empty((0, ), requires_grad=True)
     neg_scores = torch.empty((0, 3), requires_grad=True)
     loss_fn = RankingLossFunction(margin=1.0)
     loss = loss_fn(pos_scores, neg_scores)
     self.assertTensorEqual(loss, torch.zeros(()))
     loss.backward()
예제 #2
0
 def test_forward_bad(self):
     pos_scores = torch.full((3, ), -1.0, requires_grad=True)
     neg_scores = torch.zeros((3, 5), requires_grad=True)
     loss_fn = RankingLossFunction(margin=1.0)
     loss = loss_fn(pos_scores, neg_scores)
     self.assertTensorEqual(loss, torch.tensor(30.0))
     loss.backward()
예제 #3
0
 def test_forward_good(self):
     pos_scores = torch.full((3, ), 2, requires_grad=True)
     neg_scores = torch.full((3, 5), 1, requires_grad=True)
     loss_fn = RankingLossFunction(1.)
     loss = loss_fn(pos_scores, neg_scores)
     self.assertTensorEqual(loss, torch.zeros(()))
     loss.backward()
예제 #4
0
 def test_forward(self):
     pos_scores = torch.tensor([0.8181, 0.5700, 0.3506], requires_grad=True)
     neg_scores = torch.tensor([
         [0.4437, 0.6573, 0.9986, 0.2548, 0.0998],
         [0.6175, 0.4061, 0.4582, 0.5382, 0.3126],
         [0.9869, 0.2028, 0.1667, 0.0044, 0.9934],
     ],
                               requires_grad=True)
     loss_fn = RankingLossFunction(1.)
     loss = loss_fn(pos_scores, neg_scores)
     self.assertTensorEqual(loss, torch.tensor(13.4475))
     loss.backward()
     self.assertTrue((pos_scores.grad != 0).any())
     self.assertTrue((neg_scores.grad != 0).any())
예제 #5
0
    def test_forward_weight(self):
        pos_scores = torch.tensor([0.8181, 0.5700, 0.3506], requires_grad=True)
        neg_scores = torch.tensor(
            [
                [0.4437, 0.6573, 0.9986, 0.2548, 0.0998],
                [0.6175, 0.4061, 0.4582, 0.5382, 0.3126],
                [0.9869, 0.2028, 0.1667, 0.0044, 0.9934],
            ],
            requires_grad=True,
        )
        weight = torch.full((3, ), 1.23)
        loss_fn = RankingLossFunction(margin=1.0)
        loss = loss_fn(pos_scores, neg_scores, weight)
        self.assertTensorEqual(loss, torch.tensor(13.4475 * 1.23))
        loss.backward()
        self.assertTrue((pos_scores.grad != 0).any())
        self.assertTrue((neg_scores.grad != 0).any())

        weight = torch.tensor([0.2, 0.4, 0.0])
        loss = loss_fn(pos_scores, neg_scores, weight)
        self.assertTensorEqual(loss, torch.tensor(2.4658))