Ejemplo n.º 1
0
def optimize(node):
    """Optimize a graph with a single output node.

    Args:
        node: The output node.
    Returns:
        node: The newly generated node.
    """
    node = distribute_tree(node)
    linearize(node)

    all_nodes = find_topo_sort([node])
    ret_node = PseudoNode(node)
    with OutputInjectedMode(all_nodes):
        trees = find_sub_einsumtree(ret_node)
        for tree in trees:
            out_node_p, in_nodes = tree
            new_z = fuse_einsums(out_node_p.node, in_nodes)
            prune_identity_nodes(new_z)
            new_z = generate_optimal_tree(new_z)
            replace_node(out_node_p, new_z)

    node = declone(ret_node.node)
    all_nodes = find_topo_sort([node])
    for node in all_nodes:
        if isinstance(node, ad.EinsumNode):
            rewrite_einsum_expr(node)

    for node in find_topo_sort([node]):
        if node.inputs != []:
            node.set_inputs(node.inputs)

    dedup(node)
    return node
Ejemplo n.º 2
0
def generate_sequential_optimal_tree(einsum_nodes, input_nodes):
    """
    Regenerating einsum expressions based on the dimension tree.
    Parameters
    ----------
    einsum_nodes : list
        List of einsum nodes to be calculated based on the dimension tree.
    input_nodes : list
        List of input nodes whose contraction in the einsum_nodes obeys
        the sequence from the list end to the list start.

    Returns
    -------
        List of einsum nodes whose results are the same as einsum_nodes,
        while obeys the dimension tree calculation sequence.

    Examples
    --------
    >>> einsum_node_A = ad.einsum("abcd,bm,cm,dm->am", X, B, C, D)
    >>> einsum_node_B = ad.einsum("abcd,am,cm,dm->bm", X, A, C, D)
    >>> einsum_node_C = ad.einsum("abcd,am,bm,dm->cm", X, A, B, D)
    >>> dt = generate_sequential_optimal_tree([einsum_node_A, einsum_node_B, einsum_node_C], [A, B, C])
    >>> dt
    [ad.einsum('bm,abm->am', B, ad.einsum('cm,abcm->abm', C, ad.einsum('abcd,dm->abcm', X, D))),
    ad.einsum('am,abm->bm', A, ad.einsum('cm,abcm->abm', C, ad.einsum('abcd,dm->abcm', X, D))),
    ad.einsum('am,bm,abcm->cm', A, B, ad.einsum('abcd,dm->abcm', X, D)),
    ]
    (einsum strings may be different)
    """

    if len(einsum_nodes) == 1 and len(input_nodes) == 1:
        return einsum_nodes

    new_nodes = []
    for (i, node) in enumerate(einsum_nodes):
        contract_order = input_nodes[i + 1:]
        contract_order.reverse()
        contract_order = contract_order + input_nodes[:i]
        # get the subarray that is the inputs of node
        contract_order = list(
            filter(lambda n: n in node.inputs, contract_order))

        new_nodes.append(
            generate_optimal_tree_w_constraint(node, contract_order))

    # After generate_optimal_tree_w_constraint, some einstrs are not in the canonical format,
    # needs to rewrite again for dedup
    all_nodes = find_topo_sort(new_nodes)
    with OutputInjectedMode(all_nodes):
        for node in all_nodes:
            if isinstance(node, ad.EinsumNode):
                rewrite_einsum_expr(node)
            if node.inputs != []:
                node.set_inputs(node.inputs)

    dedup(*new_nodes)
    remove_transposes(find_topo_sort(new_nodes))
    return new_nodes
