def test_hessian_vector_product(): """Test Hessian-vector product for a function with one variable.""" a = torch.tensor([5.0]) x = torch.tensor([10.0], requires_grad=True) def f(): return a * (x**2) expected_hessian = 2 * a vector = torch.tensor([10.0]) expected_hvp = (expected_hessian * vector).detach() f_Ax = _build_hessian_vector_product(f, [x]) computed_hvp = f_Ax(vector).detach() assert np.allclose(computed_hvp, expected_hvp)
def test_hessian_vector_product_2x2(a_val, b_val, x_val, y_val, vector): """Test for a function with two variables.""" obs = [torch.tensor([a_val]), torch.tensor([b_val])] vector = torch.tensor([vector]) x = torch.tensor(x_val, requires_grad=True) y = torch.tensor(y_val, requires_grad=True) def f(): a, b = obs[0], obs[1] return a * (x**2) + b * (y**2) expected_hessian = compute_hessian(f(), [x, y]) expected_hvp = torch.mm(vector, expected_hessian).detach() f_Ax = _build_hessian_vector_product(f, [x, y]) hvp = f_Ax(vector[0]).detach() assert np.allclose(hvp, expected_hvp, atol=1e-6)
def test_hessian_vector_product_2x2_non_diagonal(a_val, b_val, x_val, y_val, vector): """Test for a function with two variables and non-diagonal Hessian.""" obs = [torch.tensor([a_val]), torch.tensor([b_val])] vector = torch.tensor([vector]) x = torch.tensor([x_val], requires_grad=True) y = torch.tensor([y_val], requires_grad=True) def f(): a, b = obs[0], obs[1] kl = a * (x**3) + b * (y**3) + (x**2) * y + (y**2) * x return kl expected_hessian = compute_hessian(f(), [x, y]) expected_hvp = torch.mm(vector, expected_hessian).detach() f_Ax = _build_hessian_vector_product(f, [x, y]) hvp = f_Ax(vector[0]).detach() assert np.allclose(hvp, expected_hvp)