def test_forward_one_batch(self): comparator = CosComparator() lhs_pos = torch.tensor( [[[0.8931, 0.2241, 0.4241], [0.6557, 0.2492, 0.4157]]], requires_grad=True ) rhs_pos = torch.tensor( [[[0.9220, 0.2892, 0.7408], [0.1476, 0.6079, 0.1835]]], requires_grad=True ) lhs_neg = torch.tensor( [ [ [0.3836, 0.7648, 0.0965], [0.8929, 0.8947, 0.4877], [0.4754, 0.3163, 0.3422], [0.7967, 0.6736, 0.2966], ] ], requires_grad=True, ) rhs_neg = torch.tensor( [ [ [0.6116, 0.6010, 0.9500], [0.2541, 0.7715, 0.7477], [0.2360, 0.5923, 0.7536], [0.1290, 0.3088, 0.2731], ] ], requires_grad=True, ) pos_scores, lhs_neg_scores, rhs_neg_scores = comparator( comparator.prepare(lhs_pos), comparator.prepare(rhs_pos), comparator.prepare(lhs_neg), comparator.prepare(rhs_neg), ) self.assertTensorEqual(pos_scores, torch.tensor([[0.9741, 0.6106]])) self.assertTensorEqual( lhs_neg_scores, torch.tensor( [[[0.6165, 0.8749, 0.9664, 0.8701], [0.9607, 0.8663, 0.7494, 0.8224]]] ), ) self.assertTensorEqual( rhs_neg_scores, torch.tensor( [[[0.8354, 0.6406, 0.6626, 0.6856], [0.9063, 0.7439, 0.7648, 0.7810]]] ), ) (pos_scores.sum() + lhs_neg_scores.sum() + rhs_neg_scores.sum()).backward() self.assertTrue((lhs_pos.grad != 0).any()) self.assertTrue((rhs_pos.grad != 0).any()) self.assertTrue((lhs_neg.grad != 0).any()) self.assertTrue((rhs_neg.grad != 0).any())
def test_forward_one_batch(self): comparator = BiasedComparator(CosComparator()) lhs_pos = torch.tensor( [[[0.8931, 0.2241, 0.4241], [0.6557, 0.2492, 0.4157]]], requires_grad=True ) rhs_pos = torch.tensor( [[[0.9220, 0.2892, 0.7408], [0.1476, 0.6079, 0.1835]]], requires_grad=True ) lhs_neg = torch.tensor( [ [ [0.3836, 0.7648, 0.0965], [0.8929, 0.8947, 0.4877], [0.4754, 0.3163, 0.3422], [0.7967, 0.6736, 0.2966], ] ], requires_grad=True, ) rhs_neg = torch.tensor( [ [ [0.6116, 0.6010, 0.9500], [0.2541, 0.7715, 0.7477], [0.2360, 0.5923, 0.7536], [0.1290, 0.3088, 0.2731], ] ], requires_grad=True, ) pos_scores, lhs_neg_scores, rhs_neg_scores = comparator( comparator.prepare(lhs_pos), comparator.prepare(rhs_pos), comparator.prepare(lhs_neg), comparator.prepare(rhs_neg), ) self.assertTensorEqual(pos_scores, torch.tensor([[2.8086, 1.5434]])) self.assertTensorEqual( lhs_neg_scores, torch.tensor( [[[1.7830, 2.5800, 2.3283, 2.4269], [1.5172, 2.0194, 1.4850, 1.9369]]] ), ) self.assertTensorEqual( rhs_neg_scores, torch.tensor( [[[2.5017, 2.0980, 2.1129, 1.9578], [2.2670, 1.8759, 1.8838, 1.7381]]] ), ) (pos_scores.sum() + lhs_neg_scores.sum() + rhs_neg_scores.sum()).backward() self.assertTrue((lhs_pos.grad != 0).any()) self.assertTrue((rhs_pos.grad != 0).any()) self.assertTrue((lhs_neg.grad != 0).any()) self.assertTrue((rhs_neg.grad != 0).any())