def test_cpd_jtjvp_optimized(benchmark): for datatype in BACKEND_TYPES: T.set_backend(datatype) A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) A, B, C = A_list v_A = ad.Variable(name="v_A", shape=[size, rank]) v_B = ad.Variable(name="v_B", shape=[size, rank]) v_C = ad.Variable(name="v_C", shape=[size, rank]) A_list, input_tensor_val = init_rand_cp(dim, size, rank) A_val, B_val, C_val = A_list v_A_list, _ = init_rand_cp(dim, size, rank) v_A_val, v_B_val, v_C_val = v_A_list JtJvps = ad.jtjvps(output_node=residual, node_list=[A, B, C], vector_list=[v_A, v_B, v_C]) JtJvps = [optimize(JtJvp) for JtJvp in JtJvps] dedup(*JtJvps) for node in JtJvps: assert isinstance(node, ad.AddNode) executor_JtJvps = ad.Executor(JtJvps) jtjvp_val = benchmark(executor_JtJvps.run, feed_dict={ A: A_val, B: B_val, C: C_val, input_tensor: input_tensor_val, v_A: v_A_val, v_B: v_B_val, v_C: v_C_val })
def test_cpd_jtjvp(benchmark): for datatype in BACKEND_TYPES: T.set_backend(datatype) A_list, input_tensor_val = init_rand_cp(dim, size, rank) A_val, B_val, C_val = A_list v_A_list, _ = init_rand_cp(dim, size, rank) v_A_val, v_B_val, v_C_val = v_A_list expected_hvp_val = benchmark(expect_jtjvp_val, A_val, B_val, C_val, v_A_val, v_B_val, v_C_val)
def test_cpd_shared_exec(backendopt): dim = 3 for datatype in backendopt: T.set_backend(datatype) input_val = init_rand_cp(dim, size, rank) A_list, input_tensor_val = input_val A_val, B_val, C_val = A_list outputs = cpd_als_shared_exec(dim, size, rank, 1, input_val) # expected values A_val = T.einsum( "abc,bk,ck->ak", input_tensor_val, B_val, C_val) @ T.inv( (T.transpose(B_val) @ B_val) * (T.transpose(C_val) @ C_val)) B_val = T.einsum( "abc,ak,ck->bk", input_tensor_val, A_val, C_val) @ T.inv( (T.transpose(A_val) @ A_val) * (T.transpose(C_val) @ C_val)) C_val = T.einsum( "abc,ak,bk->ck", input_tensor_val, A_val, B_val) @ T.inv( (T.transpose(A_val) @ A_val) * (T.transpose(B_val) @ B_val)) assert T.norm(outputs[0] - A_val) < 1e-8 assert T.norm(outputs[1] - B_val) < 1e-8 assert T.norm(outputs[2] - C_val) < 1e-8
def cpd_gradient_descent(size, rank, learning_rate): dim = 3 for datatype in BACKEND_TYPES: T.set_backend(datatype) A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) A, B, C = A_list grad_A, grad_B, grad_C = ad.gradients(loss, [A, B, C]) executor = ad.Executor([loss, grad_A, grad_B, grad_C]) A_list, input_tensor_val = init_rand_cp(dim, size, rank) A_val, B_val, C_val = A_list for i in range(100): loss_val, grad_A_val, grad_B_val, grad_C_val = executor.run( feed_dict={ input_tensor: input_tensor_val, A: A_val, B: B_val, C: C_val }) A_val -= learning_rate * grad_A_val B_val -= learning_rate * grad_B_val C_val -= learning_rate * grad_C_val print(f'At iteration {i} the loss is: {loss_val}')
def test_cpd_hessian_optimize_offdiag(backendopt): dim = 3 for datatype in backendopt: T.set_backend(datatype) A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) A, B, C = A_list A_list, input_tensor_val = init_rand_cp(dim, size, rank) A_val, B_val, C_val = A_list hessian = ad.hessian(loss, [A, B, C]) hessian_offdiag = [hessian[0][1], hessian[1][0]] for node in hessian_offdiag: optimize(node) assert isinstance(node, ad.AddNode) num_operations = len( list( filter(lambda x: isinstance(x, ad.OpNode), find_topo_sort([node])))) # This is currently non-deterministic. # assert num_operations == 14 executor = ad.Executor(hessian_offdiag) hes_diag_vals = executor.run(feed_dict={ A: A_val, B: B_val, C: C_val, input_tensor: input_tensor_val, })
def test_cpd_als_sktensor(benchmark): for datatype in BACKEND_TYPES: _, input_tensor_val = init_rand_cp(dim, size, rank) benchmark(sk_cp_als, dtensor(input_tensor_val), rank=rank, max_iter=1, init='random')
def test_cpd_jtjvp_optimize(backendopt): dim = 3 for datatype in backendopt: T.set_backend(datatype) A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) A, B, C = A_list v_A = ad.Variable(name="v_A", shape=[size, rank]) v_B = ad.Variable(name="v_B", shape=[size, rank]) v_C = ad.Variable(name="v_C", shape=[size, rank]) A_list, input_tensor_val = init_rand_cp(dim, size, rank) A_val, B_val, C_val = A_list v_A_list, _ = init_rand_cp(dim, size, rank) v_A_val, v_B_val, v_C_val = v_A_list JtJvps = ad.jtjvps(output_node=residual, node_list=[A, B, C], vector_list=[v_A, v_B, v_C]) JtJvps = [optimize(JtJvp) for JtJvp in JtJvps] dedup(*JtJvps) for node in JtJvps: assert isinstance(node, ad.AddNode) executor_JtJvps = ad.Executor(JtJvps) jtjvp_val = executor_JtJvps.run( feed_dict={ A: A_val, B: B_val, C: C_val, input_tensor: input_tensor_val, v_A: v_A_val, v_B: v_B_val, v_C: v_C_val }) expected_hvp_val = expect_jtjvp_val(A_val, B_val, C_val, v_A_val, v_B_val, v_C_val) assert T.norm(jtjvp_val[0] - expected_hvp_val[0]) < 1e-8 assert T.norm(jtjvp_val[1] - expected_hvp_val[1]) < 1e-8 assert T.norm(jtjvp_val[2] - expected_hvp_val[2]) < 1e-8
def cpd_newton(size, rank): dim = 3 for datatype in BACKEND_TYPES: T.set_backend(datatype) A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) A, B, C = A_list v_A = ad.Variable(name="v_A", shape=[size, rank]) v_B = ad.Variable(name="v_B", shape=[size, rank]) v_C = ad.Variable(name="v_C", shape=[size, rank]) grads = ad.gradients(loss, [A, B, C]) Hvps = ad.hvp(output_node=loss, node_list=[A, B, C], vector_list=[v_A, v_B, v_C]) executor_grads = ad.Executor([loss] + grads) executor_Hvps = ad.Executor(Hvps) A_list, input_tensor_val = init_rand_cp(dim, size, rank) A_val, B_val, C_val = A_list for i in range(100): def hess_fn(v): return executor_Hvps.run( feed_dict={ A: A_val, B: B_val, C: C_val, input_tensor: input_tensor_val, v_A: v[0], v_B: v[1], v_C: v[2] }) loss_val, grad_A_val, grad_B_val, grad_C_val = executor_grads.run( feed_dict={ A: A_val, B: B_val, C: C_val, input_tensor: input_tensor_val }) delta = conjugate_gradient( hess_fn=hess_fn, grads=[grad_A_val, grad_B_val, grad_C_val], error_tol=1e-9, max_iters=250) A_val -= delta[0] B_val -= delta[1] C_val -= delta[2] print(f'At iteration {i} the loss is: {loss_val}')
def cpd_als_shared_exec(dim, size, rank, num_iter, input_val=[]): A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) full_hessian = ad.hessian(loss, A_list) hessians = [full_hessian[i][i] for i in range(len(full_hessian))] grads = ad.gradients(loss, A_list) updates = [ ad.tensordot(ad.tensorinv(hes), grad, [[2, 3], [0, 1]]) for (hes, grad) in zip(hessians, grads) ] new_A_list = [simplify(A - update) for (A, update) in zip(A_list, updates)] new_A_list = generate_sequential_optimal_tree(new_A_list, A_list) executor = ad.Executor(new_A_list) executor_loss = ad.Executor([simplify(loss)]) if input_val == []: A_val_list, input_tensor_val = init_rand_cp(dim, size, rank) else: A_val_list, input_tensor_val = input_val for iter in range(num_iter): t0 = time.time() # als iterations for i in range(len(A_list)): feed_dict = dict(zip(A_list, A_val_list)) feed_dict.update({input_tensor: input_tensor_val}) if i == 0: A_val_list[0], = executor.run(feed_dict=feed_dict, out_nodes=[new_A_list[0]]) else: A_val_list[i], = executor.run(feed_dict=feed_dict, reset_graph=False, evicted_inputs=[A_list[i - 1]], out_nodes=[new_A_list[i]]) feed_dict = dict(zip(A_list, A_val_list)) feed_dict.update({input_tensor: input_tensor_val}) loss_val, = executor_loss.run(feed_dict=feed_dict) print(f'At iteration {iter} the loss is: {loss_val}') t1 = time.time() print(f"[ {iter} ] Sweep took {t1 - t0} seconds") return A_val_list
def test_cpd_hessian_optimize_diag(backendopt): dim = 3 for datatype in backendopt: T.set_backend(datatype) A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) A, B, C = A_list A_list, input_tensor_val = init_rand_cp(dim, size, rank) A_val, B_val, C_val = A_list hessian = ad.hessian(loss, [A, B, C]) hessian_diag = [hessian[0][0], hessian[1][1], hessian[2][2]] for node in hessian_diag: node = optimize(node) assert isinstance(node, ad.AddNode) num_operations = len( list( filter(lambda x: isinstance(x, ad.OpNode), find_topo_sort([node])))) """ Use this assertion to test the optimize function. 5 operations: 1. T.einsum('ca,cb->ab',A,A), 2. T.einsum('ca,cb->ab',B,B), 3. T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)), 4. T.einsum('bd,ac->abcd',T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)),T.identity(10)), 5. (T.einsum('bd,ac->abcd',T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)),T.identity(10))+ T.einsum('bd,ac->abcd',T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)),T.identity(10))) """ assert num_operations == 5 executor = ad.Executor(hessian_diag) hes_diag_vals = executor.run(feed_dict={ A: A_val, B: B_val, C: C_val, input_tensor: input_tensor_val, }) expected_hes_diag_val = [ 2 * T.einsum('eb,ed,fb,fd,ac->abcd', B_val, B_val, C_val, C_val, T.identity(size)), 2 * T.einsum('eb,ed,fb,fd,ac->abcd', A_val, A_val, C_val, C_val, T.identity(size)), 2 * T.einsum('eb,ed,fb,fd,ac->abcd', A_val, A_val, B_val, B_val, T.identity(size)) ] assert T.norm(hes_diag_vals[0] - expected_hes_diag_val[0]) < 1e-8 assert T.norm(hes_diag_vals[1] - expected_hes_diag_val[1]) < 1e-8 assert T.norm(hes_diag_vals[2] - expected_hes_diag_val[2]) < 1e-8
def test_cpd_als_tensorly(benchmark): for datatype in BACKEND_TYPES: tl.set_backend(datatype) assert tl.get_backend() == datatype _, input_tensor_val = init_rand_cp(dim, size, rank) input_tensor = tl.tensor(input_tensor_val, dtype='float64') factors = benchmark(parafac, input_tensor, rank=rank, init='random', tol=0, n_iter_max=1, verbose=0)
def test_cpd_grad(backendopt): dim = 3 for datatype in backendopt: T.set_backend(datatype) A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) A, B, C = A_list grad_A, grad_B, grad_C = ad.gradients(loss, [A, B, C]) executor = ad.Executor([loss, grad_A, grad_B, grad_C]) A_list, input_tensor_val = init_rand_cp(dim, size, rank) A_val, B_val, C_val = A_list loss_val, grad_A_val, grad_B_val, grad_C_val = executor.run( feed_dict={ input_tensor: input_tensor_val, A: A_val, B: B_val, C: C_val }) expected_output_tensor = T.einsum("ia,ja,ka->ijk", A_val, B_val, C_val) expected_residual = expected_output_tensor - input_tensor_val expected_norm_error = T.norm(expected_residual) expected_loss = expected_norm_error * expected_norm_error expected_contract_residual_A = 2 * T.einsum("ijk,ia->ajk", expected_residual, A_val) expected_contract_residual_B = 2 * T.einsum("ijk,ja->iak", expected_residual, B_val) expected_contract_residual_C = 2 * T.einsum("ijk,ka->ija", expected_residual, C_val) expected_grad_A = T.einsum("iak,ka->ia", expected_contract_residual_B, C_val) expected_grad_B = T.einsum("ajk,ka->ja", expected_contract_residual_A, C_val) expected_grad_C = T.einsum("ajk,ja->ka", expected_contract_residual_A, B_val) assert abs(loss_val - expected_loss) < 1e-8 assert T.norm(grad_A_val - expected_grad_A) < 1e-8 assert T.norm(grad_B_val - expected_grad_B) < 1e-8 assert T.norm(grad_C_val - expected_grad_C) < 1e-8
def cpd_als(dim, size, rank, num_iter, input_val=[]): A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) full_hessian = ad.hessian(loss, A_list) hessians = [full_hessian[i][i] for i in range(len(full_hessian))] grads = ad.gradients(loss, A_list) updates = [ ad.tensordot(ad.tensorinv(hes), grad, [[2, 3], [0, 1]]) for (hes, grad) in zip(hessians, grads) ] new_A_list = [simplify(A - update) for (A, update) in zip(A_list, updates)] executor = ad.Executor(new_A_list) executor_loss = ad.Executor([simplify(loss)]) if input_val == []: A_val_list, input_tensor_val = init_rand_cp(dim, size, rank) else: A_val_list, input_tensor_val = input_val for iter in range(num_iter): # als iterations for i in range(len(A_list)): feed_dict = dict(zip(A_list, A_val_list)) feed_dict.update({input_tensor: input_tensor_val}) A_val_list[i], = executor.run(feed_dict=feed_dict, out_nodes=[new_A_list[i]]) feed_dict = dict(zip(A_list, A_val_list)) feed_dict.update({input_tensor: input_tensor_val}) loss_val, = executor_loss.run(feed_dict=feed_dict) print(f'At iteration {iter} the loss is: {loss_val}') return A_val_list
def test_cpd_hessian_simplify(backendopt): dim = 3 for datatype in backendopt: T.set_backend(datatype) A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) A, B, C = A_list A_list, input_tensor_val = init_rand_cp(dim, size, rank) A_val, B_val, C_val = A_list hessian = ad.hessian(loss, [A, B, C]) # TODO (issue #101): test the off-diagonal elements hessian_diag = [hessian[0][0], hessian[1][1], hessian[2][2]] for node in hessian_diag: node = simplify(node) input_node = node.inputs[0] assert len(input_node.inputs) == 5 executor = ad.Executor(hessian_diag) hes_diag_vals = executor.run(feed_dict={ A: A_val, B: B_val, C: C_val, input_tensor: input_tensor_val, }) expected_hes_diag_val = [ 2 * T.einsum('eb,ed,fb,fd,ac->abcd', B_val, B_val, C_val, C_val, T.identity(size)), 2 * T.einsum('eb,ed,fb,fd,ac->abcd', A_val, A_val, C_val, C_val, T.identity(size)), 2 * T.einsum('eb,ed,fb,fd,ac->abcd', A_val, A_val, B_val, B_val, T.identity(size)) ] assert T.norm(hes_diag_vals[0] - expected_hes_diag_val[0]) < 1e-8 assert T.norm(hes_diag_vals[1] - expected_hes_diag_val[1]) < 1e-8 assert T.norm(hes_diag_vals[2] - expected_hes_diag_val[2]) < 1e-8
def test_cpd_als_shared_exec(benchmark): for datatype in BACKEND_TYPES: input_tensor = init_rand_cp(dim, size, rank) outputs = benchmark(cpd_als_shared_exec, dim, size, rank, 1, input_tensor)
def cpd_nls(size, rank, regularization=1e-7, mode='ad'): """ mode: ad / optimized / jax """ assert mode in {'ad', 'jax', 'optimized'} dim = 3 for datatype in BACKEND_TYPES: T.set_backend(datatype) T.seed(1) A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) A, B, C = A_list v_A = ad.Variable(name="v_A", shape=[size, rank]) v_B = ad.Variable(name="v_B", shape=[size, rank]) v_C = ad.Variable(name="v_C", shape=[size, rank]) grads = ad.gradients(loss, [A, B, C]) JtJvps = ad.jtjvps(output_node=residual, node_list=[A, B, C], vector_list=[v_A, v_B, v_C]) A_list, input_tensor_val = init_rand_cp(dim, size, rank) A_val, B_val, C_val = A_list if mode == 'jax': from source import SourceToSource StS = SourceToSource() StS.forward(JtJvps, file=open("examples/jax_jtjvp.py", "w"), function_name='jtjvp', backend='jax') executor_grads = ad.Executor([loss] + grads) JtJvps = [optimize(JtJvp) for JtJvp in JtJvps] dedup(*JtJvps) executor_JtJvps = ad.Executor(JtJvps) optimizer = cp_nls_optimizer(input_tensor_val, [A_val, B_val, C_val]) regu_increase = False normT = T.norm(input_tensor_val) time_all, fitness = 0., 0. for i in range(10): t0 = time.time() def hess_fn(v): if mode == 'optimized': from examples.cpd_jtjvp_optimized import jtjvp return jtjvp([v[0], B_val, C_val, v[1], A_val, v[2]]) elif mode == 'ad': return executor_JtJvps.run( feed_dict={ A: A_val, B: B_val, C: C_val, input_tensor: input_tensor_val, v_A: v[0], v_B: v[1], v_C: v[2] }) elif mode == 'jax': from examples.jax_jtjvp import jtjvp return jtjvp([B_val, C_val, v[0], A_val, v[1], v[2]]) loss_val, grad_A_val, grad_B_val, grad_C_val = executor_grads.run( feed_dict={ A: A_val, B: B_val, C: C_val, input_tensor: input_tensor_val }) res = math.sqrt(loss_val) fitness = 1 - res / normT print(f"[ {i} ] Residual is {res} fitness is: {fitness}") print(f"Regularization is: {regularization}") [A_val, B_val, C_val], total_cg_time = optimizer.step( hess_fn=hess_fn, grads=[grad_A_val / 2, grad_B_val / 2, grad_C_val / 2], regularization=regularization) t1 = time.time() print(f"[ {i} ] Sweep took {t1 - t0} seconds") time_all += t1 - t0 if regularization < 1e-07: regu_increase = True elif regularization > 1: regu_increase = False if regu_increase: regularization = regularization * 2 else: regularization = regularization / 2 return total_cg_time, fitness