def test_hessian_vector_product():
    """Test Hessian-vector product for a function with one variable."""
    a = torch.tensor([5.0])
    x = torch.tensor([10.0], requires_grad=True)

    def f():
        return a * (x**2)

    expected_hessian = 2 * a
    vector = torch.tensor([10.0])
    expected_hvp = (expected_hessian * vector).detach()
    f_Ax = _build_hessian_vector_product(f, [x])
    computed_hvp = f_Ax(vector).detach()
    assert np.allclose(computed_hvp, expected_hvp)
def test_hessian_vector_product_2x2(a_val, b_val, x_val, y_val, vector):
    """Test for a function with two variables."""
    obs = [torch.tensor([a_val]), torch.tensor([b_val])]
    vector = torch.tensor([vector])
    x = torch.tensor(x_val, requires_grad=True)
    y = torch.tensor(y_val, requires_grad=True)

    def f():
        a, b = obs[0], obs[1]
        return a * (x**2) + b * (y**2)

    expected_hessian = compute_hessian(f(), [x, y])
    expected_hvp = torch.mm(vector, expected_hessian).detach()
    f_Ax = _build_hessian_vector_product(f, [x, y])
    hvp = f_Ax(vector[0]).detach()
    assert np.allclose(hvp, expected_hvp, atol=1e-6)
def test_hessian_vector_product_2x2_non_diagonal(a_val, b_val, x_val, y_val,
                                                 vector):
    """Test for a function with two variables and non-diagonal Hessian."""
    obs = [torch.tensor([a_val]), torch.tensor([b_val])]
    vector = torch.tensor([vector])
    x = torch.tensor([x_val], requires_grad=True)
    y = torch.tensor([y_val], requires_grad=True)

    def f():
        a, b = obs[0], obs[1]
        kl = a * (x**3) + b * (y**3) + (x**2) * y + (y**2) * x
        return kl

    expected_hessian = compute_hessian(f(), [x, y])
    expected_hvp = torch.mm(vector, expected_hessian).detach()
    f_Ax = _build_hessian_vector_product(f, [x, y])
    hvp = f_Ax(vector[0]).detach()
    assert np.allclose(hvp, expected_hvp)