Example #1
0
 def test_forward_one_batch(self):
     comparator = CosComparator()
     lhs_pos = torch.tensor(
         [[[0.8931, 0.2241, 0.4241], [0.6557, 0.2492, 0.4157]]], requires_grad=True
     )
     rhs_pos = torch.tensor(
         [[[0.9220, 0.2892, 0.7408], [0.1476, 0.6079, 0.1835]]], requires_grad=True
     )
     lhs_neg = torch.tensor(
         [
             [
                 [0.3836, 0.7648, 0.0965],
                 [0.8929, 0.8947, 0.4877],
                 [0.4754, 0.3163, 0.3422],
                 [0.7967, 0.6736, 0.2966],
             ]
         ],
         requires_grad=True,
     )
     rhs_neg = torch.tensor(
         [
             [
                 [0.6116, 0.6010, 0.9500],
                 [0.2541, 0.7715, 0.7477],
                 [0.2360, 0.5923, 0.7536],
                 [0.1290, 0.3088, 0.2731],
             ]
         ],
         requires_grad=True,
     )
     pos_scores, lhs_neg_scores, rhs_neg_scores = comparator(
         comparator.prepare(lhs_pos),
         comparator.prepare(rhs_pos),
         comparator.prepare(lhs_neg),
         comparator.prepare(rhs_neg),
     )
     self.assertTensorEqual(pos_scores, torch.tensor([[0.9741, 0.6106]]))
     self.assertTensorEqual(
         lhs_neg_scores,
         torch.tensor(
             [[[0.6165, 0.8749, 0.9664, 0.8701], [0.9607, 0.8663, 0.7494, 0.8224]]]
         ),
     )
     self.assertTensorEqual(
         rhs_neg_scores,
         torch.tensor(
             [[[0.8354, 0.6406, 0.6626, 0.6856], [0.9063, 0.7439, 0.7648, 0.7810]]]
         ),
     )
     (pos_scores.sum() + lhs_neg_scores.sum() + rhs_neg_scores.sum()).backward()
     self.assertTrue((lhs_pos.grad != 0).any())
     self.assertTrue((rhs_pos.grad != 0).any())
     self.assertTrue((lhs_neg.grad != 0).any())
     self.assertTrue((rhs_neg.grad != 0).any())
Example #2
0
 def test_forward_one_batch(self):
     comparator = BiasedComparator(CosComparator())
     lhs_pos = torch.tensor(
         [[[0.8931, 0.2241, 0.4241], [0.6557, 0.2492, 0.4157]]], requires_grad=True
     )
     rhs_pos = torch.tensor(
         [[[0.9220, 0.2892, 0.7408], [0.1476, 0.6079, 0.1835]]], requires_grad=True
     )
     lhs_neg = torch.tensor(
         [
             [
                 [0.3836, 0.7648, 0.0965],
                 [0.8929, 0.8947, 0.4877],
                 [0.4754, 0.3163, 0.3422],
                 [0.7967, 0.6736, 0.2966],
             ]
         ],
         requires_grad=True,
     )
     rhs_neg = torch.tensor(
         [
             [
                 [0.6116, 0.6010, 0.9500],
                 [0.2541, 0.7715, 0.7477],
                 [0.2360, 0.5923, 0.7536],
                 [0.1290, 0.3088, 0.2731],
             ]
         ],
         requires_grad=True,
     )
     pos_scores, lhs_neg_scores, rhs_neg_scores = comparator(
         comparator.prepare(lhs_pos),
         comparator.prepare(rhs_pos),
         comparator.prepare(lhs_neg),
         comparator.prepare(rhs_neg),
     )
     self.assertTensorEqual(pos_scores, torch.tensor([[2.8086, 1.5434]]))
     self.assertTensorEqual(
         lhs_neg_scores,
         torch.tensor(
             [[[1.7830, 2.5800, 2.3283, 2.4269], [1.5172, 2.0194, 1.4850, 1.9369]]]
         ),
     )
     self.assertTensorEqual(
         rhs_neg_scores,
         torch.tensor(
             [[[2.5017, 2.0980, 2.1129, 1.9578], [2.2670, 1.8759, 1.8838, 1.7381]]]
         ),
     )
     (pos_scores.sum() + lhs_neg_scores.sum() + rhs_neg_scores.sum()).backward()
     self.assertTrue((lhs_pos.grad != 0).any())
     self.assertTrue((rhs_pos.grad != 0).any())
     self.assertTrue((lhs_neg.grad != 0).any())
     self.assertTrue((rhs_neg.grad != 0).any())