def test_einsum_multitier(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        input_nodes1, zs1 = get_tree("set1")
        input_nodes2, zs2 = get_tree("set2")
        out1 = zs1 + zs2

        input_nodes3, zs3 = get_tree("set3")
        input_nodes4, zs4 = get_tree("set4")
        out2 = zs3 + zs4
        out = ad.einsum("ij, jk->ik", out1, out2)
        input_nodes = input_nodes1 + input_nodes2 + input_nodes3 + input_nodes4

        generated_feed_dict = gen_dict(input_nodes)

        executor = ad.Executor([out])
        z_val, = executor.run(feed_dict=generated_feed_dict)

        with OutputInjectedModeP(find_topo_sort_p([PseudoNode(out)])):
            trees = find_sub_einsumtree(PseudoNode(out))
            for tree in trees:
                out_node, in_nodes = tree
                new_z = fuse_einsums(out_node.node, in_nodes)
                replace_node(out_node, new_z)

        executor = ad.Executor([out])
        z_new_val, = executor.run(feed_dict=generated_feed_dict)

        assert float_eq(z_val, z_new_val)
def find_sub_einsumtree(output_node_p):
    # TMP Pseudo Mode.
    """
    Finds all the subtrees from the given graph definition.
    There can be overlap of different subtrees.
    Arguments:
        output_node_p: the root of the tree, must be PseudoNode.
        input_nodes: leaf of the tree
    Returns:
        Return many einsum trees of the form 
        [[Pseudo root node, leaf nodes], ... ]
    """
    trees = []
    output_node = output_node_p.node
    if isinstance(output_node, ad.EinsumNode):
        tree_nodes = get_all_einsum_descendants(output_node)
        leaves = get_leaves(tree_nodes)
        for leaf in leaves:
            new_trees = find_sub_einsumtree(PseudoNode(leaf))
            trees += new_trees
        trees.append([output_node_p, leaves])
        return trees
    else:
        for i_node in output_node.inputs:
            new_trees = find_sub_einsumtree(PseudoNode(i_node))
            trees += new_trees
        return trees
def split_inv_einsum(inv_node):
    """
    Optimize the inverse of an einsum expression, such that
    inverse is operated on several smaller tensors.

    Parameters
    ----------
    node: The inverse of a fused einsum node

    Returns
    -------
    If the input node cannot be optimized, then return the input node.
    If it can be optimized, return the optimized node.

    """
    einsum_node = inv_node.inputs[0]
    assert isinstance(einsum_node, ad.EinsumNode)
    # einsum_node is a fused einsum
    for node in einsum_node.inputs:
        assert not isinstance(node, ad.EinsumNode)

    in_subs, out_subs, _ = parse_einsum_input(
        (einsum_node.einsum_subscripts, *einsum_node.inputs))
    in_subs_list = in_subs.split(',')

    p_einsum_node = PseudoNode(node=einsum_node, subscript=out_subs)
    p_in_nodes = []
    for i, node in enumerate(einsum_node.inputs):
        p_in_nodes.append(PseudoNode(node=node, subscript=in_subs_list[i]))

    dsets = inv_disjoint_sets(p_einsum_node, p_in_nodes)

    # If the node cannot be decomposed, just return the input node
    if len(dsets) == 1:
        return inv_node

    new_inputs = []
    for dset in dsets:
        input_decomp_einsum = list(
            filter(lambda node: any(char in dset for char in node.subscript),
                   p_in_nodes))
        out_subs = "".join(
            [char for char in p_einsum_node.subscript if char in dset])

        decomp_node = generate_new_einsum(input_decomp_einsum, out_subs)

        decomp_node.set_in_indices_length(int(len(out_subs) / 2))

        input_node = PseudoNode(node=ad.tensorinv(decomp_node),
                                subscript=out_subs)
        new_inputs.append(input_node)

    return generate_new_einsum(new_inputs, p_einsum_node.subscript)
def test_einsum_subtree_clone(backendopt):
    """
        [Subtree clone]
        This case is rather subtle.
        We want to auto fuse
            A   B   C   D
            |    \ /    |
            |     es    |
            |    /  \   |
            |  /      \ |
            es         es
              \       /
                  +

        Here es is einsum.
    """

    for datatype in backendopt:
        T.set_backend(datatype)
        a = ad.Variable(name="a", shape=[3, 3])
        b = ad.Variable(name="b", shape=[3, 2])
        c = ad.Variable(name="c", shape=[2, 3])
        d = ad.Variable(name="d", shape=[3, 3])

        BC = ad.einsum('ik, kj->ij', b, c)  # 3x3

        ABC = ad.einsum('ik, kj->ij', a, BC)  # 3x3

        BCD = ad.einsum('jk, ki->ji', BC, d)  # 3x3

        out = ABC + BCD

        input_nodes = [a, b, c, d]
        generated_feed_dict = gen_dict(input_nodes)

        executor = ad.Executor([out])
        out_val, = executor.run(feed_dict=generated_feed_dict)

        with OutputInjectedModeP(find_topo_sort_p([PseudoNode(out)])):
            trees = find_sub_einsumtree(PseudoNode(out))
            assert len(trees) == 2
            for tree in trees:
                out_node, in_nodes = tree
                new_z = fuse_einsums(out_node.node, in_nodes)
                replace_node(out_node, new_z)

        new_out_val, = executor.run(feed_dict=generated_feed_dict)

        assert float_eq(out_val, new_out_val)
Exemple #5
0
def dedup_transpose(graph, node, trans_node, trans_indices):
    """
    Replace the node with the trans_node, and change its output nodes in graph accordingly.

    Parameters
    ----------
    graph: list of nodes denoting a connected graph.
    node: node to be replaced.
    trans_node: the transposed node that will replace node.
    trans_indices: the transpose indices.
    """
    assert node in graph
    assert trans_node in graph

    with OutputInjectedModeP([PseudoNode(n) for n in graph]):
        for onode in node.outputs:
            # NOTE: currently we cannot deal with non-einsum nodes.
            assert isinstance(onode, ad.EinsumNode)
            in_subs, out_subs, _ = parse_einsum_input(
                (onode.einsum_subscripts, *onode.inputs))
            in_subs_list = in_subs.split(',')
            for (i, n) in enumerate(onode.inputs):
                if n is node:
                    onode.inputs[i] = trans_node
                    str_list = list(in_subs_list[i])
                    in_subs_list[i] = "".join(
                        [str_list[j] for j in trans_indices])

            new_subscripts = ",".join(in_subs_list) + "->" + out_subs
            onode.einsum_subscripts = new_subscripts
            onode.set_inputs(onode.inputs)
def optimize(node):
    """Optimize a graph with a single output node.

    Args:
        node: The output node.
    Returns:
        node: The newly generated node.
    """
    node = distribute_tree(node)
    linearize(node)

    all_nodes = find_topo_sort([node])
    ret_node = PseudoNode(node)
    with OutputInjectedMode(all_nodes):
        trees = find_sub_einsumtree(ret_node)
        for tree in trees:
            out_node_p, in_nodes = tree
            new_z = fuse_einsums(out_node_p.node, in_nodes)
            prune_identity_nodes(new_z)
            new_z = generate_optimal_tree(new_z)
            replace_node(out_node_p, new_z)

    node = declone(ret_node.node)
    all_nodes = find_topo_sort([node])
    for node in all_nodes:
        if isinstance(node, ad.EinsumNode):
            rewrite_einsum_expr(node)

    for node in find_topo_sort([node]):
        if node.inputs != []:
            node.set_inputs(node.inputs)

    dedup(node)
    return node
def test_einsum_fuse_graph(backendopt):
    """
        [Fuse einsum used twice]
        This case is rather subtle.
        We want to auto fuse
            A   B   C 
            |    \ /  
            |     es  
            |    /|   
            |  /  |   
            es    |   
              \   |   
                es
        Here es is einsum.
    """

    for datatype in backendopt:
        T.set_backend(datatype)
        a = ad.Variable(name="a", shape=[3, 3])
        b = ad.Variable(name="b", shape=[3, 2])
        c = ad.Variable(name="c", shape=[2, 3])

        BC = ad.einsum('ik, kj->ij', b, c)  # 3x3

        ABC = ad.einsum('ik, kj->ij', a, BC)  # 3x3

        out = ad.einsum('jk, ki->ji', ABC, BC)  # 3x3

        linearize(out)
        tree, = find_sub_einsumtree(PseudoNode(out))
        out, ins = tree
        new_z = fuse_einsums(out.node, ins)

        assert tree_eq(out.node, new_z, [a, b, c])
def test_einsum_fuse_w_identity(backendopt):
    """
        [Fuse einsum with multiple identities]
        We want to fuse
            A   identity  identity
            |    \       /
            |     \     /
            |      \   /
            |       es
            |     /
            |    /
            |   /
            |  /
            es
        Here es is einsum.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[3, 3])
        es_identity = ad.einsum('ik,kj->ij', ad.identity(3), ad.identity(3))
        out = ad.einsum('ai,ij->aj', a, es_identity)

        tree, = find_sub_einsumtree(PseudoNode(out))
        out, ins = tree
        new_out = fuse_einsums(out.node, ins)
        assert tree_eq(out.node, new_out, [a])
Exemple #9
0
def prune_identity_nodes(einsum_node):
    """
        reduce the number of identity nodes in the
        einsum_node's inputs. Inplace update.

        Args:
            einsum_node: An fused einsum node.
    """
    if not (isinstance(einsum_node, ad.EinsumNode)):
        return

    uf_str, p_outnode, p_innodes = generate_einsum_info(einsum_node)
    whole_str = p_outnode.subscript + "".join(
        [node.subscript for node in p_innodes])

    p_identity_nodes = list(
        filter(lambda pnode: isinstance(pnode.node, ad.IdentityNode),
               p_innodes))
    p_variable_nodes = [
        pnode for pnode in p_innodes if pnode not in p_identity_nodes
    ]

    # each disjoint set in uf_identity represents the indices
    # linked by identity node
    uf_identity = UF(list(whole_str))
    for pnode in p_identity_nodes:
        uf_identity.connect(pnode.subscript[0], pnode.subscript[1])

    input_indices_set, output_indices_set = set(), set()
    for pnode in p_variable_nodes:
        # replace subscripts by the root chars
        sub_list = [uf_identity.root(char) for char in pnode.subscript]
        pnode.subscript = "".join(sub_list)
        input_indices_set |= set(sub_list)

    p_updated_inputs = p_variable_nodes
    out_sub_list = []
    for i, char in enumerate(p_outnode.subscript):
        uf_root_char = uf_identity.root(char)
        if uf_root_char in output_indices_set:
            # we cannot assign the same char to two indices in the
            # output. Therefore, assign a new char, and add one
            # identity node to the inputs to show the constraint.
            new_char = uf_str.cg.getchar()
            out_sub_list.append(new_char)
            p_identity_node = PseudoNode(node=ad.identity(
                einsum_node.shape[i]),
                                         subscript=f"{uf_root_char}{new_char}")
            p_updated_inputs.append(p_identity_node)
        else:
            # directly assign the root char to the subscripts
            out_sub_list.append(uf_root_char)
            output_indices_set.add(uf_root_char)
    p_outnode.subscript = "".join(out_sub_list)

    new_input_subs = [pnode.subscript for pnode in p_updated_inputs]
    new_subscripts = ",".join(new_input_subs) + "->" + p_outnode.subscript
    einsum_node.einsum_subscripts = new_subscripts
    einsum_node.set_inputs([pnode.node for pnode in p_updated_inputs])
def distribute_tree(output):
    """ Distribute a tree of einsum and add nodes.

    NOTE: the output node should be a linearized node.
    Behavior undefined if there are other kind of nodes.

    Args:
        output: The output of a tree.

    Returns:
        output: a newly generated node with add operands distributed.
    
    Idea:
        1. Construct the output tree.
        2. Find binary op.
        3. Apply distribute.
        4. Iterate 1->3
    """
    def get_first_binary_op(pnodes):
        for pnode in pnodes:
            node = pnode.node
            if isinstance(node,
                          ad.DistributiveNode) and len(node.outputs) >= 1:
                has_einsum_nodes = all(
                    [isinstance(x, ad.EinsumNode) for x in node.outputs])
                if has_einsum_nodes:
                    return node
        return None

    while 1:
        all_pnodes = find_topo_sort_p([PseudoNode(output)])
        with OutputInjectedModeP(all_pnodes):
            first_binary_op = get_first_binary_op(all_pnodes)
            if first_binary_op is None:
                break
            for einsum_node in first_binary_op.outputs:
                if isinstance(einsum_node, ad.DistributiveNode):
                    continue
                assert isinstance(einsum_node, ad.EinsumNode)
                new_node = _distribute(first_binary_op, einsum_node)
                replace_node(PseudoNode(einsum_node), new_node)
                if einsum_node == output:
                    output = new_node
    # This is need for source generation.
    output.set_inputs(output.inputs)
    return output
def test_einsum_fuse_only_identity(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        es_identity = ad.einsum('ik,kj->ij', ad.identity(3), ad.identity(3))
        out = ad.einsum('ai,ij->aj', ad.identity(3), es_identity)

        tree, = find_sub_einsumtree(PseudoNode(out))
        out, ins = tree
        new_out = fuse_einsums(out.node, ins)
        assert tree_eq(out.node, new_out, [])
    def fuse_all_einsums(node):
        linearize(node)
        ret_node = PseudoNode(node)
        all_pnodes = find_topo_sort_p([ret_node])
        with OutputInjectedModeP(all_pnodes):
            trees = find_sub_einsumtree(ret_node)
            for tree in trees:
                out_node_p, in_nodes = tree
                new_z = fuse_einsums(out_node_p.node, in_nodes)
                prune_identity_nodes(new_z)
                replace_node(out_node_p, new_z)

        node = declone(ret_node.node)
        return node
def test_einsum_find_subtree_after_linearization(backendopt):
    """
        An einsum graph like
        A    B   inputs 
        |\   |
        | \  |
        |  \ |
        |   C
        |  / 
        | /
        output

        will produce

        An einsum graph like
        A    B   inputs 
        |\   |
        | A1 |
        |  \ |
        A2  C
        |  / 
        | /
        output

        The subtree inputs must then be [A1, A2, B] rather than A, B.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a1", shape=[3, 2])
        b = ad.Variable(name="b", shape=[2, 3])

        c = ad.einsum('ik,kj->ij', a, b)
        output = ad.einsum('ik,ij->kj', a, c)

        feed_dict = gen_dict([a, b])
        executor = ad.Executor([output])
        out_val, = executor.run(feed_dict=feed_dict)

        # New graph
        linearize(output)
        tree, = find_sub_einsumtree(PseudoNode(output))
        assert (len(tree[1]) == 3)
Exemple #14
0
def dedup(*nodes):
    """Remove the duplicate nodes with same name.
    Args:
        nodes: One or many nodes.
    """
    assert len(nodes) > 0

    topo_order = find_topo_sort_p([PseudoNode(n) for n in nodes])
    with OutputInjectedModeP(topo_order):
        unique_nodes_map = {}
        unique_nodes = set()
        # Use the last occurrence.
        for ptmp in topo_order:
            tmp = ptmp.node
            unique_nodes_map[tmp.name] = tmp
        unique_nodes = set(unique_nodes_map.values())

        for ptmp in topo_order:
            tmp = ptmp.node
            if tmp not in unique_nodes:
                unique_copy = unique_nodes_map[tmp.name]
                replace_node(ptmp, unique_copy)
def linearize(output_node):
    """Linearize a graph by adding clone nodes for computation optimization.

    Args:
        output_node: A single node.
    Returns: 
        None. Update is inplace. 

    NOTE: If you ever need to debug this function, the generated name is 
        inconsistent becasue of the added edges.

    """
    # Need to create new nodes for whichever node that has 2 or more outgoing edges.
    all_pnodes = find_topo_sort_p([PseudoNode(output_node)])
    # Inject outputs relationship.
    with OutputInjectedModeP(all_pnodes):
        for pn in all_pnodes:
            n = pn.node
            if len(n.outputs) > 1:
                for n_o in set(n.outputs):
                    n_o.set_inputs([
                        tmp if tmp.name != n.name else copy_tree(n)
                        for tmp in n_o.inputs
                    ])
def prune_single_inv_node(einsum_node, inv_node):
    """
    Prune the inv_node in the einsum node if condition mets.

    Note:
    1. can only optimize the node when the input of inv is an einsum node.
    2. only supports the case when the splitted nodes are different from the remaining ones.
        For example: ad.einsum("ab,bc,cd,de->ae", inv("ab,bc->ac", A, B), A, B, C) will be
        optimzied to ad.einsum("ab,bc->ac", C, ad.identity()),
        but we cannot optimize ad.einsum("ab,bc,cd,de->ae", inv("ab,bc->ac", A, B), A, B, B).

    Parameters
    ----------
    einsum_node: The fused einsum node
    inv_node: the input inv node to be pruned

    Returns
    -------
    If the einsum_node cannot be optimized, then return the input einsum_node.
    If it can be optimized, return the optimized einsum node.

    """
    from autohoot.einsum_graph.expr_generator import rewrite_einsum_expr
    from autohoot.graph_ops.optimal_tree import split_einsum

    inv_node_input = inv_node.inputs[0]
    if not isinstance(inv_node_input, ad.EinsumNode):
        logger.info(f"inv input is not einsum node, can't prune inv")
        return einsum_node

    if not set(inv_node_input.inputs).issubset(set(einsum_node.inputs)):
        logger.info(
            f"inv inputs is not subset of einsum node inputs, can't prune inv")
        return einsum_node

    einsum_inputs_in_inv = [
        n for n in einsum_node.inputs if n in inv_node_input.inputs
    ]
    if len(einsum_inputs_in_inv) < len(inv_node_input.inputs):
        logger.info(
            f"number of inv inputs is more than corresponding einsum inputs, can't prune inv"
        )
        return einsum_node

    split_einsum_node = split_einsum(
        einsum_node,
        list(set(einsum_node.inputs) - set(inv_node_input.inputs)))

    # Assign pseudo nodes and chars
    in_subs, out_subs, _ = parse_einsum_input(
        (split_einsum_node.einsum_subscripts, *split_einsum_node.inputs))
    in_subs_list = in_subs.split(',')

    updated_p_in_nodes = []
    for i, node in enumerate(split_einsum_node.inputs):
        if isinstance(node, ad.EinsumNode):
            p_einsum_input = PseudoNode(node=node, subscript=in_subs_list[i])
        elif node is inv_node:
            p_inv_input = PseudoNode(node=node, subscript=in_subs_list[i])
        else:
            updated_p_in_nodes.append(
                PseudoNode(node=node, subscript=in_subs_list[i]))

    contract_char = "".join(
        set(p_einsum_input.subscript) & set(p_inv_input.subscript))
    uncontract_str = "".join(
        set("".join([p_einsum_input.subscript, p_inv_input.subscript])) -
        set(contract_char))

    if not (len(p_einsum_input.subscript) == 2 and len(p_inv_input.subscript)
            == 2 and len(contract_char) == 1 and len(uncontract_str) == 2):
        # this is not a matmul. Just return the initial node
        logger.info(
            f"the op between inv input and the selected einsum is not matmul, can't prune inv"
        )
        return einsum_node

    if p_einsum_input.subscript[0] == p_inv_input.subscript[
            0] or p_einsum_input.subscript[1] == p_inv_input.subscript[1]:
        # the str is like "ab,ac", and one einsum needs to be transposed to compare
        p_in_subs, p_out_subs, _ = parse_einsum_input(
            (p_einsum_input.node.einsum_subscripts,
             *p_einsum_input.node.inputs))
        einsum_input = ad.einsum(
            f"{p_in_subs}->{p_out_subs[1]}{p_out_subs[0]}",
            *p_einsum_input.node.inputs)
    else:
        einsum_input = p_einsum_input.node

    rewrite_einsum_expr(einsum_input)
    rewrite_einsum_expr(inv_node_input)

    if einsum_input.name != inv_node_input.name:
        logger.info(
            f"inv input and the selected einsum have different expressions, can't prune inv"
        )
        return einsum_node

    # prune the inv node
    updated_p_in_nodes = updated_p_in_nodes + [
        PseudoNode(node=ad.identity(inv_node_input.shape[0]),
                   subscript=uncontract_str)
    ]

    return generate_new_einsum(updated_p_in_nodes, out_subs)
Exemple #17
0
def rewrite_einsum_expr(einsum_node):
    """
        Rewrites the einsum expression of a node.
        Inplace update.

        Args:
            einsum_node: Allow duplicate inputs of the einsum node.

        Returns:
            uf (type: graph_ops.graph_optimizer.UF): 
            the union_find set of the input
        
    """
    assert (isinstance(einsum_node, ad.EinsumNode))
    input_nodes = einsum_node.inputs

    # TODO: Get all the einsum nodes in the computation graph.
    # Note that the order matters!

    pseudo_nodes = []
    # Here einsum node has a temporary name so that the character assignment
    # order is consistent.
    einsum_node_dims_info = [
        DimInfo(node=einsum_node, dim_index=i, temp_node_name='_temp_einsum')
        for i in range(len(einsum_node.shape))
    ]
    pseudo_nodes.append(
        PseudoNode(node=einsum_node, dims_info=einsum_node_dims_info))

    for k, node in enumerate(einsum_node.inputs):
        dims_info = [
            DimInfo(node=node, dim_index=i, node_index=k)
            for i in range(len(node.shape))
        ]
        pseudo_nodes.append(PseudoNode(node=node, dims_info=dims_info))

    all_dims_info = sum([node.dims_info for node in pseudo_nodes], [])

    # For any two dims with the same literal, get their pos and connect.
    uf = UF(all_dims_info)
    cross_einsum_connect(uf, einsum_node, all_dims_info)

    uf.assign()
    # Assign literals
    for node in pseudo_nodes:
        node.generate_subscript(uf)

    einsum_node_subscript = pseudo_nodes[0].subscript

    # Remove the einsum node.
    pseudo_nodes.pop(0)

    # Sort based on both the node name and subscript.
    pseudo_nodes = sorted(pseudo_nodes,
                          key=lambda pnode: pnode.node.name + pnode.subscript)

    new_input_subs = [pnode.subscript for pnode in pseudo_nodes]
    new_subscripts = ",".join(new_input_subs) + "->" + einsum_node_subscript
    einsum_node.einsum_subscripts = new_subscripts
    einsum_node.set_inputs([pnode.node for pnode in pseudo_nodes])
    logger.info(f"Rewrite to new subscript: {new_subscripts}")

    return uf
Exemple #18
0
def prune_orthonormal_matmuls(einsum_node):
    """
    Remove the matrices of a einsum_node if M @ M.T like structures exist.
    Args:
        einsum_node: An fused einsum node.
    Return:
        An optimized einsum node.
    """

    # A map from the orthonormal matrix mode to (orthonormal_index, contraction_index)
    orthonormal_indices_map = {'column': (0, 1), 'row': (1, 0)}

    _, p_outnode, p_innodes = generate_einsum_info(einsum_node)
    subs_list = [pnode.subscript
                 for pnode in p_innodes] + [p_outnode.subscript]

    ortho_pnode_map = {}
    for pnode in p_innodes:
        if isinstance(pnode.node,
                      ad.MatrixNode) and pnode.node.orthonormal != None:
            nodename = pnode.node.name
            if nodename in ortho_pnode_map:
                ortho_pnode_map[nodename].append(pnode)
            else:
                ortho_pnode_map[nodename] = [pnode]

    for pnodes in ortho_pnode_map.values():
        if len(pnodes) < 2:
            continue

        remaining_pnodes = pnodes
        pnodes_subs = list(itertools.combinations(pnodes, 2))

        for pnodes_binary_input in pnodes_subs:
            if not set(pnodes_binary_input).issubset(set(remaining_pnodes)):
                continue

            pnode_A, pnode_B = pnodes_binary_input
            o_index, c_index = orthonormal_indices_map[
                pnode_A.node.orthonormal]
            # Criteria for the pruning: the o_index of two inputs are different,
            # and the c_index only appear in these two nodes.
            c_index_is_equal = pnode_A.subscript[c_index] == pnode_B.subscript[
                c_index]
            o_index_not_equal = pnode_A.subscript[
                o_index] != pnode_B.subscript[o_index]
            if not (c_index_is_equal and o_index_not_equal):
                continue
            num_subs_w_cindex = len(
                list(
                    filter(lambda subs: pnode_A.subscript[c_index] in subs,
                           subs_list)))
            if not num_subs_w_cindex == 2:
                continue
            remaining_pnodes = [
                pnode for pnode in remaining_pnodes
                if not pnode in pnodes_binary_input
            ]
            p_innodes = [
                pnode for pnode in p_innodes
                if not pnode in pnodes_binary_input
            ]

            i_node = ad.identity(pnode_A.node.shape[o_index])
            i_subs = f"{pnode_A.subscript[o_index]}{pnode_B.subscript[o_index]}"
            p_innodes.append(PseudoNode(node=i_node, subscript=i_subs))

    new_input_subs = [pnode.subscript for pnode in p_innodes]
    new_subscripts = ",".join(new_input_subs) + "->" + p_outnode.subscript
    new_inputs = [pnode.node for pnode in p_innodes]

    return ad.einsum(new_subscripts, *new_inputs)
def simplify(output_node):
    """Simplify a graph with a single output node.
    The simplified form will distribute selected operations
    (+), and fuse all connected einsums.

    Args:
        node: The output node.
    Returns:
        node: The newly generated node.
    """
    def fuse_all_einsums(node):
        linearize(node)
        ret_node = PseudoNode(node)
        all_pnodes = find_topo_sort_p([ret_node])
        with OutputInjectedModeP(all_pnodes):
            trees = find_sub_einsumtree(ret_node)
            for tree in trees:
                out_node_p, in_nodes = tree
                new_z = fuse_einsums(out_node_p.node, in_nodes)
                prune_identity_nodes(new_z)
                replace_node(out_node_p, new_z)

        node = declone(ret_node.node)
        return node

    output_node = distribute_graph_w_linearize(output_node)
    output_node = fuse_all_einsums(output_node)

    output_pnode = PseudoNode(output_node)
    all_pnodes = find_topo_sort_p([output_pnode])
    # optimize inverse
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if isinstance(node, ad.EinsumNode):
                # To make sure the same einsum nodes have the same same,
                # so that we can collapse the add node.
                rewrite_einsum_expr(node)
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.TensorInverseNode):
                new_inv_node = optimize_inverse(node)
                replace_node(pnode, new_inv_node)

    # fuse again
    output_node = output_pnode.node
    output_node = fuse_all_einsums(output_node)

    # prune the orthonormal matmuls
    all_pnodes = find_topo_sort_p([output_pnode])
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.EinsumNode):
                new_node = prune_orthonormal_matmuls(node)
                replace_node(pnode, new_node)

    # prune inverse nodes
    output_pnode = PseudoNode(output_node)
    all_pnodes = find_topo_sort_p([output_pnode])
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.EinsumNode):
                new_node = prune_inv_node(node)
                replace_node(pnode, new_node)

    # prune the scalar nodes and remove unnecessary identity nodes
    all_pnodes = find_topo_sort_p([output_pnode])
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.EinsumNode):
                prune_identity_nodes(node)
                new_node = prune_scalar_nodes(node)
                replace_node(pnode, new_node)

    # collapse symmetric expressions
    all_pnodes = find_topo_sort_p([output_pnode])
    for i in range(len(all_pnodes)):
        for j in range(i):
            collapse_symmetric_expr(all_pnodes[i].node, all_pnodes[j].node)

    sympy_input_types = (ad.DistributiveNode, ad.ScalarNode, ad.MulNode)
    #sympy_simplify the distributed nodes
    if isinstance(output_node, ad.DistributiveNode):
        sympy_inputs = []
        all_nodes = find_topo_sort([output_node])
        for node in all_nodes:
            if isinstance(node, ad.EinsumNode):
                # To make sure the same einsum nodes have the same name,
                # so that they can be reduced by sympy.
                rewrite_einsum_expr(node)
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if not isinstance(node, sympy_input_types):
                sympy_inputs.append(node)
        output_node = sympy_simplify(output_node, sympy_inputs)

    return output_node
