コード例 #1
0
def test_collapse_symmetric_expr_complex():
    """
    out1:
    A1 - a - A2 - b - A3
    |         |        |
    c         d        e
    |         |        |
    H1 - f - H2 - g - H3
    |         |        |
    h         i        j
    out2:
    a         b        c
    |         |        |
    H1 - d - H2 - e - H3
    |         |        |
    f         g        h
    A1 - i - A2 - j - A3
    """
    H1 = ad.Variable(name="H1", shape=[2, 2, 2], symmetry=[[0, 2]])
    H2 = ad.Variable(name="H2", shape=[2, 2, 2, 2], symmetry=[[0, 2]])
    H3 = ad.Variable(name="H3", shape=[2, 2, 2], symmetry=[[0, 1]])

    A1 = ad.Variable(name="H1", shape=[2, 2])
    A2 = ad.Variable(name="H2", shape=[2, 2, 2])
    A3 = ad.Variable(name="H3", shape=[2, 2])

    out1 = ad.einsum("ca,dab,eb,cfh,dgif,ejg->hij", A1, A2, A3, H1, H2, H3)
    out2 = ad.einsum("fi,gij,hj,adf,begd,che->abc", A1, A2, A3, H1, H2, H3)

    collapse_symmetric_expr(out1, out2)

    assert out1.name == out2.name
コード例 #2
0
def test_collapse_expr_w_identity():
    a = ad.Variable(name="a", shape=[2, 2])
    I = ad.identity(2)

    out1 = ad.einsum("ab,bc->ac", a, I)
    out2 = ad.einsum("ab,cb->ac", a, I)

    collapse_symmetric_expr(out1, out2)

    assert out1.name == out2.name
コード例 #3
0
def test_cannot_collapse_expr():
    h = ad.Variable(name="h", shape=[2, 2, 2, 2])
    a = ad.Variable(name="a", shape=[2, 2])

    out1 = ad.einsum("ijkl,ik->jl", h, a)
    out2 = ad.einsum("ijkl,jl->ik", h, a)

    collapse_symmetric_expr(out1, out2)

    assert out1.name != out2.name
コード例 #4
0
def test_collapse_symmetry_w_multiple_contraction_ind():

    H = ad.Variable(name="H", shape=[2, 2], symmetry=[[0, 1]])
    x1 = ad.Variable(name="x1", shape=[2])
    x2 = ad.Variable(name="x2", shape=[2])

    inner1 = ad.einsum("ab,a,b->", H, x1, x2)
    inner2 = ad.einsum("ab,b,a->", H, x1, x2)

    collapse_symmetric_expr(inner1, inner2)
    assert inner1.name == inner2.name
コード例 #5
0
def test_cannot_collapse_symmetric_expr():
    h = ad.Variable(name="h", shape=[2, 2, 2, 2], symmetry=[[0, 1], [2, 3]])
    a = ad.Variable(name="a", shape=[2, 2])

    # non-einsum node
    collapse_symmetric_expr(h, a)
    assert h.name != a.name

    # different inputs
    out1 = ad.einsum("ijkl,ik->jl", h, a)
    out2 = ad.einsum("jl,ijkl->ik", a, h)
    collapse_symmetric_expr(out1, out2)
    assert out1.name != out2.name

    # different output shape
    out1 = ad.einsum("ijkl,ik->jl", h, a)
    out2 = ad.einsum("ijkl,ik->jkl", h, a)
    collapse_symmetric_expr(out1, out2)
    assert out1.name != out2.name
コード例 #6
0
def simplify(output_node):
    """Simplify a graph with a single output node.
    The simplified form will distribute selected operations
    (+), and fuse all connected einsums.

    Args:
        node: The output node.
    Returns:
        node: The newly generated node.
    """
    def fuse_all_einsums(node):
        linearize(node)
        ret_node = PseudoNode(node)
        all_pnodes = find_topo_sort_p([ret_node])
        with OutputInjectedModeP(all_pnodes):
            trees = find_sub_einsumtree(ret_node)
            for tree in trees:
                out_node_p, in_nodes = tree
                new_z = fuse_einsums(out_node_p.node, in_nodes)
                prune_identity_nodes(new_z)
                replace_node(out_node_p, new_z)

        node = declone(ret_node.node)
        return node

    output_node = distribute_graph_w_linearize(output_node)
    output_node = fuse_all_einsums(output_node)

    output_pnode = PseudoNode(output_node)
    all_pnodes = find_topo_sort_p([output_pnode])
    # optimize inverse
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if isinstance(node, ad.EinsumNode):
                # To make sure the same einsum nodes have the same same,
                # so that we can collapse the add node.
                rewrite_einsum_expr(node)
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.TensorInverseNode):
                new_inv_node = optimize_inverse(node)
                replace_node(pnode, new_inv_node)

    # fuse again
    output_node = output_pnode.node
    output_node = fuse_all_einsums(output_node)

    # prune the orthonormal matmuls
    all_pnodes = find_topo_sort_p([output_pnode])
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.EinsumNode):
                new_node = prune_orthonormal_matmuls(node)
                replace_node(pnode, new_node)

    # prune inverse nodes
    output_pnode = PseudoNode(output_node)
    all_pnodes = find_topo_sort_p([output_pnode])
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.EinsumNode):
                new_node = prune_inv_node(node)
                replace_node(pnode, new_node)

    # prune the scalar nodes and remove unnecessary identity nodes
    all_pnodes = find_topo_sort_p([output_pnode])
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.EinsumNode):
                prune_identity_nodes(node)
                new_node = prune_scalar_nodes(node)
                replace_node(pnode, new_node)

    # collapse symmetric expressions
    all_pnodes = find_topo_sort_p([output_pnode])
    for i in range(len(all_pnodes)):
        for j in range(i):
            collapse_symmetric_expr(all_pnodes[i].node, all_pnodes[j].node)

    sympy_input_types = (ad.DistributiveNode, ad.ScalarNode, ad.MulNode)
    #sympy_simplify the distributed nodes
    if isinstance(output_node, ad.DistributiveNode):
        sympy_inputs = []
        all_nodes = find_topo_sort([output_node])
        for node in all_nodes:
            if isinstance(node, ad.EinsumNode):
                # To make sure the same einsum nodes have the same name,
                # so that they can be reduced by sympy.
                rewrite_einsum_expr(node)
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if not isinstance(node, sympy_input_types):
                sympy_inputs.append(node)
        output_node = sympy_simplify(output_node, sympy_inputs)

    return output_node