def test_inv_matmul(self):
        base_lazy_variable_mat = torch.randn(6, 6)
        base_lazy_variable_mat = base_lazy_variable_mat.t().matmul(
            base_lazy_variable_mat)
        test_matrix = torch.randn(3, 4)

        left_interp_indices = Variable(torch.LongTensor([[2, 3], [3, 4],
                                                         [4, 5]]),
                                       requires_grad=True)
        left_interp_values = Variable(torch.Tensor([[1, 2], [0.5, 1], [1, 3]]),
                                      requires_grad=True)
        right_interp_indices = Variable(torch.LongTensor([[2, 3], [3, 4],
                                                          [4, 5]]),
                                        requires_grad=True)
        right_interp_values = Variable(torch.Tensor([[1, 2], [0.5, 1], [1,
                                                                        3]]),
                                       requires_grad=True)
        left_interp_values_copy = Variable(left_interp_values.data,
                                           requires_grad=True)
        right_interp_values_copy = Variable(right_interp_values.data,
                                            requires_grad=True)

        base_lazy_variable = Variable(base_lazy_variable_mat,
                                      requires_grad=True)
        base_lazy_variable_copy = Variable(base_lazy_variable_mat,
                                           requires_grad=True)
        test_matrix_var = Variable(test_matrix, requires_grad=True)
        test_matrix_var_copy = Variable(test_matrix, requires_grad=True)

        interp_lazy_var = InterpolatedLazyVariable(
            NonLazyVariable(base_lazy_variable),
            left_interp_indices,
            left_interp_values,
            right_interp_indices,
            right_interp_values,
        )
        res = interp_lazy_var.inv_matmul(test_matrix_var)

        left_matrix = Variable(torch.zeros(3, 6))
        right_matrix = Variable(torch.zeros(3, 6))
        left_matrix.scatter_(1, left_interp_indices, left_interp_values_copy)
        right_matrix.scatter_(1, right_interp_indices,
                              right_interp_values_copy)
        actual_mat = left_matrix.matmul(base_lazy_variable_copy).matmul(
            right_matrix.transpose(-1, -2))
        actual = gpytorch.inv_matmul(actual_mat, test_matrix_var_copy)

        self.assertTrue(approx_equal(res.data, actual.data))

        # Backward pass
        res.sum().backward()
        actual.sum().backward()

        self.assertTrue(
            approx_equal(base_lazy_variable.grad.data,
                         base_lazy_variable_copy.grad.data))
        self.assertTrue(
            approx_equal(left_interp_values.grad.data,
                         left_interp_values_copy.grad.data))
    def test_inv_matmul_batch(self):
        base_lazy_variable_mat = torch.randn(6, 6)
        base_lazy_variable_mat = ((base_lazy_variable_mat.t().matmul(
            base_lazy_variable_mat)).unsqueeze(0).repeat(5, 1, 1))
        test_matrix = torch.randn(5, 3, 4)

        left_interp_indices = Variable(torch.LongTensor(
            [[2, 3], [3, 4], [4, 5]]).unsqueeze(0).repeat(5, 1, 1),
                                       requires_grad=True)
        left_interp_values = Variable(torch.Tensor(
            [[1, 2], [0.5, 1], [1, 3]]).unsqueeze(0).repeat(5, 1, 1),
                                      requires_grad=True)
        right_interp_indices = Variable(torch.LongTensor(
            [[2, 3], [3, 4], [4, 5]]).unsqueeze(0).repeat(5, 1, 1),
                                        requires_grad=True)
        right_interp_values = Variable(torch.Tensor(
            [[1, 2], [0.5, 1], [1, 3]]).unsqueeze(0).repeat(5, 1, 1),
                                       requires_grad=True)
        left_interp_values_copy = Variable(left_interp_values.data,
                                           requires_grad=True)
        right_interp_values_copy = Variable(right_interp_values.data,
                                            requires_grad=True)

        base_lazy_variable = Variable(base_lazy_variable_mat,
                                      requires_grad=True)
        base_lazy_variable_copy = Variable(base_lazy_variable_mat,
                                           requires_grad=True)
        test_matrix_var = Variable(test_matrix, requires_grad=True)
        test_matrix_var_copy = Variable(test_matrix, requires_grad=True)

        interp_lazy_var = InterpolatedLazyVariable(
            NonLazyVariable(base_lazy_variable),
            left_interp_indices,
            left_interp_values,
            right_interp_indices,
            right_interp_values,
        )
        res = interp_lazy_var.inv_matmul(test_matrix_var)

        left_matrix_comps = []
        right_matrix_comps = []
        for i in range(5):
            left_matrix_comp = Variable(torch.zeros(3, 6))
            right_matrix_comp = Variable(torch.zeros(3, 6))
            left_matrix_comp.scatter_(1, left_interp_indices[i],
                                      left_interp_values_copy[i])
            right_matrix_comp.scatter_(1, right_interp_indices[i],
                                       right_interp_values_copy[i])
            left_matrix_comps.append(left_matrix_comp.unsqueeze(0))
            right_matrix_comps.append(right_matrix_comp.unsqueeze(0))
        left_matrix = torch.cat(left_matrix_comps)
        right_matrix = torch.cat(right_matrix_comps)
        actual_mat = left_matrix.matmul(base_lazy_variable_copy).matmul(
            right_matrix.transpose(-1, -2))
        actual = gpytorch.inv_matmul(actual_mat, test_matrix_var_copy)

        self.assertTrue(approx_equal(res.data, actual.data))

        # Backward pass
        res.sum().backward()
        actual.sum().backward()

        self.assertTrue(
            approx_equal(base_lazy_variable.grad.data,
                         base_lazy_variable_copy.grad.data))
        self.assertTrue(
            approx_equal(left_interp_values.grad.data,
                         left_interp_values_copy.grad.data))