def optimize(node): """Optimize a graph with a single output node. Args: node: The output node. Returns: node: The newly generated node. """ node = distribute_tree(node) linearize(node) all_nodes = find_topo_sort([node]) ret_node = PseudoNode(node) with OutputInjectedMode(all_nodes): trees = find_sub_einsumtree(ret_node) for tree in trees: out_node_p, in_nodes = tree new_z = fuse_einsums(out_node_p.node, in_nodes) prune_identity_nodes(new_z) new_z = generate_optimal_tree(new_z) replace_node(out_node_p, new_z) node = declone(ret_node.node) all_nodes = find_topo_sort([node]) for node in all_nodes: if isinstance(node, ad.EinsumNode): rewrite_einsum_expr(node) for node in find_topo_sort([node]): if node.inputs != []: node.set_inputs(node.inputs) dedup(node) return node
def generate_sequential_optimal_tree(einsum_nodes, input_nodes): """ Regenerating einsum expressions based on the dimension tree. Parameters ---------- einsum_nodes : list List of einsum nodes to be calculated based on the dimension tree. input_nodes : list List of input nodes whose contraction in the einsum_nodes obeys the sequence from the list end to the list start. Returns ------- List of einsum nodes whose results are the same as einsum_nodes, while obeys the dimension tree calculation sequence. Examples -------- >>> einsum_node_A = ad.einsum("abcd,bm,cm,dm->am", X, B, C, D) >>> einsum_node_B = ad.einsum("abcd,am,cm,dm->bm", X, A, C, D) >>> einsum_node_C = ad.einsum("abcd,am,bm,dm->cm", X, A, B, D) >>> dt = generate_sequential_optimal_tree([einsum_node_A, einsum_node_B, einsum_node_C], [A, B, C]) >>> dt [ad.einsum('bm,abm->am', B, ad.einsum('cm,abcm->abm', C, ad.einsum('abcd,dm->abcm', X, D))), ad.einsum('am,abm->bm', A, ad.einsum('cm,abcm->abm', C, ad.einsum('abcd,dm->abcm', X, D))), ad.einsum('am,bm,abcm->cm', A, B, ad.einsum('abcd,dm->abcm', X, D)), ] (einsum strings may be different) """ if len(einsum_nodes) == 1 and len(input_nodes) == 1: return einsum_nodes new_nodes = [] for (i, node) in enumerate(einsum_nodes): contract_order = input_nodes[i + 1:] contract_order.reverse() contract_order = contract_order + input_nodes[:i] # get the subarray that is the inputs of node contract_order = list( filter(lambda n: n in node.inputs, contract_order)) new_nodes.append( generate_optimal_tree_w_constraint(node, contract_order)) # After generate_optimal_tree_w_constraint, some einstrs are not in the canonical format, # needs to rewrite again for dedup all_nodes = find_topo_sort(new_nodes) with OutputInjectedMode(all_nodes): for node in all_nodes: if isinstance(node, ad.EinsumNode): rewrite_einsum_expr(node) if node.inputs != []: node.set_inputs(node.inputs) dedup(*new_nodes) remove_transposes(find_topo_sort(new_nodes)) return new_nodes
def test_einsum_multiuse_auto_copy(backendopt): """ Test autolinearization and auto fuse. A B inputs |\ | | \ | | \ | | C | / | / output Next: we would need to autoprune. """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a1", shape=[3, 2]) b = ad.Variable(name="b", shape=[2, 3]) c = ad.einsum('ik,kj->ij', a, b) output = ad.einsum('ik,ij->kj', a, c) linearize(output) all_nodes = find_topo_sort([output]) cloned_nodes = [ tmp for tmp in all_nodes if isinstance(tmp, ad.CloneNode) ] out_new = fuse_einsums(output, [*cloned_nodes, b]) # Test that every inputs is now fused. assert all([not isinstance(x, ad.EinsumNode) for x in out_new.inputs]) assert tree_eq(output, out_new, [*cloned_nodes, b])
def test_einsum_gen_corner_case(backendopt): """ Note: Numpy contraction path cannot find the opt path for this expression. It will output the same expression as the input. -------- E -------- | | | | a b c d | | | | A - e - B - f - C - g - D | | | | h i j k | | | | """ size = 5 A = ad.Variable(name="A", shape=[size, size, size]) B = ad.Variable(name="B", shape=[size, size, size, size]) C = ad.Variable(name="C", shape=[size, size, size, size]) D = ad.Variable(name="D", shape=[size, size, size]) E = ad.Variable(name="E", shape=[size, size, size, size]) output = ad.einsum('aeh,bfie,cgjf,dgk,abcd->hijk', A, B, C, D, E) new_output = generate_optimal_tree(output) for node in find_topo_sort([new_output]): if not isinstance(node, ad.VariableNode): assert (len(node.inputs) == 2)
def get_common_ancestor(root, leaves, in_node): """ Get in_node's common ancestor of a tree(defined by root and leaves). Here our tree may let a leaf in_node has multiple parents. Parameters ---------- root: Tree root. leaves: A list of leaf nodes define the inputs of the subtree. in_node: one of the node in leaves such that multiple intermediate nodes can have it as children. Returns ---------- ancestor: A ancestor that covers all the in_node(s) in the tree. """ assert in_node in leaves num_in_nodes = len(list(filter(lambda n: n is in_node, leaves))) topo_order_list = find_topo_sort([root], leaves) for node in topo_order_list: # We want to get the smallest subtree whose inputs contain all the in_node(s). if isinstance(node, ad.EinsumNode): subtree_leaves = [ n for n in get_all_nodes([node], leaves) if n in leaves ] num_in_nodes_subtree = len( list(filter(lambda n: n is in_node, subtree_leaves))) if num_in_nodes == num_in_nodes_subtree: return node
def print_computation_graph(output_node_list, input_nodes=[]): """ ouput_node_list: a list of output nodes. """ assert len(output_node_list) > 0 topo_order = find_topo_sort(output_node_list, input_nodes) inputs = list(filter(lambda x: isinstance(x, ad.VariableNode), topo_order)) with OutputInjectedMode(topo_order): dot = Digraph(comment='Poorman Computation Graph') with dot.subgraph() as s: s.attr(rank='same') for n in inputs: s.node(n.name, style='filled', color='aquamarine3') with dot.subgraph() as s: s.attr(rank='same') for n in output_node_list: s.node(n.name, style='filled', color='thistle') with dot.subgraph() as s: for n in topo_order: if (n not in output_node_list and n not in inputs): s.node(n.name, style='filled', color='lightblue') for node in topo_order: dot.node(node.name, graph_name(node)) for node_i in node.inputs: dot.edge(node_i.name, node.name) print(dot.source)
def test_cpd_hessian_optimize_offdiag(backendopt): dim = 3 for datatype in backendopt: T.set_backend(datatype) A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) A, B, C = A_list A_list, input_tensor_val = init_rand_cp(dim, size, rank) A_val, B_val, C_val = A_list hessian = ad.hessian(loss, [A, B, C]) hessian_offdiag = [hessian[0][1], hessian[1][0]] for node in hessian_offdiag: optimize(node) assert isinstance(node, ad.AddNode) num_operations = len( list( filter(lambda x: isinstance(x, ad.OpNode), find_topo_sort([node])))) # This is currently non-deterministic. # assert num_operations == 14 executor = ad.Executor(hessian_offdiag) hes_diag_vals = executor.run(feed_dict={ A: A_val, B: B_val, C: C_val, input_tensor: input_tensor_val, })
def _sub_forward(self, output_node_list): """Forward pass subroutine""" file_string = '' topo_order = find_topo_sort(output_node_list) file_string += indent_line(f'# forward pass starts') for node in topo_order: if isinstance(node, ad.VariableNode): file_string += indent_line(self._assign_init_variable(node)) elif isinstance(node, ad.OpNode): file_string += indent_line(self._assign_mid_variable(node)) return file_string
def test_remove_transposes_multiple_trans(): a = ad.Variable(name="a", shape=[2, 2, 2, 2]) intermediate1 = ad.einsum("abcd->dcba", a) intermediate2 = ad.einsum("abcd->abdc", a) ret1 = ad.einsum("dcba->badc", intermediate1) ret2 = ad.einsum("abdc->badc", intermediate2) remove_transposes(find_topo_sort([ret1, ret2])) assert ret1.name == ret2.name
def test_cpd_hessian_optimize_diag(backendopt): dim = 3 for datatype in backendopt: T.set_backend(datatype) A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) A, B, C = A_list A_list, input_tensor_val = init_rand_cp(dim, size, rank) A_val, B_val, C_val = A_list hessian = ad.hessian(loss, [A, B, C]) hessian_diag = [hessian[0][0], hessian[1][1], hessian[2][2]] for node in hessian_diag: node = optimize(node) assert isinstance(node, ad.AddNode) num_operations = len( list( filter(lambda x: isinstance(x, ad.OpNode), find_topo_sort([node])))) """ Use this assertion to test the optimize function. 5 operations: 1. T.einsum('ca,cb->ab',A,A), 2. T.einsum('ca,cb->ab',B,B), 3. T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)), 4. T.einsum('bd,ac->abcd',T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)),T.identity(10)), 5. (T.einsum('bd,ac->abcd',T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)),T.identity(10))+ T.einsum('bd,ac->abcd',T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)),T.identity(10))) """ assert num_operations == 5 executor = ad.Executor(hessian_diag) hes_diag_vals = executor.run(feed_dict={ A: A_val, B: B_val, C: C_val, input_tensor: input_tensor_val, }) expected_hes_diag_val = [ 2 * T.einsum('eb,ed,fb,fd,ac->abcd', B_val, B_val, C_val, C_val, T.identity(size)), 2 * T.einsum('eb,ed,fb,fd,ac->abcd', A_val, A_val, C_val, C_val, T.identity(size)), 2 * T.einsum('eb,ed,fb,fd,ac->abcd', A_val, A_val, B_val, B_val, T.identity(size)) ] assert T.norm(hes_diag_vals[0] - expected_hes_diag_val[0]) < 1e-8 assert T.norm(hes_diag_vals[1] - expected_hes_diag_val[1]) < 1e-8 assert T.norm(hes_diag_vals[2] - expected_hes_diag_val[2]) < 1e-8
def _sub_gTv(self, vector_list): """Subroutine of g and v inner product.""" file_string = '\n' file_string += indent_line(f'# inner product of g and v starts') for node in vector_list: file_string += indent_line(self._assign_init_variable(node)) inner_product_node = inner_product(vector_list, self.gradient_list) topo_order = find_topo_sort([inner_product_node]) for node in topo_order: if node not in self.topo_order_gradients and \ node is not inner_product_node and \ node not in vector_list: file_string += self._assign_mid_variable(node) file_string += indent_line( f'_gTv = {inner_product_node.s2s_expr(inner_product_node.inputs)}') inner_product_node.name = '_gTv' return inner_product_node, file_string
def test_remove_transposes(): a = ad.Variable(name="a", shape=[2, 2, 2, 2]) b = ad.Variable(name="b", shape=[2, 2]) c = ad.Variable(name="b", shape=[2, 2]) d = ad.Variable(name="b", shape=[2, 2]) ab1 = ad.einsum("abcd,de->abce", a, b) ab2 = ad.einsum("abcd,de->ecba", a, b) abc1 = ad.einsum("abce,ce->abe", ab1, c) abc2 = ad.einsum("ecba,ce->eba", ab2, c) abcd1 = ad.einsum("abe,be->ae", abc1, d) abcd2 = ad.einsum("eba,be->ae", abc2, d) remove_transposes(find_topo_sort([abcd1, abcd2])) assert abcd1.name == abcd2.name
def test_dmrg_shared_exec_graph(): from graph_ops.graph_transformer import simplify from graph_ops.graph_als_optimizer import generate_sequential_optimal_tree from utils import find_topo_sort num, rank, size = 4, 3, 2 mpo_ranks = [rank for i in range(1, num)] mps_ranks = [rank for i in range(1, num)] dg = DmrgGraph.create(num, mpo_ranks, mps_ranks, size) for i, hes in enumerate(dg.hessians): dg.hessians[i] = simplify(hes) assert isinstance(hes, ad.EinsumNode) dg.hessians = generate_sequential_optimal_tree(dg.hessians, dg.mps_inputs) # 8 input variables (4 H term in MPO, 4 A term in MPS), 7 einsum nodes assert len(find_topo_sort(dg.hessians)) == 15
def test_simple_dmrg_tree(): A1 = ad.Variable(name="A1", shape=[3, 2]) A2 = ad.Variable(name="A2", shape=[3, 3, 2]) A3 = ad.Variable(name="A3", shape=[3, 2]) X1 = ad.Variable(name="X1", shape=[3, 2, 2]) X2 = ad.Variable(name="X2", shape=[3, 3, 2, 2]) X3 = ad.Variable(name="X3", shape=[3, 2, 2]) """ The network and indices positions are as follows: A1 - f - A2 - g - A3 | | | c d e | | | X1 - a - X2 - b - X3 | | | h i j | | | A1 - k - A2 - l - A3 """ einsum_node_A1 = ad.einsum("ach,abdi,bej,fgd,kli,ge,lj->fckh", X1, X2, X3, A2, A2, A3, A3) einsum_node_A2 = ad.einsum("ach,abdi,bej,fc,kh,ge,lj->fgdkli", X1, X2, X3, A1, A1, A3, A3) einsum_node_A3 = ad.einsum("ach,abdi,bej,fc,kh,fgd,kli->gelj", X1, X2, X3, A1, A1, A2, A2) dt = generate_sequential_optimal_tree( [einsum_node_A1, einsum_node_A2, einsum_node_A3], [A1, A2, A3]) assert tree_eq(dt[0], einsum_node_A1, [X1, X2, X3, A1, A1, A2, A2, A3, A3]) assert tree_eq(dt[1], einsum_node_A2, [X1, X2, X3, A1, A1, A2, A2, A3, A3]) # In the correct contraction path, only X3 should be contracted with A3, # all other X nodes should be contracted later. einsum_inputs = list( filter(lambda node: isinstance(node, ad.EinsumNode), find_topo_sort(dt))) assert sorted(einsum_inputs[0].inputs, key=lambda node: node.name) == sorted( [A3, A3, X3], key=lambda node: node.name)
def _sub_hvp(self, inner_product_node, node_list): """Subroutine of hvp.""" file_string = '\n' file_string += indent_line( f'# backward pass of inner product of g and v starts') self.forward_to_hvp_map = ad.gradients_map(inner_product_node) self.hvp_to_forward_map = invert_dict(self.forward_to_hvp_map) hvp_nodes = [self.forward_to_hvp_map[node] for node in node_list] topo_order_hvps = find_topo_sort(hvp_nodes) for node in topo_order_hvps: if node not in self.forward_to_hvp_map.keys(): if node not in self.forward_to_hvp_map.values(): file_string += indent_line(self._assign_mid_variable(node)) else: forward_node = self.hvp_to_forward_map[node] file_string += indent_line( f'_grad2{forward_node.name} = {node.s2s_expr(node.inputs)}' ) node.name = f'_grad2{forward_node.name}' return file_string
def _sub_gradients(self, output_node, node_list): """Gradient pass subroutine.""" file_string = '' file_string += self._sub_forward([output_node]) file_string += '\n' file_string += indent_line('# backward pass starts') self.forward_to_grad_map = ad.gradients_map(output_node) self.grad_to_forward_map = invert_dict(self.forward_to_grad_map) self.gradient_list = [ self.forward_to_grad_map[node] for node in node_list ] self.topo_order_gradients = find_topo_sort(self.gradient_list) for node in self.topo_order_gradients: if node not in self.forward_to_grad_map.keys(): if node not in self.forward_to_grad_map.values(): file_string += indent_line(self._assign_mid_variable(node)) else: file_string += indent_line( self._assign_grad_variable(node)) return file_string
def test_dimension_tree_4d(): A = ad.Variable(name="A", shape=[2, 2]) B = ad.Variable(name="B", shape=[2, 2]) C = ad.Variable(name="C", shape=[2, 2]) D = ad.Variable(name="D", shape=[2, 2]) X = ad.Variable(name="X", shape=[2, 2, 2, 2]) einsum_node_A = ad.einsum("abcd,bm,cm,dm->am", X, B, C, D) einsum_node_B = ad.einsum("abcd,am,cm,dm->bm", X, A, C, D) einsum_node_C = ad.einsum("abcd,am,bm,dm->cm", X, A, B, D) einsum_node_D = ad.einsum("abcd,am,bm,cm->dm", X, A, B, C) dt = generate_sequential_optimal_tree( [einsum_node_A, einsum_node_B, einsum_node_C, einsum_node_D], [A, B, C, D]) # 5 inputs, 4 outputs, 5 intermedaites assert len(find_topo_sort(dt)) == 14 assert tree_eq(dt[0], einsum_node_A, [A, B, C, D, X]) assert tree_eq(dt[1], einsum_node_B, [A, B, C, D, X]) assert tree_eq(dt[2], einsum_node_C, [A, B, C, D, X]) assert tree_eq(dt[3], einsum_node_D, [A, B, C, D, X])
def simplify(output_node): """Simplify a graph with a single output node. The simplified form will distribute selected operations (+), and fuse all connected einsums. Args: node: The output node. Returns: node: The newly generated node. """ def fuse_all_einsums(node): linearize(node) ret_node = PseudoNode(node) all_pnodes = find_topo_sort_p([ret_node]) with OutputInjectedModeP(all_pnodes): trees = find_sub_einsumtree(ret_node) for tree in trees: out_node_p, in_nodes = tree new_z = fuse_einsums(out_node_p.node, in_nodes) prune_identity_nodes(new_z) replace_node(out_node_p, new_z) node = declone(ret_node.node) return node output_node = distribute_graph_w_linearize(output_node) output_node = fuse_all_einsums(output_node) output_pnode = PseudoNode(output_node) all_pnodes = find_topo_sort_p([output_pnode]) # optimize inverse with OutputInjectedModeP(all_pnodes): for pnode in all_pnodes: node = pnode.node if isinstance(node, ad.EinsumNode): # To make sure the same einsum nodes have the same same, # so that we can collapse the add node. rewrite_einsum_expr(node) if node.inputs != []: node.set_inputs(node.inputs) if isinstance(node, ad.TensorInverseNode): new_inv_node = optimize_inverse(node) replace_node(pnode, new_inv_node) # fuse again output_node = output_pnode.node output_node = fuse_all_einsums(output_node) # prune the orthonormal matmuls all_pnodes = find_topo_sort_p([output_pnode]) with OutputInjectedModeP(all_pnodes): for pnode in all_pnodes: node = pnode.node if node.inputs != []: node.set_inputs(node.inputs) if isinstance(node, ad.EinsumNode): new_node = prune_orthonormal_matmuls(node) replace_node(pnode, new_node) # prune inverse nodes output_pnode = PseudoNode(output_node) all_pnodes = find_topo_sort_p([output_pnode]) with OutputInjectedModeP(all_pnodes): for pnode in all_pnodes: node = pnode.node if node.inputs != []: node.set_inputs(node.inputs) if isinstance(node, ad.EinsumNode): new_node = prune_inv_node(node) replace_node(pnode, new_node) # prune the scalar nodes and remove unnecessary identity nodes all_pnodes = find_topo_sort_p([output_pnode]) with OutputInjectedModeP(all_pnodes): for pnode in all_pnodes: node = pnode.node if node.inputs != []: node.set_inputs(node.inputs) if isinstance(node, ad.EinsumNode): prune_identity_nodes(node) new_node = prune_scalar_nodes(node) replace_node(pnode, new_node) # collapse symmetric expressions all_pnodes = find_topo_sort_p([output_pnode]) for i in range(len(all_pnodes)): for j in range(i): collapse_symmetric_expr(all_pnodes[i].node, all_pnodes[j].node) sympy_input_types = (ad.DistributiveNode, ad.ScalarNode, ad.MulNode) #sympy_simplify the distributed nodes if isinstance(output_node, ad.DistributiveNode): sympy_inputs = [] all_nodes = find_topo_sort([output_node]) for node in all_nodes: if isinstance(node, ad.EinsumNode): # To make sure the same einsum nodes have the same name, # so that they can be reduced by sympy. rewrite_einsum_expr(node) if node.inputs != []: node.set_inputs(node.inputs) if not isinstance(node, sympy_input_types): sympy_inputs.append(node) output_node = sympy_simplify(output_node, sympy_inputs) return output_node
def fuse_einsums(output_node, input_nodes): """ Find and fuse einsums. Parameters: Each node must have attribute inputs, which makes it a sparse graph representation. Returns: A graph with fused intermediate einsum nodes. Represented by output_node. Note: inputs of a node can have same node. But one node can't go to two output nodes """ # First assume everything einsum. logger.info('Start fusing einsum') # Making this automatic. # Assume output_node is einsum and their children are einsum of any number # of input nodes assert (isinstance(output_node, ad.EinsumNode)) pseudo_nodes = [] # # Get all the einsum nodes except the input nodes in the computation graph. # # Note that the order doesn't matter! all_nodes = find_topo_sort([output_node], input_nodes) pseudo_input_nodes = [] pseudo_output_node = None # We first represennt each dim as a different character, and then union. # Create a map for k, node in enumerate(all_nodes): node.dims_info = [ DimInfo(node=node, dim_index=i, node_index=k) for i in range(len(node.shape)) ] pnode = PseudoNode(node=node, dims_info=node.dims_info) pseudo_nodes.append(pnode) if node in input_nodes: pseudo_input_nodes.append(pnode) if node == output_node: pseudo_output_node = pnode intermediate_nodes = list(set(pseudo_nodes) - set(pseudo_input_nodes)) einsum_pseudo_nodes = list( filter(lambda x: isinstance(x.node, ad.EinsumNode), intermediate_nodes)) all_dims_info = sum([node.dims_info for node in pseudo_nodes], []) # For any two dims with the same literal, get their pos and connect. uf = UF(all_dims_info) for node in einsum_pseudo_nodes: all_dims_info = node_dims_info(node) cross_einsum_connect(uf, node.node, all_dims_info) uf.assign() # Assign literals for node in pseudo_nodes: node.generate_subscript(uf) new_input_subs = [node.subscript for node in pseudo_input_nodes] new_subscripts = ",".join( new_input_subs) + "->" + pseudo_output_node.subscript logger.info(f"Generated new subscript: {new_subscripts}") ########################################## output_node = ad.einsum(new_subscripts, *[node.node for node in pseudo_input_nodes]) return output_node