def test_s2s_hvp(backendopt): for datatype in backendopt: T.set_backend(datatype) x = ad.Variable(name="x", shape=[3]) H = ad.Variable(name="H", shape=[3, 3]) v = ad.Variable(name="v", shape=[3]) y = ad.einsum("a,ab,b->", x, H, x) grad_x, = ad.gradients(y, [x]) Hv, = ad.hvp(output_node=y, node_list=[x], vector_list=[v]) x_val = T.tensor([1., 2., 3.]) # 3 v_val = T.tensor([1., 2., 3.]) # 3 H_val = T.tensor([[2., 0., 0.], [0., 2., 0.], [0., 0., 2.]]) # 3x3 expected_yval = T.einsum("a,ab,b->", x_val, H_val, x_val) expected_grad_x_val = 2 * T.einsum("ab,b->a", H_val, x_val) expected_hv_val = T.tensor([4., 8., 12.]) StS = SourceToSource() forward_str = StS.forward([y], backend=datatype) m = import_code(forward_str) y_val_s2s, = m.forward([x_val, H_val]) grad_str = StS.gradients(y, [x], backend=datatype) m = import_code(grad_str) grad_x_val_s2s, = m.gradients([x_val, H_val]) hvp_str = StS.hvp(y, [x], [v], backend=datatype) m = import_code(hvp_str) Hv_val_s2s, = m.hvp([x_val, H_val, v_val]) assert isinstance(y, ad.Node) assert T.array_equal(y_val_s2s, expected_yval) assert T.array_equal(grad_x_val_s2s, expected_grad_x_val) assert T.array_equal(Hv_val_s2s, expected_hv_val)
def test_inner_product_hvp(): for datatype in backends: T.set_backend(datatype) x = ad.Variable(name="x", shape=[3, 1]) v = ad.Variable(name="v", shape=[3, 1]) y = ad.sum(ad.einsum("ab,bc->ac", ad.transpose(x), x)) grad_x, = ad.gradients(y, [x]) Hv, = ad.hvp(output_node=y, node_list=[x], vector_list=[v]) executor = ad.Executor([y, grad_x, Hv]) x_val = T.tensor([[1.], [2.], [3]]) # 3x1 v_val = T.tensor([[1.], [2.], [3]]) # 3x1 y_val, grad_x_val, Hv_val = executor.run(feed_dict={ x: x_val, v: v_val }) expected_yval = T.sum(T.dot(T.transpose(x_val), x_val)) expected_grad_x_val = 2 * x_val expected_hv_val = 2 * v_val assert isinstance(y, ad.Node) assert T.array_equal(y_val, expected_yval) assert T.array_equal(grad_x_val, expected_grad_x_val) assert T.array_equal(Hv_val, expected_hv_val)
def test_hvp2(backendopt): for datatype in backendopt: T.set_backend(datatype) x = ad.Variable(name="x", shape=[3, 1]) H = ad.Variable(name="H", shape=[3, 3]) v = ad.Variable(name="v", shape=[3, 1]) y = ad.sum( ad.einsum("ab,bc->ac", ad.einsum("ab,bc->ac", ad.transpose(x), H), x)) grad_x, = ad.gradients(y, [x]) Hv, = ad.hvp(output_node=y, node_list=[x], vector_list=[v]) executor = ad.Executor([y, grad_x, Hv]) x_val = T.tensor([[1.], [2.], [3]]) # 3x1 v_val = T.tensor([[1.], [2.], [3]]) # 3x1 H_val = T.tensor([[2., 0., 0.], [0., 2., 0.], [0., 0., 2.]]) # 3x3 y_val, grad_x_val, Hv_val = executor.run(feed_dict={ x: x_val, H: H_val, v: v_val }) Hx = T.dot(H_val, x_val) expected_yval = T.sum(T.dot(T.transpose(x_val), Hx)) expected_grad_x_val = 2 * Hx expected_hv_val = T.tensor([[4.], [8.], [12.]]) assert isinstance(y, ad.Node) assert T.array_equal(y_val, expected_yval) assert T.array_equal(grad_x_val, expected_grad_x_val) assert T.array_equal(Hv_val, expected_hv_val)
def cpd_newton(size, rank): dim = 3 for datatype in BACKEND_TYPES: T.set_backend(datatype) A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) A, B, C = A_list v_A = ad.Variable(name="v_A", shape=[size, rank]) v_B = ad.Variable(name="v_B", shape=[size, rank]) v_C = ad.Variable(name="v_C", shape=[size, rank]) grads = ad.gradients(loss, [A, B, C]) Hvps = ad.hvp(output_node=loss, node_list=[A, B, C], vector_list=[v_A, v_B, v_C]) executor_grads = ad.Executor([loss] + grads) executor_Hvps = ad.Executor(Hvps) A_list, input_tensor_val = init_rand_cp(dim, size, rank) A_val, B_val, C_val = A_list for i in range(100): def hess_fn(v): return executor_Hvps.run( feed_dict={ A: A_val, B: B_val, C: C_val, input_tensor: input_tensor_val, v_A: v[0], v_B: v[1], v_C: v[2] }) loss_val, grad_A_val, grad_B_val, grad_C_val = executor_grads.run( feed_dict={ A: A_val, B: B_val, C: C_val, input_tensor: input_tensor_val }) delta = conjugate_gradient( hess_fn=hess_fn, grads=[grad_A_val, grad_B_val, grad_C_val], error_tol=1e-9, max_iters=250) A_val -= delta[0] B_val -= delta[1] C_val -= delta[2] print(f'At iteration {i} the loss is: {loss_val}')
def _get_sub_hvp(cls, index, mpo_graph, mps_graph): # rebuild mps graph intermediate_set = { mps_graph.inputs[index], mps_graph.inputs[index + 1] } split_input_nodes = list(set(mps_graph.inputs) - intermediate_set) mps = split_einsum(mps_graph.output, split_input_nodes) # get the intermediate node intermediate, = [ node for node in mps.inputs if isinstance(node, ad.EinsumNode) ] mps_outer_product = ad.tensordot(mps, mps, axes=[[], []]) mpo_axes = list(range(len(mpo_graph.output.shape))) # The 0.5 factor makes sure that the hvp can be written as an einsum objective = 0.5 * ad.tensordot( mps_outer_product, mpo_graph.output, axes=[mpo_axes, mpo_axes]) vnode = ad.Variable(name=f"v{index}", shape=intermediate.shape) hvp, = ad.hvp(objective, [intermediate], [vnode]) return intermediate, hvp, vnode