def test_finite_difference_hvp_2x2_non_diagonal(self, a_val, b_val, x_val, y_val, vector): """Test Hessian-vector product for a function with two variables whose Hessian is non-diagonal. """ a_val = [a_val] b_val = [b_val] vector = np.array([vector], dtype=np.float32) policy = HelperPolicy(n_vars=2) params = policy.get_params() x, y = params[0], params[1] a = tf.constant(a_val) b = tf.constant(b_val) f = a * (x**3) + b * (y**3) + (x**2) * y + (y**2) * x expected_hessian = compute_hessian(f, [x, y]) expected_hvp = tf.matmul(vector, expected_hessian) reg_coeff = 1e-5 hvp = FiniteDifferenceHvp(base_eps=1) self.sess.run(tf.compat.v1.global_variables_initializer()) self.sess.run(x.assign([x_val])) self.sess.run(y.assign([y_val])) hvp.update_hvp(f, policy, (a, b), reg_coeff) hx = hvp.build_eval((np.array(a_val), np.array(b_val))) hvp = hx(vector[0]) expected_hvp = expected_hvp.eval() assert np.allclose(hvp, expected_hvp)
def test_finite_difference_hvp(self): """Test Hessian-vector product for a function with one variable.""" policy = HelperPolicy(n_vars=1) x = policy.get_params()[0] a_val = np.array([5.0]) a = tf.constant([0.0]) f = a * (x**2) expected_hessian = 2 * a_val vector = np.array([10.0]) expected_hvp = expected_hessian * vector reg_coeff = 1e-5 hvp = FiniteDifferenceHvp() self.sess.run(tf.compat.v1.global_variables_initializer()) hvp.update_hvp(f, policy, (a, ), reg_coeff) hx = hvp.build_eval(np.array([a_val])) computed_hvp = hx(vector) assert np.allclose(computed_hvp, expected_hvp)
def test_pickleable(self): policy = HelperPolicy(n_vars=1) x = policy.get_params()[0] a_val = np.array([5.0]) a = tf.constant([0.0]) f = a * (x**2) vector = np.array([10.0]) reg_coeff = 1e-5 hvp = FiniteDifferenceHvp() self.sess.run(tf.compat.v1.global_variables_initializer()) hvp.update_hvp(f, policy, (a, ), reg_coeff) hx = hvp.build_eval(np.array([a_val])) before_pickle = hx(vector) hvp = pickle.loads(pickle.dumps(hvp)) hvp.update_hvp(f, policy, (a, ), reg_coeff) after_pickle = hx(vector) assert np.equal(before_pickle, after_pickle)