def test_finite_difference_hvp_2x2_non_diagonal(self, a_val, b_val, x_val,
                                                    y_val, vector):
        """Test Hessian-vector product for a function with two variables whose Hessian
        is non-diagonal.
        a_val = [a_val]
        b_val = [b_val]
        vector = np.array([vector], dtype=np.float32)

        policy = HelperPolicy(n_vars=2)
        params = policy.get_params()
        x, y = params[0], params[1]
        a = tf.constant(a_val)
        b = tf.constant(b_val)
        f = a * (x**3) + b * (y**3) + (x**2) * y + (y**2) * x

        expected_hessian = compute_hessian(f, [x, y])
        expected_hvp = tf.matmul(vector, expected_hessian)
        reg_coeff = 1e-5
        hvp = FiniteDifferenceHvp(base_eps=1)[x_val]))[y_val]))
        hvp.update_hvp(f, policy, (a, b), reg_coeff)
        hx = hvp.build_eval((np.array(a_val), np.array(b_val)))
        hvp = hx(vector[0])
        expected_hvp = expected_hvp.eval()
        assert np.allclose(hvp, expected_hvp)
    def test_finite_difference_hvp(self):
        """Test Hessian-vector product for a function with one variable."""
        policy = HelperPolicy(n_vars=1)
        x = policy.get_params()[0]
        a_val = np.array([5.0])
        a = tf.constant([0.0])
        f = a * (x**2)
        expected_hessian = 2 * a_val
        vector = np.array([10.0])
        expected_hvp = expected_hessian * vector
        reg_coeff = 1e-5
        hvp = FiniteDifferenceHvp()
        hvp.update_hvp(f, policy, (a, ), reg_coeff)
        hx = hvp.build_eval(np.array([a_val]))
        computed_hvp = hx(vector)
        assert np.allclose(computed_hvp, expected_hvp)
    def test_pickleable(self):
        policy = HelperPolicy(n_vars=1)
        x = policy.get_params()[0]
        a_val = np.array([5.0])
        a = tf.constant([0.0])
        f = a * (x**2)
        vector = np.array([10.0])
        reg_coeff = 1e-5
        hvp = FiniteDifferenceHvp()
        hvp.update_hvp(f, policy, (a, ), reg_coeff)
        hx = hvp.build_eval(np.array([a_val]))
        before_pickle = hx(vector)

        hvp = pickle.loads(pickle.dumps(hvp))
        hvp.update_hvp(f, policy, (a, ), reg_coeff)
        after_pickle = hx(vector)
        assert np.equal(before_pickle, after_pickle)