def test_pearl_mutter_hvp_2x2(self, a_val, b_val, x_val, y_val, vector):
        """Test Hessian-vector product for a function with two variables."""
        a_val = [a_val]
        b_val = [b_val]
        vector = np.array([vector], dtype=np.float32)

        policy = HelperPolicy(n_vars=2)
        params = policy.get_params()
        x, y = params[0], params[1]
        a = tf.constant(a_val)
        b = tf.constant(b_val)
        f = a * (x**2) + b * (y**2)

        expected_hessian = compute_hessian(f, [x, y])
        expected_hvp = tf.matmul(vector, expected_hessian)
        reg_coeff = 1e-5
        hvp = PearlmutterHvp()

        self.sess.run(tf.compat.v1.global_variables_initializer())
        self.sess.run(x.assign([x_val]))
        self.sess.run(y.assign([y_val]))
        hvp.update_hvp(f, policy, (a, b), reg_coeff)
        hx = hvp.build_eval((np.array(a_val), np.array(b_val)))
        hvp = hx(vector[0])
        expected_hvp = expected_hvp.eval()
        assert np.allclose(hvp, expected_hvp, atol=1e-6)
    def test_pearl_mutter_hvp_1x1(self):
        """Test Hessian-vector product for a function with one variable."""
        policy = HelperPolicy(n_vars=1)
        x = policy.get_params()[0]
        a_val = np.array([5.0])
        a = tf.constant([0.0])
        f = a * (x**2)
        expected_hessian = 2 * a_val
        vector = np.array([10.0])
        expected_hvp = expected_hessian * vector
        reg_coeff = 1e-5
        hvp = PearlmutterHvp()

        self.sess.run(tf.compat.v1.global_variables_initializer())
        hvp.update_hvp(f, policy, (a, ), reg_coeff)
        hx = hvp.build_eval(np.array([a_val]))
        computed_hvp = hx(vector)
        assert np.allclose(computed_hvp, expected_hvp)
    def test_pickleable(self):
        policy = HelperPolicy(n_vars=1)
        x = policy.get_params()[0]
        a_val = np.array([5.0])
        a = tf.constant([0.0])
        f = a * (x**2)
        vector = np.array([10.0])
        reg_coeff = 1e-5
        hvp = PearlmutterHvp()

        self.sess.run(tf.compat.v1.global_variables_initializer())
        hvp.update_hvp(f, policy, (a, ), reg_coeff)
        hx = hvp.build_eval(np.array([a_val]))
        before_pickle = hx(vector)

        hvp = pickle.loads(pickle.dumps(hvp))
        hvp.update_hvp(f, policy, (a, ), reg_coeff)
        after_pickle = hx(vector)
        assert np.equal(before_pickle, after_pickle)