def fuse_einsums(output_node, input_nodes):
    """
    Find and fuse einsums.
        Parameters:
            Each node must have attribute inputs, which makes it a sparse graph
            representation.
        Returns:
            A graph with fused intermediate einsum nodes. Represented by
            output_node.
    Note: inputs of a node can have same node. But one node can't go to two 
    output nodes
    """
    # First assume everything einsum.
    logger.info('Start fusing einsum')

    # Making this automatic.
    # Assume output_node is einsum and their children are einsum of any number
    # of input nodes
    assert (isinstance(output_node, ad.EinsumNode))

    pseudo_nodes = []

    # # Get all the einsum nodes except the input nodes in the computation graph.
    # # Note that the order doesn't matter!
    all_nodes = find_topo_sort([output_node], input_nodes)

    pseudo_input_nodes = []
    pseudo_output_node = None

    # We first represennt each dim as a different character, and then union.
    # Create a map
    for k, node in enumerate(all_nodes):
        node.dims_info = [
            DimInfo(node=node, dim_index=i, node_index=k)
            for i in range(len(node.shape))
        ]
        pnode = PseudoNode(node=node, dims_info=node.dims_info)
        pseudo_nodes.append(pnode)
        if node in input_nodes:
            pseudo_input_nodes.append(pnode)
        if node == output_node:
            pseudo_output_node = pnode

    intermediate_nodes = list(set(pseudo_nodes) - set(pseudo_input_nodes))

    einsum_pseudo_nodes = list(
        filter(lambda x: isinstance(x.node, ad.EinsumNode),
               intermediate_nodes))

    all_dims_info = sum([node.dims_info for node in pseudo_nodes], [])

    # For any two dims with the same literal, get their pos and connect.
    uf = UF(all_dims_info)
    for node in einsum_pseudo_nodes:
        all_dims_info = node_dims_info(node)
        cross_einsum_connect(uf, node.node, all_dims_info)

    uf.assign()
    # Assign literals
    for node in pseudo_nodes:
        node.generate_subscript(uf)

    new_input_subs = [node.subscript for node in pseudo_input_nodes]
    new_subscripts = ",".join(
        new_input_subs) + "->" + pseudo_output_node.subscript
    logger.info(f"Generated new subscript: {new_subscripts}")
    ##########################################
    output_node = ad.einsum(new_subscripts,
                            *[node.node for node in pseudo_input_nodes])

    return output_node