def optimize(node): """Optimize a graph with a single output node. Args: node: The output node. Returns: node: The newly generated node. """ node = distribute_tree(node) linearize(node) all_nodes = find_topo_sort([node]) ret_node = PseudoNode(node) with OutputInjectedMode(all_nodes): trees = find_sub_einsumtree(ret_node) for tree in trees: out_node_p, in_nodes = tree new_z = fuse_einsums(out_node_p.node, in_nodes) prune_identity_nodes(new_z) new_z = generate_optimal_tree(new_z) replace_node(out_node_p, new_z) node = declone(ret_node.node) all_nodes = find_topo_sort([node]) for node in all_nodes: if isinstance(node, ad.EinsumNode): rewrite_einsum_expr(node) for node in find_topo_sort([node]): if node.inputs != []: node.set_inputs(node.inputs) dedup(node) return node
def test_cpd_jtjvp_optimized(benchmark): for datatype in BACKEND_TYPES: T.set_backend(datatype) A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) A, B, C = A_list v_A = ad.Variable(name="v_A", shape=[size, rank]) v_B = ad.Variable(name="v_B", shape=[size, rank]) v_C = ad.Variable(name="v_C", shape=[size, rank]) A_list, input_tensor_val = init_rand_cp(dim, size, rank) A_val, B_val, C_val = A_list v_A_list, _ = init_rand_cp(dim, size, rank) v_A_val, v_B_val, v_C_val = v_A_list JtJvps = ad.jtjvps(output_node=residual, node_list=[A, B, C], vector_list=[v_A, v_B, v_C]) JtJvps = [optimize(JtJvp) for JtJvp in JtJvps] dedup(*JtJvps) for node in JtJvps: assert isinstance(node, ad.AddNode) executor_JtJvps = ad.Executor(JtJvps) jtjvp_val = benchmark(executor_JtJvps.run, feed_dict={ A: A_val, B: B_val, C: C_val, input_tensor: input_tensor_val, v_A: v_A_val, v_B: v_B_val, v_C: v_C_val })
def generate_sequential_optimal_tree(einsum_nodes, input_nodes): """ Regenerating einsum expressions based on the dimension tree. Parameters ---------- einsum_nodes : list List of einsum nodes to be calculated based on the dimension tree. input_nodes : list List of input nodes whose contraction in the einsum_nodes obeys the sequence from the list end to the list start. Returns ------- List of einsum nodes whose results are the same as einsum_nodes, while obeys the dimension tree calculation sequence. Examples -------- >>> einsum_node_A = ad.einsum("abcd,bm,cm,dm->am", X, B, C, D) >>> einsum_node_B = ad.einsum("abcd,am,cm,dm->bm", X, A, C, D) >>> einsum_node_C = ad.einsum("abcd,am,bm,dm->cm", X, A, B, D) >>> dt = generate_sequential_optimal_tree([einsum_node_A, einsum_node_B, einsum_node_C], [A, B, C]) >>> dt [ad.einsum('bm,abm->am', B, ad.einsum('cm,abcm->abm', C, ad.einsum('abcd,dm->abcm', X, D))), ad.einsum('am,abm->bm', A, ad.einsum('cm,abcm->abm', C, ad.einsum('abcd,dm->abcm', X, D))), ad.einsum('am,bm,abcm->cm', A, B, ad.einsum('abcd,dm->abcm', X, D)), ] (einsum strings may be different) """ if len(einsum_nodes) == 1 and len(input_nodes) == 1: return einsum_nodes new_nodes = [] for (i, node) in enumerate(einsum_nodes): contract_order = input_nodes[i + 1:] contract_order.reverse() contract_order = contract_order + input_nodes[:i] # get the subarray that is the inputs of node contract_order = list( filter(lambda n: n in node.inputs, contract_order)) new_nodes.append( generate_optimal_tree_w_constraint(node, contract_order)) # After generate_optimal_tree_w_constraint, some einstrs are not in the canonical format, # needs to rewrite again for dedup all_nodes = find_topo_sort(new_nodes) with OutputInjectedMode(all_nodes): for node in all_nodes: if isinstance(node, ad.EinsumNode): rewrite_einsum_expr(node) if node.inputs != []: node.set_inputs(node.inputs) dedup(*new_nodes) remove_transposes(find_topo_sort(new_nodes)) return new_nodes
def test_dedup(): """ Dedup the tree. """ a = ad.Variable(name="a", shape=[2, 2]) b = ad.Variable(name="b", shape=[2, 2]) c = a + b d = a + b z = c + d dedup(z) # Assert object level equivalence. assert z.inputs[0] == z.inputs[1]
def test_cpd_jtjvp_optimize(backendopt): dim = 3 for datatype in backendopt: T.set_backend(datatype) A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) A, B, C = A_list v_A = ad.Variable(name="v_A", shape=[size, rank]) v_B = ad.Variable(name="v_B", shape=[size, rank]) v_C = ad.Variable(name="v_C", shape=[size, rank]) A_list, input_tensor_val = init_rand_cp(dim, size, rank) A_val, B_val, C_val = A_list v_A_list, _ = init_rand_cp(dim, size, rank) v_A_val, v_B_val, v_C_val = v_A_list JtJvps = ad.jtjvps(output_node=residual, node_list=[A, B, C], vector_list=[v_A, v_B, v_C]) JtJvps = [optimize(JtJvp) for JtJvp in JtJvps] dedup(*JtJvps) for node in JtJvps: assert isinstance(node, ad.AddNode) executor_JtJvps = ad.Executor(JtJvps) jtjvp_val = executor_JtJvps.run( feed_dict={ A: A_val, B: B_val, C: C_val, input_tensor: input_tensor_val, v_A: v_A_val, v_B: v_B_val, v_C: v_C_val }) expected_hvp_val = expect_jtjvp_val(A_val, B_val, C_val, v_A_val, v_B_val, v_C_val) assert T.norm(jtjvp_val[0] - expected_hvp_val[0]) < 1e-8 assert T.norm(jtjvp_val[1] - expected_hvp_val[1]) < 1e-8 assert T.norm(jtjvp_val[2] - expected_hvp_val[2]) < 1e-8
def cpd_nls(size, rank, regularization=1e-7, mode='ad'): """ mode: ad / optimized / jax """ assert mode in {'ad', 'jax', 'optimized'} dim = 3 for datatype in BACKEND_TYPES: T.set_backend(datatype) T.seed(1) A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) A, B, C = A_list v_A = ad.Variable(name="v_A", shape=[size, rank]) v_B = ad.Variable(name="v_B", shape=[size, rank]) v_C = ad.Variable(name="v_C", shape=[size, rank]) grads = ad.gradients(loss, [A, B, C]) JtJvps = ad.jtjvps(output_node=residual, node_list=[A, B, C], vector_list=[v_A, v_B, v_C]) A_list, input_tensor_val = init_rand_cp(dim, size, rank) A_val, B_val, C_val = A_list if mode == 'jax': from source import SourceToSource StS = SourceToSource() StS.forward(JtJvps, file=open("examples/jax_jtjvp.py", "w"), function_name='jtjvp', backend='jax') executor_grads = ad.Executor([loss] + grads) JtJvps = [optimize(JtJvp) for JtJvp in JtJvps] dedup(*JtJvps) executor_JtJvps = ad.Executor(JtJvps) optimizer = cp_nls_optimizer(input_tensor_val, [A_val, B_val, C_val]) regu_increase = False normT = T.norm(input_tensor_val) time_all, fitness = 0., 0. for i in range(10): t0 = time.time() def hess_fn(v): if mode == 'optimized': from examples.cpd_jtjvp_optimized import jtjvp return jtjvp([v[0], B_val, C_val, v[1], A_val, v[2]]) elif mode == 'ad': return executor_JtJvps.run( feed_dict={ A: A_val, B: B_val, C: C_val, input_tensor: input_tensor_val, v_A: v[0], v_B: v[1], v_C: v[2] }) elif mode == 'jax': from examples.jax_jtjvp import jtjvp return jtjvp([B_val, C_val, v[0], A_val, v[1], v[2]]) loss_val, grad_A_val, grad_B_val, grad_C_val = executor_grads.run( feed_dict={ A: A_val, B: B_val, C: C_val, input_tensor: input_tensor_val }) res = math.sqrt(loss_val) fitness = 1 - res / normT print(f"[ {i} ] Residual is {res} fitness is: {fitness}") print(f"Regularization is: {regularization}") [A_val, B_val, C_val], total_cg_time = optimizer.step( hess_fn=hess_fn, grads=[grad_A_val / 2, grad_B_val / 2, grad_C_val / 2], regularization=regularization) t1 = time.time() print(f"[ {i} ] Sweep took {t1 - t0} seconds") time_all += t1 - t0 if regularization < 1e-07: regu_increase = True elif regularization > 1: regu_increase = False if regu_increase: regularization = regularization * 2 else: regularization = regularization / 2 return total_cg_time, fitness