def optimize(node): """Optimize a graph with a single output node. Args: node: The output node. Returns: node: The newly generated node. """ node = distribute_tree(node) linearize(node) all_nodes = find_topo_sort([node]) ret_node = PseudoNode(node) with OutputInjectedMode(all_nodes): trees = find_sub_einsumtree(ret_node) for tree in trees: out_node_p, in_nodes = tree new_z = fuse_einsums(out_node_p.node, in_nodes) prune_identity_nodes(new_z) new_z = generate_optimal_tree(new_z) replace_node(out_node_p, new_z) node = declone(ret_node.node) all_nodes = find_topo_sort([node]) for node in all_nodes: if isinstance(node, ad.EinsumNode): rewrite_einsum_expr(node) for node in find_topo_sort([node]): if node.inputs != []: node.set_inputs(node.inputs) dedup(node) return node
def generate_sequential_optimal_tree(einsum_nodes, input_nodes): """ Regenerating einsum expressions based on the dimension tree. Parameters ---------- einsum_nodes : list List of einsum nodes to be calculated based on the dimension tree. input_nodes : list List of input nodes whose contraction in the einsum_nodes obeys the sequence from the list end to the list start. Returns ------- List of einsum nodes whose results are the same as einsum_nodes, while obeys the dimension tree calculation sequence. Examples -------- >>> einsum_node_A = ad.einsum("abcd,bm,cm,dm->am", X, B, C, D) >>> einsum_node_B = ad.einsum("abcd,am,cm,dm->bm", X, A, C, D) >>> einsum_node_C = ad.einsum("abcd,am,bm,dm->cm", X, A, B, D) >>> dt = generate_sequential_optimal_tree([einsum_node_A, einsum_node_B, einsum_node_C], [A, B, C]) >>> dt [ad.einsum('bm,abm->am', B, ad.einsum('cm,abcm->abm', C, ad.einsum('abcd,dm->abcm', X, D))), ad.einsum('am,abm->bm', A, ad.einsum('cm,abcm->abm', C, ad.einsum('abcd,dm->abcm', X, D))), ad.einsum('am,bm,abcm->cm', A, B, ad.einsum('abcd,dm->abcm', X, D)), ] (einsum strings may be different) """ if len(einsum_nodes) == 1 and len(input_nodes) == 1: return einsum_nodes new_nodes = [] for (i, node) in enumerate(einsum_nodes): contract_order = input_nodes[i + 1:] contract_order.reverse() contract_order = contract_order + input_nodes[:i] # get the subarray that is the inputs of node contract_order = list( filter(lambda n: n in node.inputs, contract_order)) new_nodes.append( generate_optimal_tree_w_constraint(node, contract_order)) # After generate_optimal_tree_w_constraint, some einstrs are not in the canonical format, # needs to rewrite again for dedup all_nodes = find_topo_sort(new_nodes) with OutputInjectedMode(all_nodes): for node in all_nodes: if isinstance(node, ad.EinsumNode): rewrite_einsum_expr(node) if node.inputs != []: node.set_inputs(node.inputs) dedup(*new_nodes) remove_transposes(find_topo_sort(new_nodes)) return new_nodes
def test_einsum_rewrite_duplicate_input(backendopt): a = ad.Variable(name="a", shape=[3, 2]) x = ad.einsum('ca,cb->ab', a, a) y = ad.einsum('cb,ca->ab', a, a) rewrite_einsum_expr(x) rewrite_einsum_expr(y) assert x.einsum_subscripts == y.einsum_subscripts
def test_einsum_equal_repeated_transpose(backendopt): A = ad.Variable(name="A", shape=[3, 5]) x = ad.einsum('or,ob->br', A, A) y = ad.einsum('eb,ed->bd', A, A) uf1 = rewrite_einsum_expr(x) uf2 = rewrite_einsum_expr(y) assert x.einsum_subscripts == y.einsum_subscripts assert x.inputs == y.inputs
def test_einsum_equal_repeated_transpose(backendopt): A = ad.Variable(name="A", shape=[3, 3]) B = ad.Variable(name="B", shape=[3, 3]) x = ad.einsum("ac,ba,bc->", A, A, B) y = ad.einsum("ba,ac,bc->", A, A, B) uf1 = rewrite_einsum_expr(x) uf2 = rewrite_einsum_expr(y) assert x.einsum_subscripts == y.einsum_subscripts assert x.inputs == y.inputs
def test_einsum_equal(backendopt): a1 = ad.Variable(name="a1", shape=[3, 2]) a2 = ad.Variable(name="a2", shape=[2, 3]) x = ad.einsum('ik,kj->ij', a1, a2) y = ad.einsum('ml,sm->sl', a2, a1) rewrite_einsum_expr(x) rewrite_einsum_expr(y) assert x.einsum_subscripts == y.einsum_subscripts assert x.inputs == y.inputs
def test_einsum_equal_uf_assign_order(backendopt): A = ad.Variable(name="A", shape=[3, 3]) B = ad.Variable(name="B", shape=[3, 3]) I = ad.identity(10) x = ad.einsum('pb,or,ob,pr,st->srtb', B, A, A, B, I) y = ad.einsum('eb,ed,fb,fd,ac->abcd', A, A, B, B, I) uf1 = rewrite_einsum_expr(x) uf2 = rewrite_einsum_expr(y) assert x.einsum_subscripts == y.einsum_subscripts assert x.inputs == y.inputs
def test_rewrite_expr(backendopt): """ Test rewrite the einsum expression. """ a1 = ad.Variable(name="a1", shape=[3, 2]) a2 = ad.Variable(name="a2", shape=[2, 3]) x = ad.einsum('ik,kj->ij', a1, a2) y = ad.einsum('sm,ml->sl', a1, a2) rewrite_einsum_expr(x) rewrite_einsum_expr(y) assert x.einsum_subscripts == y.einsum_subscripts
def prune_single_inv_node(einsum_node, inv_node): """ Prune the inv_node in the einsum node if condition mets. Note: 1. can only optimize the node when the input of inv is an einsum node. 2. only supports the case when the splitted nodes are different from the remaining ones. For example: ad.einsum("ab,bc,cd,de->ae", inv("ab,bc->ac", A, B), A, B, C) will be optimzied to ad.einsum("ab,bc->ac", C, ad.identity()), but we cannot optimize ad.einsum("ab,bc,cd,de->ae", inv("ab,bc->ac", A, B), A, B, B). Parameters ---------- einsum_node: The fused einsum node inv_node: the input inv node to be pruned Returns ------- If the einsum_node cannot be optimized, then return the input einsum_node. If it can be optimized, return the optimized einsum node. """ from autohoot.einsum_graph.expr_generator import rewrite_einsum_expr from autohoot.graph_ops.optimal_tree import split_einsum inv_node_input = inv_node.inputs[0] if not isinstance(inv_node_input, ad.EinsumNode): logger.info(f"inv input is not einsum node, can't prune inv") return einsum_node if not set(inv_node_input.inputs).issubset(set(einsum_node.inputs)): logger.info( f"inv inputs is not subset of einsum node inputs, can't prune inv") return einsum_node einsum_inputs_in_inv = [ n for n in einsum_node.inputs if n in inv_node_input.inputs ] if len(einsum_inputs_in_inv) < len(inv_node_input.inputs): logger.info( f"number of inv inputs is more than corresponding einsum inputs, can't prune inv" ) return einsum_node split_einsum_node = split_einsum( einsum_node, list(set(einsum_node.inputs) - set(inv_node_input.inputs))) # Assign pseudo nodes and chars in_subs, out_subs, _ = parse_einsum_input( (split_einsum_node.einsum_subscripts, *split_einsum_node.inputs)) in_subs_list = in_subs.split(',') updated_p_in_nodes = [] for i, node in enumerate(split_einsum_node.inputs): if isinstance(node, ad.EinsumNode): p_einsum_input = PseudoNode(node=node, subscript=in_subs_list[i]) elif node is inv_node: p_inv_input = PseudoNode(node=node, subscript=in_subs_list[i]) else: updated_p_in_nodes.append( PseudoNode(node=node, subscript=in_subs_list[i])) contract_char = "".join( set(p_einsum_input.subscript) & set(p_inv_input.subscript)) uncontract_str = "".join( set("".join([p_einsum_input.subscript, p_inv_input.subscript])) - set(contract_char)) if not (len(p_einsum_input.subscript) == 2 and len(p_inv_input.subscript) == 2 and len(contract_char) == 1 and len(uncontract_str) == 2): # this is not a matmul. Just return the initial node logger.info( f"the op between inv input and the selected einsum is not matmul, can't prune inv" ) return einsum_node if p_einsum_input.subscript[0] == p_inv_input.subscript[ 0] or p_einsum_input.subscript[1] == p_inv_input.subscript[1]: # the str is like "ab,ac", and one einsum needs to be transposed to compare p_in_subs, p_out_subs, _ = parse_einsum_input( (p_einsum_input.node.einsum_subscripts, *p_einsum_input.node.inputs)) einsum_input = ad.einsum( f"{p_in_subs}->{p_out_subs[1]}{p_out_subs[0]}", *p_einsum_input.node.inputs) else: einsum_input = p_einsum_input.node rewrite_einsum_expr(einsum_input) rewrite_einsum_expr(inv_node_input) if einsum_input.name != inv_node_input.name: logger.info( f"inv input and the selected einsum have different expressions, can't prune inv" ) return einsum_node # prune the inv node updated_p_in_nodes = updated_p_in_nodes + [ PseudoNode(node=ad.identity(inv_node_input.shape[0]), subscript=uncontract_str) ] return generate_new_einsum(updated_p_in_nodes, out_subs)
def simplify(output_node): """Simplify a graph with a single output node. The simplified form will distribute selected operations (+), and fuse all connected einsums. Args: node: The output node. Returns: node: The newly generated node. """ def fuse_all_einsums(node): linearize(node) ret_node = PseudoNode(node) all_pnodes = find_topo_sort_p([ret_node]) with OutputInjectedModeP(all_pnodes): trees = find_sub_einsumtree(ret_node) for tree in trees: out_node_p, in_nodes = tree new_z = fuse_einsums(out_node_p.node, in_nodes) prune_identity_nodes(new_z) replace_node(out_node_p, new_z) node = declone(ret_node.node) return node output_node = distribute_graph_w_linearize(output_node) output_node = fuse_all_einsums(output_node) output_pnode = PseudoNode(output_node) all_pnodes = find_topo_sort_p([output_pnode]) # optimize inverse with OutputInjectedModeP(all_pnodes): for pnode in all_pnodes: node = pnode.node if isinstance(node, ad.EinsumNode): # To make sure the same einsum nodes have the same same, # so that we can collapse the add node. rewrite_einsum_expr(node) if node.inputs != []: node.set_inputs(node.inputs) if isinstance(node, ad.TensorInverseNode): new_inv_node = optimize_inverse(node) replace_node(pnode, new_inv_node) # fuse again output_node = output_pnode.node output_node = fuse_all_einsums(output_node) # prune the orthonormal matmuls all_pnodes = find_topo_sort_p([output_pnode]) with OutputInjectedModeP(all_pnodes): for pnode in all_pnodes: node = pnode.node if node.inputs != []: node.set_inputs(node.inputs) if isinstance(node, ad.EinsumNode): new_node = prune_orthonormal_matmuls(node) replace_node(pnode, new_node) # prune inverse nodes output_pnode = PseudoNode(output_node) all_pnodes = find_topo_sort_p([output_pnode]) with OutputInjectedModeP(all_pnodes): for pnode in all_pnodes: node = pnode.node if node.inputs != []: node.set_inputs(node.inputs) if isinstance(node, ad.EinsumNode): new_node = prune_inv_node(node) replace_node(pnode, new_node) # prune the scalar nodes and remove unnecessary identity nodes all_pnodes = find_topo_sort_p([output_pnode]) with OutputInjectedModeP(all_pnodes): for pnode in all_pnodes: node = pnode.node if node.inputs != []: node.set_inputs(node.inputs) if isinstance(node, ad.EinsumNode): prune_identity_nodes(node) new_node = prune_scalar_nodes(node) replace_node(pnode, new_node) # collapse symmetric expressions all_pnodes = find_topo_sort_p([output_pnode]) for i in range(len(all_pnodes)): for j in range(i): collapse_symmetric_expr(all_pnodes[i].node, all_pnodes[j].node) sympy_input_types = (ad.DistributiveNode, ad.ScalarNode, ad.MulNode) #sympy_simplify the distributed nodes if isinstance(output_node, ad.DistributiveNode): sympy_inputs = [] all_nodes = find_topo_sort([output_node]) for node in all_nodes: if isinstance(node, ad.EinsumNode): # To make sure the same einsum nodes have the same name, # so that they can be reduced by sympy. rewrite_einsum_expr(node) if node.inputs != []: node.set_inputs(node.inputs) if not isinstance(node, sympy_input_types): sympy_inputs.append(node) output_node = sympy_simplify(output_node, sympy_inputs) return output_node