def test_einsum_multitier(backendopt): for datatype in backendopt: T.set_backend(datatype) input_nodes1, zs1 = get_tree("set1") input_nodes2, zs2 = get_tree("set2") out1 = zs1 + zs2 input_nodes3, zs3 = get_tree("set3") input_nodes4, zs4 = get_tree("set4") out2 = zs3 + zs4 out = ad.einsum("ij, jk->ik", out1, out2) input_nodes = input_nodes1 + input_nodes2 + input_nodes3 + input_nodes4 generated_feed_dict = gen_dict(input_nodes) executor = ad.Executor([out]) z_val, = executor.run(feed_dict=generated_feed_dict) with OutputInjectedModeP(find_topo_sort_p([PseudoNode(out)])): trees = find_sub_einsumtree(PseudoNode(out)) for tree in trees: out_node, in_nodes = tree new_z = fuse_einsums(out_node.node, in_nodes) replace_node(out_node, new_z) executor = ad.Executor([out]) z_new_val, = executor.run(feed_dict=generated_feed_dict) assert float_eq(z_val, z_new_val)
def find_sub_einsumtree(output_node_p): # TMP Pseudo Mode. """ Finds all the subtrees from the given graph definition. There can be overlap of different subtrees. Arguments: output_node_p: the root of the tree, must be PseudoNode. input_nodes: leaf of the tree Returns: Return many einsum trees of the form [[Pseudo root node, leaf nodes], ... ] """ trees = [] output_node = output_node_p.node if isinstance(output_node, ad.EinsumNode): tree_nodes = get_all_einsum_descendants(output_node) leaves = get_leaves(tree_nodes) for leaf in leaves: new_trees = find_sub_einsumtree(PseudoNode(leaf)) trees += new_trees trees.append([output_node_p, leaves]) return trees else: for i_node in output_node.inputs: new_trees = find_sub_einsumtree(PseudoNode(i_node)) trees += new_trees return trees
def split_inv_einsum(inv_node): """ Optimize the inverse of an einsum expression, such that inverse is operated on several smaller tensors. Parameters ---------- node: The inverse of a fused einsum node Returns ------- If the input node cannot be optimized, then return the input node. If it can be optimized, return the optimized node. """ einsum_node = inv_node.inputs[0] assert isinstance(einsum_node, ad.EinsumNode) # einsum_node is a fused einsum for node in einsum_node.inputs: assert not isinstance(node, ad.EinsumNode) in_subs, out_subs, _ = parse_einsum_input( (einsum_node.einsum_subscripts, *einsum_node.inputs)) in_subs_list = in_subs.split(',') p_einsum_node = PseudoNode(node=einsum_node, subscript=out_subs) p_in_nodes = [] for i, node in enumerate(einsum_node.inputs): p_in_nodes.append(PseudoNode(node=node, subscript=in_subs_list[i])) dsets = inv_disjoint_sets(p_einsum_node, p_in_nodes) # If the node cannot be decomposed, just return the input node if len(dsets) == 1: return inv_node new_inputs = [] for dset in dsets: input_decomp_einsum = list( filter(lambda node: any(char in dset for char in node.subscript), p_in_nodes)) out_subs = "".join( [char for char in p_einsum_node.subscript if char in dset]) decomp_node = generate_new_einsum(input_decomp_einsum, out_subs) decomp_node.set_in_indices_length(int(len(out_subs) / 2)) input_node = PseudoNode(node=ad.tensorinv(decomp_node), subscript=out_subs) new_inputs.append(input_node) return generate_new_einsum(new_inputs, p_einsum_node.subscript)
def test_einsum_subtree_clone(backendopt): """ [Subtree clone] This case is rather subtle. We want to auto fuse A B C D | \ / | | es | | / \ | | / \ | es es \ / + Here es is einsum. """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 3]) b = ad.Variable(name="b", shape=[3, 2]) c = ad.Variable(name="c", shape=[2, 3]) d = ad.Variable(name="d", shape=[3, 3]) BC = ad.einsum('ik, kj->ij', b, c) # 3x3 ABC = ad.einsum('ik, kj->ij', a, BC) # 3x3 BCD = ad.einsum('jk, ki->ji', BC, d) # 3x3 out = ABC + BCD input_nodes = [a, b, c, d] generated_feed_dict = gen_dict(input_nodes) executor = ad.Executor([out]) out_val, = executor.run(feed_dict=generated_feed_dict) with OutputInjectedModeP(find_topo_sort_p([PseudoNode(out)])): trees = find_sub_einsumtree(PseudoNode(out)) assert len(trees) == 2 for tree in trees: out_node, in_nodes = tree new_z = fuse_einsums(out_node.node, in_nodes) replace_node(out_node, new_z) new_out_val, = executor.run(feed_dict=generated_feed_dict) assert float_eq(out_val, new_out_val)
def dedup_transpose(graph, node, trans_node, trans_indices): """ Replace the node with the trans_node, and change its output nodes in graph accordingly. Parameters ---------- graph: list of nodes denoting a connected graph. node: node to be replaced. trans_node: the transposed node that will replace node. trans_indices: the transpose indices. """ assert node in graph assert trans_node in graph with OutputInjectedModeP([PseudoNode(n) for n in graph]): for onode in node.outputs: # NOTE: currently we cannot deal with non-einsum nodes. assert isinstance(onode, ad.EinsumNode) in_subs, out_subs, _ = parse_einsum_input( (onode.einsum_subscripts, *onode.inputs)) in_subs_list = in_subs.split(',') for (i, n) in enumerate(onode.inputs): if n is node: onode.inputs[i] = trans_node str_list = list(in_subs_list[i]) in_subs_list[i] = "".join( [str_list[j] for j in trans_indices]) new_subscripts = ",".join(in_subs_list) + "->" + out_subs onode.einsum_subscripts = new_subscripts onode.set_inputs(onode.inputs)
def optimize(node): """Optimize a graph with a single output node. Args: node: The output node. Returns: node: The newly generated node. """ node = distribute_tree(node) linearize(node) all_nodes = find_topo_sort([node]) ret_node = PseudoNode(node) with OutputInjectedMode(all_nodes): trees = find_sub_einsumtree(ret_node) for tree in trees: out_node_p, in_nodes = tree new_z = fuse_einsums(out_node_p.node, in_nodes) prune_identity_nodes(new_z) new_z = generate_optimal_tree(new_z) replace_node(out_node_p, new_z) node = declone(ret_node.node) all_nodes = find_topo_sort([node]) for node in all_nodes: if isinstance(node, ad.EinsumNode): rewrite_einsum_expr(node) for node in find_topo_sort([node]): if node.inputs != []: node.set_inputs(node.inputs) dedup(node) return node
def test_einsum_fuse_graph(backendopt): """ [Fuse einsum used twice] This case is rather subtle. We want to auto fuse A B C | \ / | es | /| | / | es | \ | es Here es is einsum. """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 3]) b = ad.Variable(name="b", shape=[3, 2]) c = ad.Variable(name="c", shape=[2, 3]) BC = ad.einsum('ik, kj->ij', b, c) # 3x3 ABC = ad.einsum('ik, kj->ij', a, BC) # 3x3 out = ad.einsum('jk, ki->ji', ABC, BC) # 3x3 linearize(out) tree, = find_sub_einsumtree(PseudoNode(out)) out, ins = tree new_z = fuse_einsums(out.node, ins) assert tree_eq(out.node, new_z, [a, b, c])
def test_einsum_fuse_w_identity(backendopt): """ [Fuse einsum with multiple identities] We want to fuse A identity identity | \ / | \ / | \ / | es | / | / | / | / es Here es is einsum. """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 3]) es_identity = ad.einsum('ik,kj->ij', ad.identity(3), ad.identity(3)) out = ad.einsum('ai,ij->aj', a, es_identity) tree, = find_sub_einsumtree(PseudoNode(out)) out, ins = tree new_out = fuse_einsums(out.node, ins) assert tree_eq(out.node, new_out, [a])
def prune_identity_nodes(einsum_node): """ reduce the number of identity nodes in the einsum_node's inputs. Inplace update. Args: einsum_node: An fused einsum node. """ if not (isinstance(einsum_node, ad.EinsumNode)): return uf_str, p_outnode, p_innodes = generate_einsum_info(einsum_node) whole_str = p_outnode.subscript + "".join( [node.subscript for node in p_innodes]) p_identity_nodes = list( filter(lambda pnode: isinstance(pnode.node, ad.IdentityNode), p_innodes)) p_variable_nodes = [ pnode for pnode in p_innodes if pnode not in p_identity_nodes ] # each disjoint set in uf_identity represents the indices # linked by identity node uf_identity = UF(list(whole_str)) for pnode in p_identity_nodes: uf_identity.connect(pnode.subscript[0], pnode.subscript[1]) input_indices_set, output_indices_set = set(), set() for pnode in p_variable_nodes: # replace subscripts by the root chars sub_list = [uf_identity.root(char) for char in pnode.subscript] pnode.subscript = "".join(sub_list) input_indices_set |= set(sub_list) p_updated_inputs = p_variable_nodes out_sub_list = [] for i, char in enumerate(p_outnode.subscript): uf_root_char = uf_identity.root(char) if uf_root_char in output_indices_set: # we cannot assign the same char to two indices in the # output. Therefore, assign a new char, and add one # identity node to the inputs to show the constraint. new_char = uf_str.cg.getchar() out_sub_list.append(new_char) p_identity_node = PseudoNode(node=ad.identity( einsum_node.shape[i]), subscript=f"{uf_root_char}{new_char}") p_updated_inputs.append(p_identity_node) else: # directly assign the root char to the subscripts out_sub_list.append(uf_root_char) output_indices_set.add(uf_root_char) p_outnode.subscript = "".join(out_sub_list) new_input_subs = [pnode.subscript for pnode in p_updated_inputs] new_subscripts = ",".join(new_input_subs) + "->" + p_outnode.subscript einsum_node.einsum_subscripts = new_subscripts einsum_node.set_inputs([pnode.node for pnode in p_updated_inputs])
def distribute_tree(output): """ Distribute a tree of einsum and add nodes. NOTE: the output node should be a linearized node. Behavior undefined if there are other kind of nodes. Args: output: The output of a tree. Returns: output: a newly generated node with add operands distributed. Idea: 1. Construct the output tree. 2. Find binary op. 3. Apply distribute. 4. Iterate 1->3 """ def get_first_binary_op(pnodes): for pnode in pnodes: node = pnode.node if isinstance(node, ad.DistributiveNode) and len(node.outputs) >= 1: has_einsum_nodes = all( [isinstance(x, ad.EinsumNode) for x in node.outputs]) if has_einsum_nodes: return node return None while 1: all_pnodes = find_topo_sort_p([PseudoNode(output)]) with OutputInjectedModeP(all_pnodes): first_binary_op = get_first_binary_op(all_pnodes) if first_binary_op is None: break for einsum_node in first_binary_op.outputs: if isinstance(einsum_node, ad.DistributiveNode): continue assert isinstance(einsum_node, ad.EinsumNode) new_node = _distribute(first_binary_op, einsum_node) replace_node(PseudoNode(einsum_node), new_node) if einsum_node == output: output = new_node # This is need for source generation. output.set_inputs(output.inputs) return output
def test_einsum_fuse_only_identity(backendopt): for datatype in backendopt: T.set_backend(datatype) es_identity = ad.einsum('ik,kj->ij', ad.identity(3), ad.identity(3)) out = ad.einsum('ai,ij->aj', ad.identity(3), es_identity) tree, = find_sub_einsumtree(PseudoNode(out)) out, ins = tree new_out = fuse_einsums(out.node, ins) assert tree_eq(out.node, new_out, [])
def fuse_all_einsums(node): linearize(node) ret_node = PseudoNode(node) all_pnodes = find_topo_sort_p([ret_node]) with OutputInjectedModeP(all_pnodes): trees = find_sub_einsumtree(ret_node) for tree in trees: out_node_p, in_nodes = tree new_z = fuse_einsums(out_node_p.node, in_nodes) prune_identity_nodes(new_z) replace_node(out_node_p, new_z) node = declone(ret_node.node) return node
def test_einsum_find_subtree_after_linearization(backendopt): """ An einsum graph like A B inputs |\ | | \ | | \ | | C | / | / output will produce An einsum graph like A B inputs |\ | | A1 | | \ | A2 C | / | / output The subtree inputs must then be [A1, A2, B] rather than A, B. """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a1", shape=[3, 2]) b = ad.Variable(name="b", shape=[2, 3]) c = ad.einsum('ik,kj->ij', a, b) output = ad.einsum('ik,ij->kj', a, c) feed_dict = gen_dict([a, b]) executor = ad.Executor([output]) out_val, = executor.run(feed_dict=feed_dict) # New graph linearize(output) tree, = find_sub_einsumtree(PseudoNode(output)) assert (len(tree[1]) == 3)
def dedup(*nodes): """Remove the duplicate nodes with same name. Args: nodes: One or many nodes. """ assert len(nodes) > 0 topo_order = find_topo_sort_p([PseudoNode(n) for n in nodes]) with OutputInjectedModeP(topo_order): unique_nodes_map = {} unique_nodes = set() # Use the last occurrence. for ptmp in topo_order: tmp = ptmp.node unique_nodes_map[tmp.name] = tmp unique_nodes = set(unique_nodes_map.values()) for ptmp in topo_order: tmp = ptmp.node if tmp not in unique_nodes: unique_copy = unique_nodes_map[tmp.name] replace_node(ptmp, unique_copy)
def linearize(output_node): """Linearize a graph by adding clone nodes for computation optimization. Args: output_node: A single node. Returns: None. Update is inplace. NOTE: If you ever need to debug this function, the generated name is inconsistent becasue of the added edges. """ # Need to create new nodes for whichever node that has 2 or more outgoing edges. all_pnodes = find_topo_sort_p([PseudoNode(output_node)]) # Inject outputs relationship. with OutputInjectedModeP(all_pnodes): for pn in all_pnodes: n = pn.node if len(n.outputs) > 1: for n_o in set(n.outputs): n_o.set_inputs([ tmp if tmp.name != n.name else copy_tree(n) for tmp in n_o.inputs ])
def prune_single_inv_node(einsum_node, inv_node): """ Prune the inv_node in the einsum node if condition mets. Note: 1. can only optimize the node when the input of inv is an einsum node. 2. only supports the case when the splitted nodes are different from the remaining ones. For example: ad.einsum("ab,bc,cd,de->ae", inv("ab,bc->ac", A, B), A, B, C) will be optimzied to ad.einsum("ab,bc->ac", C, ad.identity()), but we cannot optimize ad.einsum("ab,bc,cd,de->ae", inv("ab,bc->ac", A, B), A, B, B). Parameters ---------- einsum_node: The fused einsum node inv_node: the input inv node to be pruned Returns ------- If the einsum_node cannot be optimized, then return the input einsum_node. If it can be optimized, return the optimized einsum node. """ from autohoot.einsum_graph.expr_generator import rewrite_einsum_expr from autohoot.graph_ops.optimal_tree import split_einsum inv_node_input = inv_node.inputs[0] if not isinstance(inv_node_input, ad.EinsumNode): logger.info(f"inv input is not einsum node, can't prune inv") return einsum_node if not set(inv_node_input.inputs).issubset(set(einsum_node.inputs)): logger.info( f"inv inputs is not subset of einsum node inputs, can't prune inv") return einsum_node einsum_inputs_in_inv = [ n for n in einsum_node.inputs if n in inv_node_input.inputs ] if len(einsum_inputs_in_inv) < len(inv_node_input.inputs): logger.info( f"number of inv inputs is more than corresponding einsum inputs, can't prune inv" ) return einsum_node split_einsum_node = split_einsum( einsum_node, list(set(einsum_node.inputs) - set(inv_node_input.inputs))) # Assign pseudo nodes and chars in_subs, out_subs, _ = parse_einsum_input( (split_einsum_node.einsum_subscripts, *split_einsum_node.inputs)) in_subs_list = in_subs.split(',') updated_p_in_nodes = [] for i, node in enumerate(split_einsum_node.inputs): if isinstance(node, ad.EinsumNode): p_einsum_input = PseudoNode(node=node, subscript=in_subs_list[i]) elif node is inv_node: p_inv_input = PseudoNode(node=node, subscript=in_subs_list[i]) else: updated_p_in_nodes.append( PseudoNode(node=node, subscript=in_subs_list[i])) contract_char = "".join( set(p_einsum_input.subscript) & set(p_inv_input.subscript)) uncontract_str = "".join( set("".join([p_einsum_input.subscript, p_inv_input.subscript])) - set(contract_char)) if not (len(p_einsum_input.subscript) == 2 and len(p_inv_input.subscript) == 2 and len(contract_char) == 1 and len(uncontract_str) == 2): # this is not a matmul. Just return the initial node logger.info( f"the op between inv input and the selected einsum is not matmul, can't prune inv" ) return einsum_node if p_einsum_input.subscript[0] == p_inv_input.subscript[ 0] or p_einsum_input.subscript[1] == p_inv_input.subscript[1]: # the str is like "ab,ac", and one einsum needs to be transposed to compare p_in_subs, p_out_subs, _ = parse_einsum_input( (p_einsum_input.node.einsum_subscripts, *p_einsum_input.node.inputs)) einsum_input = ad.einsum( f"{p_in_subs}->{p_out_subs[1]}{p_out_subs[0]}", *p_einsum_input.node.inputs) else: einsum_input = p_einsum_input.node rewrite_einsum_expr(einsum_input) rewrite_einsum_expr(inv_node_input) if einsum_input.name != inv_node_input.name: logger.info( f"inv input and the selected einsum have different expressions, can't prune inv" ) return einsum_node # prune the inv node updated_p_in_nodes = updated_p_in_nodes + [ PseudoNode(node=ad.identity(inv_node_input.shape[0]), subscript=uncontract_str) ] return generate_new_einsum(updated_p_in_nodes, out_subs)
def rewrite_einsum_expr(einsum_node): """ Rewrites the einsum expression of a node. Inplace update. Args: einsum_node: Allow duplicate inputs of the einsum node. Returns: uf (type: graph_ops.graph_optimizer.UF): the union_find set of the input """ assert (isinstance(einsum_node, ad.EinsumNode)) input_nodes = einsum_node.inputs # TODO: Get all the einsum nodes in the computation graph. # Note that the order matters! pseudo_nodes = [] # Here einsum node has a temporary name so that the character assignment # order is consistent. einsum_node_dims_info = [ DimInfo(node=einsum_node, dim_index=i, temp_node_name='_temp_einsum') for i in range(len(einsum_node.shape)) ] pseudo_nodes.append( PseudoNode(node=einsum_node, dims_info=einsum_node_dims_info)) for k, node in enumerate(einsum_node.inputs): dims_info = [ DimInfo(node=node, dim_index=i, node_index=k) for i in range(len(node.shape)) ] pseudo_nodes.append(PseudoNode(node=node, dims_info=dims_info)) all_dims_info = sum([node.dims_info for node in pseudo_nodes], []) # For any two dims with the same literal, get their pos and connect. uf = UF(all_dims_info) cross_einsum_connect(uf, einsum_node, all_dims_info) uf.assign() # Assign literals for node in pseudo_nodes: node.generate_subscript(uf) einsum_node_subscript = pseudo_nodes[0].subscript # Remove the einsum node. pseudo_nodes.pop(0) # Sort based on both the node name and subscript. pseudo_nodes = sorted(pseudo_nodes, key=lambda pnode: pnode.node.name + pnode.subscript) new_input_subs = [pnode.subscript for pnode in pseudo_nodes] new_subscripts = ",".join(new_input_subs) + "->" + einsum_node_subscript einsum_node.einsum_subscripts = new_subscripts einsum_node.set_inputs([pnode.node for pnode in pseudo_nodes]) logger.info(f"Rewrite to new subscript: {new_subscripts}") return uf
def prune_orthonormal_matmuls(einsum_node): """ Remove the matrices of a einsum_node if M @ M.T like structures exist. Args: einsum_node: An fused einsum node. Return: An optimized einsum node. """ # A map from the orthonormal matrix mode to (orthonormal_index, contraction_index) orthonormal_indices_map = {'column': (0, 1), 'row': (1, 0)} _, p_outnode, p_innodes = generate_einsum_info(einsum_node) subs_list = [pnode.subscript for pnode in p_innodes] + [p_outnode.subscript] ortho_pnode_map = {} for pnode in p_innodes: if isinstance(pnode.node, ad.MatrixNode) and pnode.node.orthonormal != None: nodename = pnode.node.name if nodename in ortho_pnode_map: ortho_pnode_map[nodename].append(pnode) else: ortho_pnode_map[nodename] = [pnode] for pnodes in ortho_pnode_map.values(): if len(pnodes) < 2: continue remaining_pnodes = pnodes pnodes_subs = list(itertools.combinations(pnodes, 2)) for pnodes_binary_input in pnodes_subs: if not set(pnodes_binary_input).issubset(set(remaining_pnodes)): continue pnode_A, pnode_B = pnodes_binary_input o_index, c_index = orthonormal_indices_map[ pnode_A.node.orthonormal] # Criteria for the pruning: the o_index of two inputs are different, # and the c_index only appear in these two nodes. c_index_is_equal = pnode_A.subscript[c_index] == pnode_B.subscript[ c_index] o_index_not_equal = pnode_A.subscript[ o_index] != pnode_B.subscript[o_index] if not (c_index_is_equal and o_index_not_equal): continue num_subs_w_cindex = len( list( filter(lambda subs: pnode_A.subscript[c_index] in subs, subs_list))) if not num_subs_w_cindex == 2: continue remaining_pnodes = [ pnode for pnode in remaining_pnodes if not pnode in pnodes_binary_input ] p_innodes = [ pnode for pnode in p_innodes if not pnode in pnodes_binary_input ] i_node = ad.identity(pnode_A.node.shape[o_index]) i_subs = f"{pnode_A.subscript[o_index]}{pnode_B.subscript[o_index]}" p_innodes.append(PseudoNode(node=i_node, subscript=i_subs)) new_input_subs = [pnode.subscript for pnode in p_innodes] new_subscripts = ",".join(new_input_subs) + "->" + p_outnode.subscript new_inputs = [pnode.node for pnode in p_innodes] return ad.einsum(new_subscripts, *new_inputs)
def simplify(output_node): """Simplify a graph with a single output node. The simplified form will distribute selected operations (+), and fuse all connected einsums. Args: node: The output node. Returns: node: The newly generated node. """ def fuse_all_einsums(node): linearize(node) ret_node = PseudoNode(node) all_pnodes = find_topo_sort_p([ret_node]) with OutputInjectedModeP(all_pnodes): trees = find_sub_einsumtree(ret_node) for tree in trees: out_node_p, in_nodes = tree new_z = fuse_einsums(out_node_p.node, in_nodes) prune_identity_nodes(new_z) replace_node(out_node_p, new_z) node = declone(ret_node.node) return node output_node = distribute_graph_w_linearize(output_node) output_node = fuse_all_einsums(output_node) output_pnode = PseudoNode(output_node) all_pnodes = find_topo_sort_p([output_pnode]) # optimize inverse with OutputInjectedModeP(all_pnodes): for pnode in all_pnodes: node = pnode.node if isinstance(node, ad.EinsumNode): # To make sure the same einsum nodes have the same same, # so that we can collapse the add node. rewrite_einsum_expr(node) if node.inputs != []: node.set_inputs(node.inputs) if isinstance(node, ad.TensorInverseNode): new_inv_node = optimize_inverse(node) replace_node(pnode, new_inv_node) # fuse again output_node = output_pnode.node output_node = fuse_all_einsums(output_node) # prune the orthonormal matmuls all_pnodes = find_topo_sort_p([output_pnode]) with OutputInjectedModeP(all_pnodes): for pnode in all_pnodes: node = pnode.node if node.inputs != []: node.set_inputs(node.inputs) if isinstance(node, ad.EinsumNode): new_node = prune_orthonormal_matmuls(node) replace_node(pnode, new_node) # prune inverse nodes output_pnode = PseudoNode(output_node) all_pnodes = find_topo_sort_p([output_pnode]) with OutputInjectedModeP(all_pnodes): for pnode in all_pnodes: node = pnode.node if node.inputs != []: node.set_inputs(node.inputs) if isinstance(node, ad.EinsumNode): new_node = prune_inv_node(node) replace_node(pnode, new_node) # prune the scalar nodes and remove unnecessary identity nodes all_pnodes = find_topo_sort_p([output_pnode]) with OutputInjectedModeP(all_pnodes): for pnode in all_pnodes: node = pnode.node if node.inputs != []: node.set_inputs(node.inputs) if isinstance(node, ad.EinsumNode): prune_identity_nodes(node) new_node = prune_scalar_nodes(node) replace_node(pnode, new_node) # collapse symmetric expressions all_pnodes = find_topo_sort_p([output_pnode]) for i in range(len(all_pnodes)): for j in range(i): collapse_symmetric_expr(all_pnodes[i].node, all_pnodes[j].node) sympy_input_types = (ad.DistributiveNode, ad.ScalarNode, ad.MulNode) #sympy_simplify the distributed nodes if isinstance(output_node, ad.DistributiveNode): sympy_inputs = [] all_nodes = find_topo_sort([output_node]) for node in all_nodes: if isinstance(node, ad.EinsumNode): # To make sure the same einsum nodes have the same name, # so that they can be reduced by sympy. rewrite_einsum_expr(node) if node.inputs != []: node.set_inputs(node.inputs) if not isinstance(node, sympy_input_types): sympy_inputs.append(node) output_node = sympy_simplify(output_node, sympy_inputs) return output_node
def fuse_einsums(output_node, input_nodes): """ Find and fuse einsums. Parameters: Each node must have attribute inputs, which makes it a sparse graph representation. Returns: A graph with fused intermediate einsum nodes. Represented by output_node. Note: inputs of a node can have same node. But one node can't go to two output nodes """ # First assume everything einsum. logger.info('Start fusing einsum') # Making this automatic. # Assume output_node is einsum and their children are einsum of any number # of input nodes assert (isinstance(output_node, ad.EinsumNode)) pseudo_nodes = [] # # Get all the einsum nodes except the input nodes in the computation graph. # # Note that the order doesn't matter! all_nodes = find_topo_sort([output_node], input_nodes) pseudo_input_nodes = [] pseudo_output_node = None # We first represennt each dim as a different character, and then union. # Create a map for k, node in enumerate(all_nodes): node.dims_info = [ DimInfo(node=node, dim_index=i, node_index=k) for i in range(len(node.shape)) ] pnode = PseudoNode(node=node, dims_info=node.dims_info) pseudo_nodes.append(pnode) if node in input_nodes: pseudo_input_nodes.append(pnode) if node == output_node: pseudo_output_node = pnode intermediate_nodes = list(set(pseudo_nodes) - set(pseudo_input_nodes)) einsum_pseudo_nodes = list( filter(lambda x: isinstance(x.node, ad.EinsumNode), intermediate_nodes)) all_dims_info = sum([node.dims_info for node in pseudo_nodes], []) # For any two dims with the same literal, get their pos and connect. uf = UF(all_dims_info) for node in einsum_pseudo_nodes: all_dims_info = node_dims_info(node) cross_einsum_connect(uf, node.node, all_dims_info) uf.assign() # Assign literals for node in pseudo_nodes: node.generate_subscript(uf) new_input_subs = [node.subscript for node in pseudo_input_nodes] new_subscripts = ",".join( new_input_subs) + "->" + pseudo_output_node.subscript logger.info(f"Generated new subscript: {new_subscripts}") ########################################## output_node = ad.einsum(new_subscripts, *[node.node for node in pseudo_input_nodes]) return output_node