Ejemplo n.º 3
0
def test_einsum_multiuse_auto_copy(backendopt):
    """
        Test autolinearization and auto fuse.
        A    B   inputs 
        |\   |
        | \  |
        |  \ |
        |   C
        |  / 
        | /
        output

        Next: we would need to autoprune.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a1", shape=[3, 2])
        b = ad.Variable(name="b", shape=[2, 3])

        c = ad.einsum('ik,kj->ij', a, b)
        output = ad.einsum('ik,ij->kj', a, c)

        linearize(output)
        all_nodes = find_topo_sort([output])
        cloned_nodes = [
            tmp for tmp in all_nodes if isinstance(tmp, ad.CloneNode)
        ]

        out_new = fuse_einsums(output, [*cloned_nodes, b])
        # Test that every inputs is now fused.
        assert all([not isinstance(x, ad.EinsumNode) for x in out_new.inputs])

        assert tree_eq(output, out_new, [*cloned_nodes, b])
Ejemplo n.º 4
0
def test_einsum_gen_corner_case(backendopt):
    """
    Note: Numpy contraction path cannot find the opt path for this expression.
        It will output the same expression as the input.
    --------    E    --------
    |       |       |       |
    a       b       c       d
    |       |       |       |
    A - e - B - f - C - g - D
    |       |       |       |
    h       i       j       k
    |       |       |       |
    """
    size = 5
    A = ad.Variable(name="A", shape=[size, size, size])
    B = ad.Variable(name="B", shape=[size, size, size, size])
    C = ad.Variable(name="C", shape=[size, size, size, size])
    D = ad.Variable(name="D", shape=[size, size, size])
    E = ad.Variable(name="E", shape=[size, size, size, size])

    output = ad.einsum('aeh,bfie,cgjf,dgk,abcd->hijk', A, B, C, D, E)
    new_output = generate_optimal_tree(output)

    for node in find_topo_sort([new_output]):
        if not isinstance(node, ad.VariableNode):
            assert (len(node.inputs) == 2)
Ejemplo n.º 5
0
def get_common_ancestor(root, leaves, in_node):
    """
    Get in_node's common ancestor of a tree(defined by root and leaves).
    Here our tree may let a leaf in_node has multiple parents.

    Parameters
    ----------
    root: Tree root.
    leaves: A list of leaf nodes define the inputs of the subtree.
    in_node: one of the node in leaves such that multiple intermediate nodes can have it as children.

    Returns
    ----------
    ancestor: A ancestor that covers all the in_node(s) in the tree.
    """

    assert in_node in leaves

    num_in_nodes = len(list(filter(lambda n: n is in_node, leaves)))
    topo_order_list = find_topo_sort([root], leaves)

    for node in topo_order_list:
        # We want to get the smallest subtree whose inputs contain all the in_node(s).
        if isinstance(node, ad.EinsumNode):
            subtree_leaves = [
                n for n in get_all_nodes([node], leaves) if n in leaves
            ]
            num_in_nodes_subtree = len(
                list(filter(lambda n: n is in_node, subtree_leaves)))
            if num_in_nodes == num_in_nodes_subtree:
                return node
Ejemplo n.º 6
0
def print_computation_graph(output_node_list, input_nodes=[]):
    """
        ouput_node_list: a list of output nodes.
    """
    assert len(output_node_list) > 0

    topo_order = find_topo_sort(output_node_list, input_nodes)

    inputs = list(filter(lambda x: isinstance(x, ad.VariableNode), topo_order))
    with OutputInjectedMode(topo_order):

        dot = Digraph(comment='Poorman Computation Graph')

        with dot.subgraph() as s:
            s.attr(rank='same')
            for n in inputs:
                s.node(n.name, style='filled', color='aquamarine3')
        with dot.subgraph() as s:
            s.attr(rank='same')
            for n in output_node_list:
                s.node(n.name, style='filled', color='thistle')
        with dot.subgraph() as s:
            for n in topo_order:
                if (n not in output_node_list and n not in inputs):
                    s.node(n.name, style='filled', color='lightblue')

        for node in topo_order:
            dot.node(node.name, graph_name(node))
            for node_i in node.inputs:
                dot.edge(node_i.name, node.name)

        print(dot.source)
Ejemplo n.º 7
0
def test_cpd_hessian_optimize_offdiag(backendopt):
    dim = 3
    for datatype in backendopt:
        T.set_backend(datatype)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list
        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list

        hessian = ad.hessian(loss, [A, B, C])
        hessian_offdiag = [hessian[0][1], hessian[1][0]]
        for node in hessian_offdiag:
            optimize(node)
            assert isinstance(node, ad.AddNode)
            num_operations = len(
                list(
                    filter(lambda x: isinstance(x, ad.OpNode),
                           find_topo_sort([node]))))
            # This is currently non-deterministic.
            # assert num_operations == 14

        executor = ad.Executor(hessian_offdiag)
        hes_diag_vals = executor.run(feed_dict={
            A: A_val,
            B: B_val,
            C: C_val,
            input_tensor: input_tensor_val,
        })
Ejemplo n.º 8
0
 def _sub_forward(self, output_node_list):
     """Forward pass subroutine"""
     file_string = ''
     topo_order = find_topo_sort(output_node_list)
     file_string += indent_line(f'# forward pass starts')
     for node in topo_order:
         if isinstance(node, ad.VariableNode):
             file_string += indent_line(self._assign_init_variable(node))
         elif isinstance(node, ad.OpNode):
             file_string += indent_line(self._assign_mid_variable(node))
     return file_string
Ejemplo n.º 9
0
def test_remove_transposes_multiple_trans():
    a = ad.Variable(name="a", shape=[2, 2, 2, 2])

    intermediate1 = ad.einsum("abcd->dcba", a)
    intermediate2 = ad.einsum("abcd->abdc", a)

    ret1 = ad.einsum("dcba->badc", intermediate1)
    ret2 = ad.einsum("abdc->badc", intermediate2)

    remove_transposes(find_topo_sort([ret1, ret2]))

    assert ret1.name == ret2.name
Ejemplo n.º 10
0
def test_cpd_hessian_optimize_diag(backendopt):
    dim = 3
    for datatype in backendopt:
        T.set_backend(datatype)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list
        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list

        hessian = ad.hessian(loss, [A, B, C])
        hessian_diag = [hessian[0][0], hessian[1][1], hessian[2][2]]
        for node in hessian_diag:
            node = optimize(node)
            assert isinstance(node, ad.AddNode)
            num_operations = len(
                list(
                    filter(lambda x: isinstance(x, ad.OpNode),
                           find_topo_sort([node]))))
            """
            Use this assertion to test the optimize function.
            5 operations:
            1. T.einsum('ca,cb->ab',A,A),
            2. T.einsum('ca,cb->ab',B,B),
            3. T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)),
            4. T.einsum('bd,ac->abcd',T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)),T.identity(10)),
            5. (T.einsum('bd,ac->abcd',T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)),T.identity(10))+
            T.einsum('bd,ac->abcd',T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)),T.identity(10)))
            """
            assert num_operations == 5

        executor = ad.Executor(hessian_diag)
        hes_diag_vals = executor.run(feed_dict={
            A: A_val,
            B: B_val,
            C: C_val,
            input_tensor: input_tensor_val,
        })

        expected_hes_diag_val = [
            2 * T.einsum('eb,ed,fb,fd,ac->abcd', B_val, B_val, C_val, C_val,
                         T.identity(size)),
            2 * T.einsum('eb,ed,fb,fd,ac->abcd', A_val, A_val, C_val, C_val,
                         T.identity(size)),
            2 * T.einsum('eb,ed,fb,fd,ac->abcd', A_val, A_val, B_val, B_val,
                         T.identity(size))
        ]
        assert T.norm(hes_diag_vals[0] - expected_hes_diag_val[0]) < 1e-8
        assert T.norm(hes_diag_vals[1] - expected_hes_diag_val[1]) < 1e-8
        assert T.norm(hes_diag_vals[2] - expected_hes_diag_val[2]) < 1e-8
Ejemplo n.º 11
0
 def _sub_gTv(self, vector_list):
     """Subroutine of g and v inner product."""
     file_string = '\n'
     file_string += indent_line(f'# inner product of g and v starts')
     for node in vector_list:
         file_string += indent_line(self._assign_init_variable(node))
     inner_product_node = inner_product(vector_list, self.gradient_list)
     topo_order = find_topo_sort([inner_product_node])
     for node in topo_order:
         if node not in self.topo_order_gradients and \
                 node is not inner_product_node and \
                 node not in vector_list:
             file_string += self._assign_mid_variable(node)
     file_string += indent_line(
         f'_gTv = {inner_product_node.s2s_expr(inner_product_node.inputs)}')
     inner_product_node.name = '_gTv'
     return inner_product_node, file_string
Ejemplo n.º 12
0
def test_remove_transposes():
    a = ad.Variable(name="a", shape=[2, 2, 2, 2])
    b = ad.Variable(name="b", shape=[2, 2])
    c = ad.Variable(name="b", shape=[2, 2])
    d = ad.Variable(name="b", shape=[2, 2])

    ab1 = ad.einsum("abcd,de->abce", a, b)
    ab2 = ad.einsum("abcd,de->ecba", a, b)

    abc1 = ad.einsum("abce,ce->abe", ab1, c)
    abc2 = ad.einsum("ecba,ce->eba", ab2, c)

    abcd1 = ad.einsum("abe,be->ae", abc1, d)
    abcd2 = ad.einsum("eba,be->ae", abc2, d)

    remove_transposes(find_topo_sort([abcd1, abcd2]))

    assert abcd1.name == abcd2.name
Ejemplo n.º 13
0
def test_dmrg_shared_exec_graph():

    from graph_ops.graph_transformer import simplify
    from graph_ops.graph_als_optimizer import generate_sequential_optimal_tree
    from utils import find_topo_sort

    num, rank, size = 4, 3, 2
    mpo_ranks = [rank for i in range(1, num)]
    mps_ranks = [rank for i in range(1, num)]

    dg = DmrgGraph.create(num, mpo_ranks, mps_ranks, size)
    for i, hes in enumerate(dg.hessians):
        dg.hessians[i] = simplify(hes)
        assert isinstance(hes, ad.EinsumNode)
    dg.hessians = generate_sequential_optimal_tree(dg.hessians, dg.mps_inputs)

    # 8 input variables (4 H term in MPO, 4 A term in MPS), 7 einsum nodes
    assert len(find_topo_sort(dg.hessians)) == 15
Ejemplo n.º 14
0
def test_simple_dmrg_tree():
    A1 = ad.Variable(name="A1", shape=[3, 2])
    A2 = ad.Variable(name="A2", shape=[3, 3, 2])
    A3 = ad.Variable(name="A3", shape=[3, 2])

    X1 = ad.Variable(name="X1", shape=[3, 2, 2])
    X2 = ad.Variable(name="X2", shape=[3, 3, 2, 2])
    X3 = ad.Variable(name="X3", shape=[3, 2, 2])
    """
        The network and indices positions are as follows:

        A1 - f - A2 - g - A3
        |        |        |
        c        d        e
        |        |        |
        X1 - a - X2 - b - X3
        |        |        |
        h        i        j
        |        |        |
        A1 - k - A2 - l - A3

    """
    einsum_node_A1 = ad.einsum("ach,abdi,bej,fgd,kli,ge,lj->fckh", X1, X2, X3,
                               A2, A2, A3, A3)
    einsum_node_A2 = ad.einsum("ach,abdi,bej,fc,kh,ge,lj->fgdkli", X1, X2, X3,
                               A1, A1, A3, A3)
    einsum_node_A3 = ad.einsum("ach,abdi,bej,fc,kh,fgd,kli->gelj", X1, X2, X3,
                               A1, A1, A2, A2)

    dt = generate_sequential_optimal_tree(
        [einsum_node_A1, einsum_node_A2, einsum_node_A3], [A1, A2, A3])

    assert tree_eq(dt[0], einsum_node_A1, [X1, X2, X3, A1, A1, A2, A2, A3, A3])
    assert tree_eq(dt[1], einsum_node_A2, [X1, X2, X3, A1, A1, A2, A2, A3, A3])

    # In the correct contraction path, only X3 should be contracted with A3,
    # all other X nodes should be contracted later.
    einsum_inputs = list(
        filter(lambda node: isinstance(node, ad.EinsumNode),
               find_topo_sort(dt)))
    assert sorted(einsum_inputs[0].inputs,
                  key=lambda node: node.name) == sorted(
                      [A3, A3, X3], key=lambda node: node.name)
Ejemplo n.º 15
0
 def _sub_hvp(self, inner_product_node, node_list):
     """Subroutine of hvp."""
     file_string = '\n'
     file_string += indent_line(
         f'# backward pass of inner product of g and v starts')
     self.forward_to_hvp_map = ad.gradients_map(inner_product_node)
     self.hvp_to_forward_map = invert_dict(self.forward_to_hvp_map)
     hvp_nodes = [self.forward_to_hvp_map[node] for node in node_list]
     topo_order_hvps = find_topo_sort(hvp_nodes)
     for node in topo_order_hvps:
         if node not in self.forward_to_hvp_map.keys():
             if node not in self.forward_to_hvp_map.values():
                 file_string += indent_line(self._assign_mid_variable(node))
             else:
                 forward_node = self.hvp_to_forward_map[node]
                 file_string += indent_line(
                     f'_grad2{forward_node.name} = {node.s2s_expr(node.inputs)}'
                 )
                 node.name = f'_grad2{forward_node.name}'
     return file_string
Ejemplo n.º 16
0
    def _sub_gradients(self, output_node, node_list):
        """Gradient pass subroutine."""
        file_string = ''
        file_string += self._sub_forward([output_node])
        file_string += '\n'
        file_string += indent_line('# backward pass starts')

        self.forward_to_grad_map = ad.gradients_map(output_node)
        self.grad_to_forward_map = invert_dict(self.forward_to_grad_map)
        self.gradient_list = [
            self.forward_to_grad_map[node] for node in node_list
        ]
        self.topo_order_gradients = find_topo_sort(self.gradient_list)

        for node in self.topo_order_gradients:
            if node not in self.forward_to_grad_map.keys():
                if node not in self.forward_to_grad_map.values():
                    file_string += indent_line(self._assign_mid_variable(node))
                else:
                    file_string += indent_line(
                        self._assign_grad_variable(node))
        return file_string
Ejemplo n.º 17
0
def test_dimension_tree_4d():
    A = ad.Variable(name="A", shape=[2, 2])
    B = ad.Variable(name="B", shape=[2, 2])
    C = ad.Variable(name="C", shape=[2, 2])
    D = ad.Variable(name="D", shape=[2, 2])
    X = ad.Variable(name="X", shape=[2, 2, 2, 2])

    einsum_node_A = ad.einsum("abcd,bm,cm,dm->am", X, B, C, D)
    einsum_node_B = ad.einsum("abcd,am,cm,dm->bm", X, A, C, D)
    einsum_node_C = ad.einsum("abcd,am,bm,dm->cm", X, A, B, D)
    einsum_node_D = ad.einsum("abcd,am,bm,cm->dm", X, A, B, C)

    dt = generate_sequential_optimal_tree(
        [einsum_node_A, einsum_node_B, einsum_node_C, einsum_node_D],
        [A, B, C, D])

    # 5 inputs, 4 outputs, 5 intermedaites
    assert len(find_topo_sort(dt)) == 14

    assert tree_eq(dt[0], einsum_node_A, [A, B, C, D, X])
    assert tree_eq(dt[1], einsum_node_B, [A, B, C, D, X])
    assert tree_eq(dt[2], einsum_node_C, [A, B, C, D, X])
    assert tree_eq(dt[3], einsum_node_D, [A, B, C, D, X])
Ejemplo n.º 18
0
def simplify(output_node):
    """Simplify a graph with a single output node.
    The simplified form will distribute selected operations
    (+), and fuse all connected einsums.

    Args:
        node: The output node.
    Returns:
        node: The newly generated node.
    """
    def fuse_all_einsums(node):
        linearize(node)
        ret_node = PseudoNode(node)
        all_pnodes = find_topo_sort_p([ret_node])
        with OutputInjectedModeP(all_pnodes):
            trees = find_sub_einsumtree(ret_node)
            for tree in trees:
                out_node_p, in_nodes = tree
                new_z = fuse_einsums(out_node_p.node, in_nodes)
                prune_identity_nodes(new_z)
                replace_node(out_node_p, new_z)

        node = declone(ret_node.node)
        return node

    output_node = distribute_graph_w_linearize(output_node)
    output_node = fuse_all_einsums(output_node)

    output_pnode = PseudoNode(output_node)
    all_pnodes = find_topo_sort_p([output_pnode])
    # optimize inverse
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if isinstance(node, ad.EinsumNode):
                # To make sure the same einsum nodes have the same same,
                # so that we can collapse the add node.
                rewrite_einsum_expr(node)
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.TensorInverseNode):
                new_inv_node = optimize_inverse(node)
                replace_node(pnode, new_inv_node)

    # fuse again
    output_node = output_pnode.node
    output_node = fuse_all_einsums(output_node)

    # prune the orthonormal matmuls
    all_pnodes = find_topo_sort_p([output_pnode])
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.EinsumNode):
                new_node = prune_orthonormal_matmuls(node)
                replace_node(pnode, new_node)

    # prune inverse nodes
    output_pnode = PseudoNode(output_node)
    all_pnodes = find_topo_sort_p([output_pnode])
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.EinsumNode):
                new_node = prune_inv_node(node)
                replace_node(pnode, new_node)

    # prune the scalar nodes and remove unnecessary identity nodes
    all_pnodes = find_topo_sort_p([output_pnode])
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.EinsumNode):
                prune_identity_nodes(node)
                new_node = prune_scalar_nodes(node)
                replace_node(pnode, new_node)

    # collapse symmetric expressions
    all_pnodes = find_topo_sort_p([output_pnode])
    for i in range(len(all_pnodes)):
        for j in range(i):
            collapse_symmetric_expr(all_pnodes[i].node, all_pnodes[j].node)

    sympy_input_types = (ad.DistributiveNode, ad.ScalarNode, ad.MulNode)
    #sympy_simplify the distributed nodes
    if isinstance(output_node, ad.DistributiveNode):
        sympy_inputs = []
        all_nodes = find_topo_sort([output_node])
        for node in all_nodes:
            if isinstance(node, ad.EinsumNode):
                # To make sure the same einsum nodes have the same name,
                # so that they can be reduced by sympy.
                rewrite_einsum_expr(node)
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if not isinstance(node, sympy_input_types):
                sympy_inputs.append(node)
        output_node = sympy_simplify(output_node, sympy_inputs)

    return output_node
Ejemplo n.º 19
0
def fuse_einsums(output_node, input_nodes):
    """
    Find and fuse einsums.
        Parameters:
            Each node must have attribute inputs, which makes it a sparse graph
            representation.
        Returns:
            A graph with fused intermediate einsum nodes. Represented by
            output_node.
    Note: inputs of a node can have same node. But one node can't go to two 
    output nodes
    """
    # First assume everything einsum.
    logger.info('Start fusing einsum')

    # Making this automatic.
    # Assume output_node is einsum and their children are einsum of any number
    # of input nodes
    assert (isinstance(output_node, ad.EinsumNode))

    pseudo_nodes = []

    # # Get all the einsum nodes except the input nodes in the computation graph.
    # # Note that the order doesn't matter!
    all_nodes = find_topo_sort([output_node], input_nodes)

    pseudo_input_nodes = []
    pseudo_output_node = None

    # We first represennt each dim as a different character, and then union.
    # Create a map
    for k, node in enumerate(all_nodes):
        node.dims_info = [
            DimInfo(node=node, dim_index=i, node_index=k)
            for i in range(len(node.shape))
        ]
        pnode = PseudoNode(node=node, dims_info=node.dims_info)
        pseudo_nodes.append(pnode)
        if node in input_nodes:
            pseudo_input_nodes.append(pnode)
        if node == output_node:
            pseudo_output_node = pnode

    intermediate_nodes = list(set(pseudo_nodes) - set(pseudo_input_nodes))

    einsum_pseudo_nodes = list(
        filter(lambda x: isinstance(x.node, ad.EinsumNode),
               intermediate_nodes))

    all_dims_info = sum([node.dims_info for node in pseudo_nodes], [])

    # For any two dims with the same literal, get their pos and connect.
    uf = UF(all_dims_info)
    for node in einsum_pseudo_nodes:
        all_dims_info = node_dims_info(node)
        cross_einsum_connect(uf, node.node, all_dims_info)

    uf.assign()
    # Assign literals
    for node in pseudo_nodes:
        node.generate_subscript(uf)

    new_input_subs = [node.subscript for node in pseudo_input_nodes]
    new_subscripts = ",".join(
        new_input_subs) + "->" + pseudo_output_node.subscript
    logger.info(f"Generated new subscript: {new_subscripts}")
    ##########################################
    output_node = ad.einsum(new_subscripts,
                            *[node.node for node in pseudo_input_nodes])

    return output_node