def test_s2s_hvp(backendopt): for datatype in backendopt: T.set_backend(datatype) x = ad.Variable(name="x", shape=[3]) H = ad.Variable(name="H", shape=[3, 3]) v = ad.Variable(name="v", shape=[3]) y = ad.einsum("a,ab,b->", x, H, x) grad_x, = ad.gradients(y, [x]) Hv, = ad.hvp(output_node=y, node_list=[x], vector_list=[v]) x_val = T.tensor([1., 2., 3.]) # 3 v_val = T.tensor([1., 2., 3.]) # 3 H_val = T.tensor([[2., 0., 0.], [0., 2., 0.], [0., 0., 2.]]) # 3x3 expected_yval = T.einsum("a,ab,b->", x_val, H_val, x_val) expected_grad_x_val = 2 * T.einsum("ab,b->a", H_val, x_val) expected_hv_val = T.tensor([4., 8., 12.]) StS = SourceToSource() forward_str = StS.forward([y], backend=datatype) m = import_code(forward_str) y_val_s2s, = m.forward([x_val, H_val]) grad_str = StS.gradients(y, [x], backend=datatype) m = import_code(grad_str) grad_x_val_s2s, = m.gradients([x_val, H_val]) hvp_str = StS.hvp(y, [x], [v], backend=datatype) m = import_code(hvp_str) Hv_val_s2s, = m.hvp([x_val, H_val, v_val]) assert isinstance(y, ad.Node) assert T.array_equal(y_val_s2s, expected_yval) assert T.array_equal(grad_x_val_s2s, expected_grad_x_val) assert T.array_equal(Hv_val_s2s, expected_hv_val)
def test_hvp2(backendopt): for datatype in backendopt: T.set_backend(datatype) x = ad.Variable(name="x", shape=[3, 1]) H = ad.Variable(name="H", shape=[3, 3]) v = ad.Variable(name="v", shape=[3, 1]) y = ad.sum( ad.einsum("ab,bc->ac", ad.einsum("ab,bc->ac", ad.transpose(x), H), x)) grad_x, = ad.gradients(y, [x]) Hv, = ad.hvp(output_node=y, node_list=[x], vector_list=[v]) executor = ad.Executor([y, grad_x, Hv]) x_val = T.tensor([[1.], [2.], [3]]) # 3x1 v_val = T.tensor([[1.], [2.], [3]]) # 3x1 H_val = T.tensor([[2., 0., 0.], [0., 2., 0.], [0., 0., 2.]]) # 3x3 y_val, grad_x_val, Hv_val = executor.run(feed_dict={ x: x_val, H: H_val, v: v_val }) Hx = T.dot(H_val, x_val) expected_yval = T.sum(T.dot(T.transpose(x_val), Hx)) expected_grad_x_val = 2 * Hx expected_hv_val = T.tensor([[4.], [8.], [12.]]) assert isinstance(y, ad.Node) assert T.array_equal(y_val, expected_yval) assert T.array_equal(grad_x_val, expected_grad_x_val) assert T.array_equal(Hv_val, expected_hv_val)
def cpd_gradient_descent(size, rank, learning_rate): dim = 3 for datatype in BACKEND_TYPES: T.set_backend(datatype) A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) A, B, C = A_list grad_A, grad_B, grad_C = ad.gradients(loss, [A, B, C]) executor = ad.Executor([loss, grad_A, grad_B, grad_C]) A_list, input_tensor_val = init_rand_cp(dim, size, rank) A_val, B_val, C_val = A_list for i in range(100): loss_val, grad_A_val, grad_B_val, grad_C_val = executor.run( feed_dict={ input_tensor: input_tensor_val, A: A_val, B: B_val, C: C_val }) A_val -= learning_rate * grad_A_val B_val -= learning_rate * grad_B_val C_val -= learning_rate * grad_C_val print(f'At iteration {i} the loss is: {loss_val}')
def test_inner_product_hvp(): for datatype in backends: T.set_backend(datatype) x = ad.Variable(name="x", shape=[3, 1]) v = ad.Variable(name="v", shape=[3, 1]) y = ad.sum(ad.einsum("ab,bc->ac", ad.transpose(x), x)) grad_x, = ad.gradients(y, [x]) Hv, = ad.hvp(output_node=y, node_list=[x], vector_list=[v]) executor = ad.Executor([y, grad_x, Hv]) x_val = T.tensor([[1.], [2.], [3]]) # 3x1 v_val = T.tensor([[1.], [2.], [3]]) # 3x1 y_val, grad_x_val, Hv_val = executor.run(feed_dict={ x: x_val, v: v_val }) expected_yval = T.sum(T.dot(T.transpose(x_val), x_val)) expected_grad_x_val = 2 * x_val expected_hv_val = 2 * v_val assert isinstance(y, ad.Node) assert T.array_equal(y_val, expected_yval) assert T.array_equal(grad_x_val, expected_grad_x_val) assert T.array_equal(Hv_val, expected_hv_val)
def test_einsum(): for datatype in backends: T.set_backend(datatype) x2 = ad.Variable(name="x2", shape=[3, 2]) x3 = ad.Variable(name="x3", shape=[2, 3]) matmul = ad.einsum('ik,kj->ij', x2, x3) y = ad.sum(matmul) grad_x2, grad_x3 = ad.gradients(y, [x2, x3]) executor = ad.Executor([y, grad_x2, grad_x3]) x2_val = T.tensor([[1, 2], [3, 4], [5, 6]]) # 3x2 x3_val = T.tensor([[7, 8, 9], [10, 11, 12]]) # 2x3 y_val, grad_x2_val, grad_x3_val = executor.run(feed_dict={ x2: x2_val, x3: x3_val }) expected_grad_sum = T.ones_like(T.dot(x2_val, x3_val)) expected_yval = T.sum(T.dot(x2_val, x3_val)) expected_grad_x2_val = T.dot(expected_grad_sum, T.transpose(x3_val)) expected_grad_x3_val = T.dot(T.transpose(x2_val), expected_grad_sum) assert isinstance(y, ad.Node) assert T.array_equal(y_val, expected_yval) assert T.array_equal(grad_x2_val, expected_grad_x2_val) assert T.array_equal(grad_x3_val, expected_grad_x3_val)
def test_add_mul_mix_3(backendopt): for datatype in backendopt: T.set_backend(datatype) x2 = ad.Variable(name="x2", shape=[3]) x3 = ad.Variable(name="x3", shape=[3]) z = x2 * x2 + x2 + x3 + 3 y = ad.sum(z * z + x3) grad_x2, grad_x3 = ad.gradients(y, [x2, x3]) executor = ad.Executor([y, grad_x2, grad_x3]) x2_val = 2 * T.ones(3) x3_val = 3 * T.ones(3) y_val, grad_x2_val, grad_x3_val = executor.run(feed_dict={ x2: x2_val, x3: x3_val }) z_val = x2_val * x2_val + x2_val + x3_val + 3 expected_yval = z_val * z_val + x3_val expected_grad_x2_val = 2 * \ (x2_val * x2_val + x2_val + x3_val + 3) * (2 * x2_val + 1) expected_grad_x3_val = 2 * (x2_val * x2_val + x2_val + x3_val + 3) + 1 assert isinstance(y, ad.Node) assert T.array_equal(y_val, T.sum(expected_yval)) assert T.array_equal(grad_x2_val, expected_grad_x2_val) assert T.array_equal(grad_x3_val, expected_grad_x3_val)
def test_add_mul_mix_2(backendopt): for datatype in backendopt: T.set_backend(datatype) x1 = ad.Variable(name="x1", shape=[3]) x2 = ad.Variable(name="x2", shape=[3]) x3 = ad.Variable(name="x3", shape=[3]) x4 = ad.Variable(name="x4", shape=[3]) y = ad.sum(x1 + x2 * x3 * x4) grad_x1, grad_x2, grad_x3, grad_x4 = ad.gradients(y, [x1, x2, x3, x4]) executor = ad.Executor([y, grad_x1, grad_x2, grad_x3, grad_x4]) x1_val = 1 * T.ones(3) x2_val = 2 * T.ones(3) x3_val = 3 * T.ones(3) x4_val = 4 * T.ones(3) y_val, grad_x1_val, grad_x2_val, grad_x3_val, grad_x4_val = executor.run( feed_dict={ x1: x1_val, x2: x2_val, x3: x3_val, x4: x4_val }) assert isinstance(y, ad.Node) assert T.array_equal(y_val, T.sum(x1_val + x2_val * x3_val * x4_val)) assert T.array_equal(grad_x1_val, T.ones_like(x1_val)) assert T.array_equal(grad_x2_val, x3_val * x4_val) assert T.array_equal(grad_x3_val, x2_val * x4_val) assert T.array_equal(grad_x4_val, x2_val * x3_val)
def cpd_newton(size, rank): dim = 3 for datatype in BACKEND_TYPES: T.set_backend(datatype) A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) A, B, C = A_list v_A = ad.Variable(name="v_A", shape=[size, rank]) v_B = ad.Variable(name="v_B", shape=[size, rank]) v_C = ad.Variable(name="v_C", shape=[size, rank]) grads = ad.gradients(loss, [A, B, C]) Hvps = ad.hvp(output_node=loss, node_list=[A, B, C], vector_list=[v_A, v_B, v_C]) executor_grads = ad.Executor([loss] + grads) executor_Hvps = ad.Executor(Hvps) A_list, input_tensor_val = init_rand_cp(dim, size, rank) A_val, B_val, C_val = A_list for i in range(100): def hess_fn(v): return executor_Hvps.run( feed_dict={ A: A_val, B: B_val, C: C_val, input_tensor: input_tensor_val, v_A: v[0], v_B: v[1], v_C: v[2] }) loss_val, grad_A_val, grad_B_val, grad_C_val = executor_grads.run( feed_dict={ A: A_val, B: B_val, C: C_val, input_tensor: input_tensor_val }) delta = conjugate_gradient( hess_fn=hess_fn, grads=[grad_A_val, grad_B_val, grad_C_val], error_tol=1e-9, max_iters=250) A_val -= delta[0] B_val -= delta[1] C_val -= delta[2] print(f'At iteration {i} the loss is: {loss_val}')
def cpd_als_shared_exec(dim, size, rank, num_iter, input_val=[]): A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) full_hessian = ad.hessian(loss, A_list) hessians = [full_hessian[i][i] for i in range(len(full_hessian))] grads = ad.gradients(loss, A_list) updates = [ ad.tensordot(ad.tensorinv(hes), grad, [[2, 3], [0, 1]]) for (hes, grad) in zip(hessians, grads) ] new_A_list = [simplify(A - update) for (A, update) in zip(A_list, updates)] new_A_list = generate_sequential_optimal_tree(new_A_list, A_list) executor = ad.Executor(new_A_list) executor_loss = ad.Executor([simplify(loss)]) if input_val == []: A_val_list, input_tensor_val = init_rand_cp(dim, size, rank) else: A_val_list, input_tensor_val = input_val for iter in range(num_iter): t0 = time.time() # als iterations for i in range(len(A_list)): feed_dict = dict(zip(A_list, A_val_list)) feed_dict.update({input_tensor: input_tensor_val}) if i == 0: A_val_list[0], = executor.run(feed_dict=feed_dict, out_nodes=[new_A_list[0]]) else: A_val_list[i], = executor.run(feed_dict=feed_dict, reset_graph=False, evicted_inputs=[A_list[i - 1]], out_nodes=[new_A_list[i]]) feed_dict = dict(zip(A_list, A_val_list)) feed_dict.update({input_tensor: input_tensor_val}) loss_val, = executor_loss.run(feed_dict=feed_dict) print(f'At iteration {iter} the loss is: {loss_val}') t1 = time.time() print(f"[ {iter} ] Sweep took {t1 - t0} seconds") return A_val_list
def test_negative(backendopt): for datatype in backendopt: T.set_backend(datatype) x2 = ad.Variable(name="x2", shape=[3]) y = ad.sum(-x2) grad_x2, = ad.gradients(y, [x2]) executor = ad.Executor([y, grad_x2]) x2_val = 2 * T.ones(3) y_val, grad_x2_val = executor.run(feed_dict={x2: x2_val}) assert isinstance(y, ad.Node) assert T.array_equal(y_val, T.sum(-x2_val)) assert T.array_equal(grad_x2_val, -T.ones_like(x2_val))
def test_summation_einsum(backendopt): for datatype in backendopt: T.set_backend(datatype) x = ad.Variable(name="x", shape=[2, 2]) x_sum = ad.einsum('ij->', x) grad_x, = ad.gradients(x_sum, [x]) executor = ad.Executor([x_sum, grad_x]) x_val = T.tensor([[1., 2.], [3., 4.]]) x_sum_val, grad_x_val = executor.run(feed_dict={x: x_val}) expected_x_sum_val = T.sum(x_val) expected_grad_x_val = T.ones_like(x_val) assert T.array_equal(x_sum_val, expected_x_sum_val) assert T.array_equal(grad_x_val, expected_grad_x_val)
def tucker_als_graph_shared_exec(dim, size, rank): """ Build the graph used for Tucker ALS with shared execution. Parameters ---------- dim: dimensionality of the input tensor size: the size of input tensor's each dim rank: the rank of the decomposition Returns ------- tg: an TuckerGraph object executor: An shared executor loss: the optimized graph for tucker loss updates: an list containing updates graphs for each dimension intermediates: list of einsum nodes. Each node is the objective each Tucker ALS step optimized for """ tg = TuckerGraph(dim, size, rank) updates = [] for i in range(dim): core_A = tg.intermediates[i] hes = ad.hessian(tg.losses[i], [core_A]) hes = hes[0][0] grad, = ad.gradients(tg.losses[i], [core_A]) new_core_A = core_A - ad.tensordot( ad.tensorinv(hes), grad, [[i + dim for i in range(dim)], [i for i in range(dim)]]) updates.append(simplify(new_core_A)) loss = simplify(tg.losses[0]) for i in range(1, len(tg.losses)): assert loss.name == simplify(tg.losses[i]).name updates = generate_sequential_optimal_tree(updates, tg.A_list) executor_updates = ad.Executor(updates) executor_loss = ad.Executor([loss]) return tg, executor_updates, executor_loss, loss, updates, tg.intermediates
def test_cpd_grad(backendopt): dim = 3 for datatype in backendopt: T.set_backend(datatype) A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) A, B, C = A_list grad_A, grad_B, grad_C = ad.gradients(loss, [A, B, C]) executor = ad.Executor([loss, grad_A, grad_B, grad_C]) A_list, input_tensor_val = init_rand_cp(dim, size, rank) A_val, B_val, C_val = A_list loss_val, grad_A_val, grad_B_val, grad_C_val = executor.run( feed_dict={ input_tensor: input_tensor_val, A: A_val, B: B_val, C: C_val }) expected_output_tensor = T.einsum("ia,ja,ka->ijk", A_val, B_val, C_val) expected_residual = expected_output_tensor - input_tensor_val expected_norm_error = T.norm(expected_residual) expected_loss = expected_norm_error * expected_norm_error expected_contract_residual_A = 2 * T.einsum("ijk,ia->ajk", expected_residual, A_val) expected_contract_residual_B = 2 * T.einsum("ijk,ja->iak", expected_residual, B_val) expected_contract_residual_C = 2 * T.einsum("ijk,ka->ija", expected_residual, C_val) expected_grad_A = T.einsum("iak,ka->ia", expected_contract_residual_B, C_val) expected_grad_B = T.einsum("ajk,ka->ja", expected_contract_residual_A, C_val) expected_grad_C = T.einsum("ajk,ja->ka", expected_contract_residual_A, B_val) assert abs(loss_val - expected_loss) < 1e-8 assert T.norm(grad_A_val - expected_grad_A) < 1e-8 assert T.norm(grad_B_val - expected_grad_B) < 1e-8 assert T.norm(grad_C_val - expected_grad_C) < 1e-8
def test_inner_product_einsum(backendopt): for datatype in backendopt: T.set_backend(datatype) x = ad.Variable(name="x", shape=[3]) x_inner = ad.einsum('i,i->', x, x) grad_x, = ad.gradients(x_inner, [x]) executor = ad.Executor([x_inner, grad_x]) x_val = T.tensor([3., 4.]) # 1x2 y_val, grad_x_val = executor.run(feed_dict={x: x_val}) expected_yval = T.norm(x_val)**2 expected_grad_x_val = 2 * x_val assert isinstance(x_inner, ad.Node) assert T.array_equal(y_val, expected_yval) assert T.array_equal(grad_x_val, expected_grad_x_val)
def test_summation_einsum_2(backendopt): for datatype in backendopt: T.set_backend(datatype) x = ad.Variable(name="x", shape=[2, 2]) y = ad.Variable(name="y", shape=[2, 2]) out = ad.sum(ad.einsum('ij,ab->ab', x, y)) grad_x, = ad.gradients(out, [x]) executor = ad.Executor([out, grad_x]) x_val = T.tensor([[1., 2.], [3., 4.]]) y_val = T.tensor([[5., 6.], [7., 8.]]) out_val, grad_x_val = executor.run(feed_dict={x: x_val, y: y_val}) expected_out_val = T.sum(T.einsum('ij,ab->ab', x_val, y_val)) expected_grad_x_val = T.sum(y_val) * T.ones_like(x_val) assert T.array_equal(out_val, expected_out_val) assert T.array_equal(grad_x_val, expected_grad_x_val)
def test_transpose_einsum(backendopt): for datatype in backendopt: T.set_backend(datatype) x = ad.Variable(name="x", shape=[3, 2]) y = ad.sum(ad.einsum("ij->ji", x)) grad_x, = ad.gradients(y, [x]) executor = ad.Executor([y, grad_x]) x_val = T.tensor([[1, 2], [3, 4], [5, 6]]) # 3x2 y_val, grad_x_val = executor.run(feed_dict={x: x_val}) expected_yval = T.sum(T.transpose(x_val)) expected_grad_x_val = T.ones_like(x_val) assert isinstance(y, ad.Node) assert T.array_equal(y_val, expected_yval) assert T.array_equal(grad_x_val, expected_grad_x_val)
def test_trace_einsum(backendopt): for datatype in backendopt: if datatype == 'taco': continue # Currently taco doesn't support same subscript in one operand. T.set_backend(datatype) x = ad.Variable(name="x", shape=[2, 2]) trace = ad.einsum('ii->', x) grad_x, = ad.gradients(trace, [x]) executor = ad.Executor([trace, grad_x]) x_val = T.tensor([[1., 2.], [3., 4.]]) trace_val, grad_x_val = executor.run(feed_dict={x: x_val}) expected_trace_val = T.einsum('ii->', x_val) expected_grad_x_val = T.identity(2) assert T.array_equal(trace_val, expected_trace_val) assert T.array_equal(grad_x_val, expected_grad_x_val)
def tucker_als_graph(dim, size, rank): """ Build the graph used for Tucker ALS. Parameters ---------- dim: dimensionality of the input tensor size: the size of input tensor's each dim rank: the rank of the decomposition Returns ------- tg: an TuckerGraph object executors: list of executors. Each executor is used for one step of Tucker ALS intermediates: list of einsum nodes. Each node is the objective each Tucker ALS step optimized for """ tg = TuckerGraph(dim, size, rank) executors_update = [] for i in range(dim): core_A = tg.intermediates[i] hes = ad.hessian(tg.losses[i], [core_A]) hes = hes[0][0] grad, = ad.gradients(tg.losses[i], [core_A]) new_core_A = core_A - ad.tensordot( ad.tensorinv(hes), grad, [[i + dim for i in range(dim)], [i for i in range(dim)]]) executor = ad.Executor([simplify(new_core_A)]) executors_update.append(executor) executor_loss = ad.Executor([simplify(tg.losses[0])]) return tg, executors_update, executor_loss, tg.intermediates
def test_norm(backendopt): for datatype in backendopt: T.set_backend(datatype) x = ad.Variable(name="x", shape=[3, 2]) y = ad.norm(x) z = y**2 grad_x, = ad.gradients(z, [x]) executor = ad.Executor([z, grad_x]) x_val = T.tensor([[1., 2.], [3., 4.], [5., 6.]]) # 3x2 z_val, grad_x_val = executor.run(feed_dict={x: x_val}) expected_zval = T.norm(x_val)**2 expected_grad_x_val = 2 * x_val assert isinstance(z, ad.Node) assert T.array_equal(z_val, expected_zval) assert T.array_equal(grad_x_val, expected_grad_x_val)
def cpd_als(dim, size, rank, num_iter, input_val=[]): A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) full_hessian = ad.hessian(loss, A_list) hessians = [full_hessian[i][i] for i in range(len(full_hessian))] grads = ad.gradients(loss, A_list) updates = [ ad.tensordot(ad.tensorinv(hes), grad, [[2, 3], [0, 1]]) for (hes, grad) in zip(hessians, grads) ] new_A_list = [simplify(A - update) for (A, update) in zip(A_list, updates)] executor = ad.Executor(new_A_list) executor_loss = ad.Executor([simplify(loss)]) if input_val == []: A_val_list, input_tensor_val = init_rand_cp(dim, size, rank) else: A_val_list, input_tensor_val = input_val for iter in range(num_iter): # als iterations for i in range(len(A_list)): feed_dict = dict(zip(A_list, A_val_list)) feed_dict.update({input_tensor: input_tensor_val}) A_val_list[i], = executor.run(feed_dict=feed_dict, out_nodes=[new_A_list[i]]) feed_dict = dict(zip(A_list, A_val_list)) feed_dict.update({input_tensor: input_tensor_val}) loss_val, = executor_loss.run(feed_dict=feed_dict) print(f'At iteration {iter} the loss is: {loss_val}') return A_val_list
def test_einsum_3op(backendopt): for datatype in backendopt: T.set_backend(datatype) x2 = ad.Variable(name="x2", shape=[3, 2]) x3 = ad.Variable(name="x3", shape=[2, 3]) x4 = ad.Variable(name="x4", shape=[3, 2]) matmul = ad.einsum('ik,kj,jl->il', x2, x3, x4) y = ad.sum(matmul) grad_x2, grad_x3, grad_x4 = ad.gradients(y, [x2, x3, x4]) executor = ad.Executor([y, grad_x2, grad_x3, grad_x4]) x2_val = T.tensor([[1, 2], [3, 4], [5, 6]]) # 3x2 x3_val = T.tensor([[7, 8, 9], [10, 11, 12]]) # 2x3 x4_val = T.tensor([[1, 2], [3, 4], [5, 6]]) # 3x2 y_val, grad_x2_val, grad_x3_val, grad_x4_val = executor.run(feed_dict={ x2: x2_val, x3: x3_val, x4: x4_val }) expected_grad_sum = T.ones_like(T.dot(T.dot(x2_val, x3_val), x4_val)) expected_yval = T.sum(T.dot(T.dot(x2_val, x3_val), x4_val)) expected_grad_x2_val = T.einsum("il, kj, jl->ik", expected_grad_sum, x3_val, x4_val) expected_grad_x3_val = T.einsum("ik, il, jl->kj", x2_val, expected_grad_sum, x4_val) expected_grad_x4_val = T.einsum("ik, kj, il->jl", x2_val, x3_val, expected_grad_sum) assert isinstance(y, ad.Node) assert T.array_equal(y_val, expected_yval) assert T.array_equal(grad_x2_val, expected_grad_x2_val) assert T.array_equal(grad_x3_val, expected_grad_x3_val) assert T.array_equal(grad_x4_val, expected_grad_x4_val)
def test_mul_two_vars(backendopt): for datatype in backendopt: T.set_backend(datatype) x2 = ad.Variable(name="x2", shape=[3]) x3 = ad.Variable(name="x3", shape=[3]) y = ad.sum(x2 * x3) grad_x2, grad_x3 = ad.gradients(y, [x2, x3]) executor = ad.Executor([y, grad_x2, grad_x3]) x2_val = 2 * T.ones(3) x3_val = 3 * T.ones(3) y_val, grad_x2_val, grad_x3_val = executor.run(feed_dict={ x2: x2_val, x3: x3_val }) assert isinstance(y, ad.Node) assert T.array_equal(y_val, T.sum(x2_val * x3_val)) assert T.array_equal(grad_x2_val, x3_val) assert T.array_equal(grad_x3_val, x2_val)
def cpd_nls(size, rank, regularization=1e-7, mode='ad'): """ mode: ad / optimized / jax """ assert mode in {'ad', 'jax', 'optimized'} dim = 3 for datatype in BACKEND_TYPES: T.set_backend(datatype) T.seed(1) A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) A, B, C = A_list v_A = ad.Variable(name="v_A", shape=[size, rank]) v_B = ad.Variable(name="v_B", shape=[size, rank]) v_C = ad.Variable(name="v_C", shape=[size, rank]) grads = ad.gradients(loss, [A, B, C]) JtJvps = ad.jtjvps(output_node=residual, node_list=[A, B, C], vector_list=[v_A, v_B, v_C]) A_list, input_tensor_val = init_rand_cp(dim, size, rank) A_val, B_val, C_val = A_list if mode == 'jax': from source import SourceToSource StS = SourceToSource() StS.forward(JtJvps, file=open("examples/jax_jtjvp.py", "w"), function_name='jtjvp', backend='jax') executor_grads = ad.Executor([loss] + grads) JtJvps = [optimize(JtJvp) for JtJvp in JtJvps] dedup(*JtJvps) executor_JtJvps = ad.Executor(JtJvps) optimizer = cp_nls_optimizer(input_tensor_val, [A_val, B_val, C_val]) regu_increase = False normT = T.norm(input_tensor_val) time_all, fitness = 0., 0. for i in range(10): t0 = time.time() def hess_fn(v): if mode == 'optimized': from examples.cpd_jtjvp_optimized import jtjvp return jtjvp([v[0], B_val, C_val, v[1], A_val, v[2]]) elif mode == 'ad': return executor_JtJvps.run( feed_dict={ A: A_val, B: B_val, C: C_val, input_tensor: input_tensor_val, v_A: v[0], v_B: v[1], v_C: v[2] }) elif mode == 'jax': from examples.jax_jtjvp import jtjvp return jtjvp([B_val, C_val, v[0], A_val, v[1], v[2]]) loss_val, grad_A_val, grad_B_val, grad_C_val = executor_grads.run( feed_dict={ A: A_val, B: B_val, C: C_val, input_tensor: input_tensor_val }) res = math.sqrt(loss_val) fitness = 1 - res / normT print(f"[ {i} ] Residual is {res} fitness is: {fitness}") print(f"Regularization is: {regularization}") [A_val, B_val, C_val], total_cg_time = optimizer.step( hess_fn=hess_fn, grads=[grad_A_val / 2, grad_B_val / 2, grad_C_val / 2], regularization=regularization) t1 = time.time() print(f"[ {i} ] Sweep took {t1 - t0} seconds") time_all += t1 - t0 if regularization < 1e-07: regu_increase = True elif regularization > 1: regu_increase = False if regu_increase: regularization = regularization * 2 else: regularization = regularization / 2 return total_cg_time, fitness