def test_matmul_mat_with_two_matrices(self): mat1 = make_random_mat(20, 5) mat2 = make_random_mat(20, 5) vec = torch.randn(20, 7, requires_grad=True) mat1_copy = mat1.clone().detach().requires_grad_(True) mat2_copy = mat2.clone().detach().requires_grad_(True) vec_copy = vec.clone().detach().requires_grad_(True) # Forward res = MulLazyTensor(RootLazyTensor(mat1), RootLazyTensor(mat2)).matmul(vec) actual = prod( [mat1_copy.matmul(mat1_copy.transpose(-1, -2)), mat2_copy.matmul(mat2_copy.transpose(-1, -2))] ).matmul(vec_copy) assert torch.max(((res - actual) / actual).abs()) < 0.01 # Backward res.sum().backward() actual.sum().backward() self.assertLess(torch.max(((mat1.grad - mat1_copy.grad) / mat1_copy.grad).abs()), 0.01) self.assertLess(torch.max(((mat2.grad - mat2_copy.grad) / mat2_copy.grad).abs()), 0.01) self.assertLess(torch.max(((vec.grad - vec_copy.grad) / vec_copy.grad).abs()), 0.01)
def test_batch_matmul_mat_with_five_matrices(self): mat1 = make_random_mat(20, rank=4, batch_size=5) mat2 = make_random_mat(20, rank=4, batch_size=5) mat3 = make_random_mat(20, rank=4, batch_size=5) mat4 = make_random_mat(20, rank=4, batch_size=5) mat5 = make_random_mat(20, rank=4, batch_size=5) vec = torch.randn(5, 20, 7, requires_grad=True) mat1_copy = mat1.clone().detach().requires_grad_(True) mat2_copy = mat2.clone().detach().requires_grad_(True) mat3_copy = mat3.clone().detach().requires_grad_(True) mat4_copy = mat4.clone().detach().requires_grad_(True) mat5_copy = mat5.clone().detach().requires_grad_(True) vec_copy = vec.clone().detach().requires_grad_(True) # Forward res = MulLazyTensor( RootLazyTensor(mat1), RootLazyTensor(mat2), RootLazyTensor(mat3), RootLazyTensor(mat4), RootLazyTensor(mat5) ).matmul(vec) actual = prod( [ mat1_copy.matmul(mat1_copy.transpose(-1, -2)), mat2_copy.matmul(mat2_copy.transpose(-1, -2)), mat3_copy.matmul(mat3_copy.transpose(-1, -2)), mat4_copy.matmul(mat4_copy.transpose(-1, -2)), mat5_copy.matmul(mat5_copy.transpose(-1, -2)), ] ).matmul(vec_copy) self.assertLess(torch.max(((res - actual) / actual).abs()), 0.01) # Backward res.sum().backward() actual.sum().backward() self.assertLess(torch.max(((mat1.grad - mat1_copy.grad) / mat1_copy.grad).abs()), 0.01) self.assertLess(torch.max(((mat2.grad - mat2_copy.grad) / mat2_copy.grad).abs()), 0.01) self.assertLess(torch.max(((mat3.grad - mat3_copy.grad) / mat3_copy.grad).abs()), 0.01) self.assertLess(torch.max(((mat4.grad - mat4_copy.grad) / mat4_copy.grad).abs()), 0.01) self.assertLess(torch.max(((mat5.grad - mat5_copy.grad) / mat5_copy.grad).abs()), 0.01) self.assertLess(torch.max(((vec.grad - vec_copy.grad) / vec_copy.grad).abs()), 0.01)