def test_hvp2(backendopt): for datatype in backendopt: T.set_backend(datatype) x = ad.Variable(name="x", shape=[3, 1]) H = ad.Variable(name="H", shape=[3, 3]) v = ad.Variable(name="v", shape=[3, 1]) y = ad.sum( ad.einsum("ab,bc->ac", ad.einsum("ab,bc->ac", ad.transpose(x), H), x)) grad_x, = ad.gradients(y, [x]) Hv, = ad.hvp(output_node=y, node_list=[x], vector_list=[v]) executor = ad.Executor([y, grad_x, Hv]) x_val = T.tensor([[1.], [2.], [3]]) # 3x1 v_val = T.tensor([[1.], [2.], [3]]) # 3x1 H_val = T.tensor([[2., 0., 0.], [0., 2., 0.], [0., 0., 2.]]) # 3x3 y_val, grad_x_val, Hv_val = executor.run(feed_dict={ x: x_val, H: H_val, v: v_val }) expected_yval = T.sum(T.transpose(x_val) @ H_val @ x_val) expected_grad_x_val = 2 * H_val @ x_val expected_hv_val = T.tensor([[4.], [8.], [12.]]) assert isinstance(y, ad.Node) assert T.array_equal(y_val, expected_yval) assert T.array_equal(grad_x_val, expected_grad_x_val) assert T.array_equal(Hv_val, expected_hv_val)
def test_inner_product(backendopt): for datatype in backendopt: T.set_backend(datatype) x = ad.Variable(name="x", shape=[1, 3]) x_inner = ad.sum(ad.einsum("ab,bc->ac", x, ad.transpose(x))) grad_x, = ad.gradients(x_inner, [x]) executor = ad.Executor([x_inner, grad_x]) x_val = T.tensor([[3., 4.]]) # 1x2 y_val, grad_x_val = executor.run(feed_dict={x: x_val}) expected_yval = T.norm(x_val)**2 expected_grad_x_val = 2 * x_val assert isinstance(x_inner, ad.Node) assert T.array_equal(y_val, expected_yval) assert T.array_equal(grad_x_val, expected_grad_x_val)
def test_transpose(backendopt): for datatype in backendopt: T.set_backend(datatype) x = ad.Variable(name="x", shape=[3, 2]) y = ad.sum(ad.transpose(x)) grad_x, = ad.gradients(y, [x]) executor = ad.Executor([y, grad_x]) x_val = T.tensor([[1, 2], [3, 4], [5, 6]]) # 3x2 y_val, grad_x_val = executor.run(feed_dict={x: x_val}) expected_yval = T.sum(T.transpose(x_val)) expected_grad_x_val = T.ones_like(x_val) assert isinstance(y, ad.Node) assert T.array_equal(y_val, expected_yval) assert T.array_equal(grad_x_val, expected_grad_x_val)
def parse_jax_transpose(parameters, innodes): assert len(innodes) == 1 return ad.transpose(innodes[0], parameters['permutation'])