def test_batch_matmul(self): left_interp_indices = torch.tensor([[2, 3], [3, 4], [4, 5]], dtype=torch.long).repeat(5, 3, 1) left_interp_values = torch.tensor([[1, 2], [0.5, 1], [1, 3]], dtype=torch.float).repeat(5, 3, 1) left_interp_values_copy = left_interp_values.clone() left_interp_values.requires_grad = True left_interp_values_copy.requires_grad = True right_interp_indices = torch.tensor([[0, 1], [1, 2], [2, 3]], dtype=torch.long).repeat(5, 3, 1) right_interp_values = torch.tensor([[1, 2], [2, 0.5], [1, 3]], dtype=torch.float).repeat(5, 3, 1) right_interp_values_copy = right_interp_values.clone() right_interp_values.requires_grad = True right_interp_values_copy.requires_grad = True base_lazy_tensor_mat = torch.randn(5, 6, 6) base_lazy_tensor_mat = base_lazy_tensor_mat.transpose( -1, -2).matmul(base_lazy_tensor_mat) base_tensor = base_lazy_tensor_mat base_tensor_copy = base_tensor.clone() base_tensor.requires_grad = True base_tensor_copy.requires_grad = True base_lazy_tensor = NonLazyTensor(base_tensor) test_matrix = torch.randn(5, 9, 4) interp_lazy_tensor = InterpolatedLazyTensor(base_lazy_tensor, left_interp_indices, left_interp_values, right_interp_indices, right_interp_values) res = interp_lazy_tensor.matmul(test_matrix) left_matrix_comps = [] right_matrix_comps = [] for i in range(5): left_matrix_comp = torch.zeros(9, 6) right_matrix_comp = torch.zeros(9, 6) left_matrix_comp.scatter_(1, left_interp_indices[i], left_interp_values_copy[i]) right_matrix_comp.scatter_(1, right_interp_indices[i], right_interp_values_copy[i]) left_matrix_comps.append(left_matrix_comp.unsqueeze(0)) right_matrix_comps.append(right_matrix_comp.unsqueeze(0)) left_matrix = torch.cat(left_matrix_comps) right_matrix = torch.cat(right_matrix_comps) actual = left_matrix.matmul(base_tensor_copy).matmul( right_matrix.transpose(-1, -2)) actual = actual.matmul(test_matrix) self.assertTrue(approx_equal(res, actual)) res.sum().backward() actual.sum().backward() self.assertTrue(approx_equal(base_tensor.grad, base_tensor_copy.grad)) self.assertTrue( approx_equal(left_interp_values.grad, left_interp_values_copy.grad))
def test_matmul_batch(self): left_interp_indices = torch.LongTensor([[2, 3], [3, 4], [4, 5]]).repeat(5, 3, 1) left_interp_values = torch.tensor([[1, 2], [0.5, 1], [1, 3]], dtype=torch.float).repeat(5, 3, 1) right_interp_indices = torch.LongTensor([[0, 1], [1, 2], [2, 3]]).repeat(5, 3, 1) right_interp_values = torch.tensor([[1, 2], [2, 0.5], [1, 3]], dtype=torch.float).repeat(5, 3, 1) base_lazy_tensor_mat = torch.randn(5, 6, 6) base_lazy_tensor_mat = base_lazy_tensor_mat.transpose( 1, 2).matmul(base_lazy_tensor_mat) base_lazy_tensor_mat.requires_grad = True test_matrix = torch.randn(1, 9, 4) base_lazy_tensor = NonLazyTensor(base_lazy_tensor_mat) interp_lazy_tensor = InterpolatedLazyTensor(base_lazy_tensor, left_interp_indices, left_interp_values, right_interp_indices, right_interp_values) res = interp_lazy_tensor.matmul(test_matrix) left_matrix = torch.tensor( [ [0, 0, 1, 2, 0, 0], [0, 0, 0, 0.5, 1, 0], [0, 0, 0, 0, 1, 3], [0, 0, 1, 2, 0, 0], [0, 0, 0, 0.5, 1, 0], [0, 0, 0, 0, 1, 3], [0, 0, 1, 2, 0, 0], [0, 0, 0, 0.5, 1, 0], [0, 0, 0, 0, 1, 3], ], dtype=torch.float, ).repeat(5, 1, 1) right_matrix = torch.tensor( [ [1, 2, 0, 0, 0, 0], [0, 2, 0.5, 0, 0, 0], [0, 0, 1, 3, 0, 0], [1, 2, 0, 0, 0, 0], [0, 2, 0.5, 0, 0, 0], [0, 0, 1, 3, 0, 0], [1, 2, 0, 0, 0, 0], [0, 2, 0.5, 0, 0, 0], [0, 0, 1, 3, 0, 0], ], dtype=torch.float, ).repeat(5, 1, 1) actual = left_matrix.matmul(base_lazy_tensor_mat).matmul( right_matrix.transpose(-1, -2)).matmul(test_matrix) self.assertTrue(approx_equal(res, actual))
def test_matmul(self): left_interp_indices = torch.LongTensor([[2, 3], [3, 4], [4, 5]]).repeat(3, 1) left_interp_values = torch.tensor([[1, 2], [0.5, 1], [1, 3]], dtype=torch.float).repeat(3, 1) left_interp_values_copy = left_interp_values.clone() left_interp_values.requires_grad = True left_interp_values_copy.requires_grad = True right_interp_indices = torch.LongTensor([[0, 1], [1, 2], [2, 3]]).repeat(3, 1) right_interp_values = torch.tensor([[1, 2], [2, 0.5], [1, 3]], dtype=torch.float).repeat(3, 1) right_interp_values_copy = right_interp_values.clone() right_interp_values.requires_grad = True right_interp_values_copy.requires_grad = True base_lazy_tensor_mat = torch.randn(6, 6) base_lazy_tensor_mat = base_lazy_tensor_mat.t().matmul( base_lazy_tensor_mat) base_tensor = base_lazy_tensor_mat base_tensor.requires_grad = True base_tensor_copy = base_lazy_tensor_mat base_lazy_tensor = NonLazyTensor(base_tensor) test_matrix = torch.randn(9, 4) interp_lazy_tensor = InterpolatedLazyTensor(base_lazy_tensor, left_interp_indices, left_interp_values, right_interp_indices, right_interp_values) res = interp_lazy_tensor.matmul(test_matrix) left_matrix = torch.zeros(9, 6) right_matrix = torch.zeros(9, 6) left_matrix.scatter_(1, left_interp_indices, left_interp_values_copy) right_matrix.scatter_(1, right_interp_indices, right_interp_values_copy) actual = left_matrix.matmul(base_tensor_copy).matmul( right_matrix.t()).matmul(test_matrix) self.assertTrue(approx_equal(res, actual)) res.sum().backward() actual.sum().backward() self.assertTrue(approx_equal(base_tensor.grad, base_tensor_copy.grad)) self.assertTrue( approx_equal(left_interp_values.grad, left_interp_values_copy.grad))