def test_tree_distribution_ppE(dist_op, backendopt):
    """
        [Distributive] ((A + B) + C) * G

        will produce
        
        AG + BG + CG

        Note that (A+B) has parent (A + B) + C.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[3, 2])
        b = ad.Variable(name="b", shape=[3, 2])
        c = ad.Variable(name="c", shape=[3, 2])
        g = ad.Variable(name="g", shape=[2, 2])

        output = ad.einsum('ik,kk->ik', dist_op(dist_op(a, b), c), g)

        new_output = distribute_tree(output)
        assert isinstance(new_output, dist_op)

        assert tree_eq(output, new_output, [a, b, c, g])
def test_einsum_fuse_graph(backendopt):
    """
        [Fuse einsum used twice]
        This case is rather subtle.
        We want to auto fuse
            A   B   C 
            |    \ /  
            |     es  
            |    /|   
            |  /  |   
            es    |   
              \   |   
                es
        Here es is einsum.
    """

    for datatype in backendopt:
        T.set_backend(datatype)
        a = ad.Variable(name="a", shape=[3, 3])
        b = ad.Variable(name="b", shape=[3, 2])
        c = ad.Variable(name="c", shape=[2, 3])

        BC = ad.einsum('ik, kj->ij', b, c)  # 3x3

        ABC = ad.einsum('ik, kj->ij', a, BC)  # 3x3

        out = ad.einsum('jk, ki->ji', ABC, BC)  # 3x3

        linearize(out)
        tree, = find_sub_einsumtree(PseudoNode(out))
        out, ins = tree
        new_z = fuse_einsums(out.node, ins)

        assert tree_eq(out.node, new_z, [a, b, c])
def test_tree_distribution_w_add_output(dist_op, backendopt):
    """
        Test C * (A + B) + F * (D + E)
            = (C * A + C * B) + (F * D + F * E)
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[3, 3])
        b = ad.Variable(name="b", shape=[3, 3])
        c = ad.Variable(name="c", shape=[3, 3])

        d = ad.Variable(name="d", shape=[3, 3])
        e = ad.Variable(name="e", shape=[3, 3])
        f = ad.Variable(name="f", shape=[3, 3])

        out1 = ad.einsum('ik,kj->ij', c, dist_op(a, b))
        out2 = ad.einsum('ik,kj->ij', d, dist_op(e, f))
        output = dist_op(out1, out2)
        new_output = distribute_tree(output)
        assert isinstance(new_output, dist_op)
        for input_node in new_output.inputs:
            assert isinstance(input_node, dist_op)
        assert tree_eq(output, new_output, [a, b, c, d, e, f])
def test_tree_distribution_two_layers(dist_op, backendopt):
    """
        [Distributive] ((A + B) * G) * C

        will produce
        
        AGC + BGC

        Note that (A+B) * G is contracted first.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[3, 2])
        b = ad.Variable(name="b", shape=[3, 2])
        g = ad.Variable(name="g", shape=[2, 2])
        c = ad.Variable(name="c", shape=[2, 3])

        interm = ad.einsum('ik, kk->ik', dist_op(a, b), g)
        output = ad.einsum('ik,kj->ij', interm, c)

        new_output = distribute_tree(output)
        assert isinstance(new_output, dist_op)

        assert tree_eq(output, new_output, [a, b, c, g])
def test_einsum_fuse_w_identity(backendopt):
    """
        [Fuse einsum with multiple identities]
        We want to fuse
            A   identity  identity
            |    \       /
            |     \     /
            |      \   /
            |       es
            |     /
            |    /
            |   /
            |  /
            es
        Here es is einsum.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[3, 3])
        es_identity = ad.einsum('ik,kj->ij', ad.identity(3), ad.identity(3))
        out = ad.einsum('ai,ij->aj', a, es_identity)

        tree, = find_sub_einsumtree(PseudoNode(out))
        out, ins = tree
        new_out = fuse_einsums(out.node, ins)
        assert tree_eq(out.node, new_out, [a])
def test_prune_identity(backendopt):
    for datatype in backendopt:
        T.set_backend(datatype)

        a1 = ad.Variable(name="a1", shape=[3, 3])
        a2 = ad.Variable(name="a2", shape=[3, 3])
        i1 = ad.identity(3)
        i2 = ad.identity(3)
        i3 = ad.identity(3)

        out = ad.einsum("ab,cd,ac,be,ef->abdf", a1, a2, i1, i2, i3)
        prune_identity_nodes(out)
        """
        Explanation to the einsum above:
        The identity node i1 means that a and c should be the same dim.
        we can get rid of i1 and rewrite the expr as 
        ad.einsum("ab,ad,be,ef->abdf", a1, a2, i2, i3).
        we can also combine i2 and i3 cuz e is useless. Therefore,
        we can rewrite the expr as
        ad.einsum("ab,ad,bf->abdf", a1, a2, i2).
        """
        out_expect = ad.einsum("ab,ad,bf->abdf", a1, a2, i2)
        assert len(out.inputs) == 3

        assert tree_eq(out, out_expect, [a1, a2])
def test_einsum_multiuse_auto_copy(backendopt):
    """
        Test autolinearization and auto fuse.
        A    B   inputs 
        |\   |
        | \  |
        |  \ |
        |   C
        |  / 
        | /
        output

        Next: we would need to autoprune.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a1", shape=[3, 2])
        b = ad.Variable(name="b", shape=[2, 3])

        c = ad.einsum('ik,kj->ij', a, b)
        output = ad.einsum('ik,ij->kj', a, c)

        linearize(output)
        all_nodes = find_topo_sort([output])
        cloned_nodes = [
            tmp for tmp in all_nodes if isinstance(tmp, ad.CloneNode)
        ]

        out_new = fuse_einsums(output, [*cloned_nodes, b])
        # Test that every inputs is now fused.
        assert all([not isinstance(x, ad.EinsumNode) for x in out_new.inputs])

        assert tree_eq(output, out_new, [*cloned_nodes, b])
def test_einsum_multiuse(backendopt):
    """
        Test manual fuse.
        A    B   inputs 
        |\   |
        | \  |
        |  \ |
        |   C
        |  / 
        | /
        output

        Note that here we assume A is split into 2 vars by some other operations.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a1", shape=[3, 2])
        a_copy = ad.Variable(name="a2", shape=[3, 2])
        b = ad.Variable(name="b", shape=[2, 3])

        c = ad.einsum('ik,kj->ij', a, b)
        output = ad.einsum('ik,ij->kj', a_copy, c)
        # New graph
        out_new = fuse_einsums(output, [a, a_copy, b])
        assert tree_eq(output, out_new, [a, a_copy, b])
Exemple #9
0
def test_dimension_tree_w_identity():
    A = ad.Variable(name="A", shape=[2, 2])
    B = ad.identity(2)
    C = ad.Variable(name="C", shape=[2, 2])
    X = ad.Variable(name="X", shape=[2, 2, 2])

    einsum_node_A = ad.einsum("abc,bm,cm->am", X, B, C)
    einsum_node_B = ad.einsum("abc,am,cm->bm", X, A, C)
    einsum_node_C = ad.einsum("abc,am,bm->cm", X, A, B)

    dt = generate_sequential_optimal_tree(
        [einsum_node_A, einsum_node_B, einsum_node_C], [A, B, C])

    assert tree_eq(dt[0], einsum_node_A, [A, C, X])
    assert tree_eq(dt[1], einsum_node_B, [A, C, X])
    assert tree_eq(dt[2], einsum_node_C, [A, B, X])
def test_einsum(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        input_nodes, z = get_tree()
        z_new = fuse_einsums(z, input_nodes)
        assert tree_eq(z, z_new, input_nodes)
Exemple #11
0
def test_einsum_gen_custom(backendopt):
    for datatype in backendopt:
        a = ad.Variable(name="a", shape=[2, 2])
        b = ad.Variable(name="b", shape=[2, 5])
        c = ad.Variable(name="c", shape=[5, 2])

        output = ad.einsum('ij,jk,kl->il', a, b, c)
        new_output = generate_optimal_tree(output, path=[(1, 2), (0, 1)])
        assert tree_eq(output, new_output, [a, b, c])
Exemple #12
0
def test_tensordot(backendopt):
    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[3, 3, 3, 3])
        b = ad.Variable(name="b", shape=[3, 3, 3, 3])
        result = ad.tensordot(a, b, axes=[[1, 3], [0, 1]])
        result2 = ad.einsum("abcd,bdef->acef", a, b)

        assert tree_eq(result, result2, [a, b])
Exemple #13
0
def test_einsum_gen(backendopt):
    for datatype in backendopt:
        N = 10
        C = ad.Variable(name="C", shape=[N, N])
        I = ad.Variable(name="I", shape=[N, N, N, N])

        output = ad.einsum('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C)
        new_output = generate_optimal_tree(output)
        assert tree_eq(output, new_output, [C, I])
        assert len(new_output.inputs) == 2
Exemple #14
0
def test_simple_dmrg_tree():
    A1 = ad.Variable(name="A1", shape=[3, 2])
    A2 = ad.Variable(name="A2", shape=[3, 3, 2])
    A3 = ad.Variable(name="A3", shape=[3, 2])

    X1 = ad.Variable(name="X1", shape=[3, 2, 2])
    X2 = ad.Variable(name="X2", shape=[3, 3, 2, 2])
    X3 = ad.Variable(name="X3", shape=[3, 2, 2])
    """
        The network and indices positions are as follows:

        A1 - f - A2 - g - A3
        |        |        |
        c        d        e
        |        |        |
        X1 - a - X2 - b - X3
        |        |        |
        h        i        j
        |        |        |
        A1 - k - A2 - l - A3

    """
    einsum_node_A1 = ad.einsum("ach,abdi,bej,fgd,kli,ge,lj->fckh", X1, X2, X3,
                               A2, A2, A3, A3)
    einsum_node_A2 = ad.einsum("ach,abdi,bej,fc,kh,ge,lj->fgdkli", X1, X2, X3,
                               A1, A1, A3, A3)
    einsum_node_A3 = ad.einsum("ach,abdi,bej,fc,kh,fgd,kli->gelj", X1, X2, X3,
                               A1, A1, A2, A2)

    dt = generate_sequential_optimal_tree(
        [einsum_node_A1, einsum_node_A2, einsum_node_A3], [A1, A2, A3])

    assert tree_eq(dt[0], einsum_node_A1, [X1, X2, X3, A1, A1, A2, A2, A3, A3])
    assert tree_eq(dt[1], einsum_node_A2, [X1, X2, X3, A1, A1, A2, A2, A3, A3])

    # In the correct contraction path, only X3 should be contracted with A3,
    # all other X nodes should be contracted later.
    einsum_inputs = list(
        filter(lambda node: isinstance(node, ad.EinsumNode),
               find_topo_sort(dt)))
    assert sorted(einsum_inputs[0].inputs,
                  key=lambda node: node.name) == sorted(
                      [A3, A3, X3], key=lambda node: node.name)
Exemple #15
0
def test_mps(backendopt):
    for datatype in backendopt:
        T.set_backend(datatype)

        mps_graph = MpsGraph.create(4, ranks=[5, 6, 7])
        executor = ad.Executor([mps_graph.output])

        expect_mps = ad.einsum('ab,acd,cef,eg->bdfg', *mps_graph.inputs)

        assert tree_eq(mps_graph.output, expect_mps, mps_graph.inputs)
Exemple #16
0
def test_mpo(backendopt):
    for datatype in backendopt:
        T.set_backend(datatype)

        mpo_graph = MpoGraph.create(4, ranks=[5, 6, 7])
        executor = ad.Executor([mpo_graph.output])

        expect_mpo = ad.einsum('abc,adef,dghi,gjk->behjcfik',
                               *mpo_graph.inputs)

        assert tree_eq(mpo_graph.output, expect_mpo, mpo_graph.inputs)
Exemple #17
0
def test_mul_by_const_float(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        x = ad.Variable(name="x", shape=[3])
        y1 = ad.sum(5 * x)
        y2 = ad.sum(5.0 * x)

        assert y1.name == y2.name
        assert tree_eq(y1, y2, [x])
Exemple #18
0
def test_split_einsum_dup(backendopt):
    for datatype in backendopt:

        A = ad.Variable(name="A", shape=[2, 2])
        B = ad.Variable(name="B", shape=[2, 2])

        einsum_node = ad.einsum("ab,bc,cd->ad", A, B, B)
        split_input_nodes = [A]
        new_einsum = split_einsum(einsum_node, split_input_nodes)

        assert len(new_einsum.inputs) == 2  # A, einsum(B, B)
        assert tree_eq(new_einsum, einsum_node, [A, B])
def test_einsum_fuse_only_identity(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        es_identity = ad.einsum('ik,kj->ij', ad.identity(3), ad.identity(3))
        out = ad.einsum('ai,ij->aj', ad.identity(3), es_identity)

        tree, = find_sub_einsumtree(PseudoNode(out))
        out, ins = tree
        new_out = fuse_einsums(out.node, ins)
        assert tree_eq(out.node, new_out, [])
def test_distribute_dup(dist_op, backendopt):
    from graph_ops.graph_transformer import distribute_graph_w_linearize

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[3, 3])
        b = ad.Variable(name="b", shape=[3, 3])
        c = ad.Variable(name="c", shape=[3, 3])

        output = ad.einsum("ab,ab->", dist_op(a, c), dist_op(a, c))
        new_output = distribute_graph_w_linearize(output)
        assert tree_eq(output, new_output, [a, c])
def test_einsum_simple_rewrite(backendopt):
    """
        Rewrite the einsum expression.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a1 = ad.Variable(name="a1", shape=[3, 2])
        a2 = ad.Variable(name="a2", shape=[2, 3])
        x = ad.einsum('ik,kj->ij', a1, a2)
        x_new = fuse_einsums(x, [a1, a2])
        assert tree_eq(x, x_new, [a1, a2])
Exemple #22
0
def test_prune_inv_nodes_transpose(backendopt):
    for datatype in backendopt:
        A = ad.Variable(name="A", shape=[2, 2])
        B = ad.Variable(name="B", shape=[2, 2])

        inv_input = ad.einsum('ab,bc->ca', A, B)
        # inv(inv_input.T) @ inv_input.T
        output = ad.einsum('ca,cd,de->ae', ad.tensorinv(inv_input, ind=1), A,
                           B)
        new_output = prune_inv_node(output)

        assert isinstance(new_output, ad.IdentityNode)
        assert tree_eq(output, new_output, [A, B], tol=1e-6)
def test_prune_scalar_nodes(backendopt):
    for datatype in backendopt:
        T.set_backend(datatype)

        a1 = ad.Variable(name="a1", shape=[3, 3])
        a2 = ad.Variable(name="a2", shape=[3, 3])
        s = ad.scalar(3.)

        out = ad.einsum("ab,,ab->ab", a1, s, a2)
        out_prune = prune_scalar_nodes(out)

        assert isinstance(out_prune, ad.MulByConstNode)
        assert tree_eq(out, out_prune, [a1, a2])
Exemple #24
0
def test_split_einsum(backendopt):

    A = ad.Variable(name="A", shape=[2, 2])
    B = ad.Variable(name="B", shape=[2, 2])
    C = ad.Variable(name="C", shape=[2, 2])
    D = ad.Variable(name="D", shape=[2, 2])
    E = ad.Variable(name="E", shape=[2, 2])

    einsum_node = ad.einsum("ab,bc,cd,de,ef->af", A, B, C, D, E)
    split_input_nodes = [A, B]
    new_einsum = split_einsum(einsum_node, split_input_nodes)
    assert len(new_einsum.inputs) == 3  # A, B, einsum(C, D, E)
    assert tree_eq(new_einsum, einsum_node, [A, B, C, D, E])
Exemple #25
0
def test_dimension_tree_4d():
    A = ad.Variable(name="A", shape=[2, 2])
    B = ad.Variable(name="B", shape=[2, 2])
    C = ad.Variable(name="C", shape=[2, 2])
    D = ad.Variable(name="D", shape=[2, 2])
    X = ad.Variable(name="X", shape=[2, 2, 2, 2])

    einsum_node_A = ad.einsum("abcd,bm,cm,dm->am", X, B, C, D)
    einsum_node_B = ad.einsum("abcd,am,cm,dm->bm", X, A, C, D)
    einsum_node_C = ad.einsum("abcd,am,bm,dm->cm", X, A, B, D)
    einsum_node_D = ad.einsum("abcd,am,bm,cm->dm", X, A, B, C)

    dt = generate_sequential_optimal_tree(
        [einsum_node_A, einsum_node_B, einsum_node_C, einsum_node_D],
        [A, B, C, D])

    # 5 inputs, 4 outputs, 5 intermedaites
    assert len(find_topo_sort(dt)) == 14

    assert tree_eq(dt[0], einsum_node_A, [A, B, C, D, X])
    assert tree_eq(dt[1], einsum_node_B, [A, B, C, D, X])
    assert tree_eq(dt[2], einsum_node_C, [A, B, C, D, X])
    assert tree_eq(dt[3], einsum_node_D, [A, B, C, D, X])
def test_simplify_inv_w_identity(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        A = ad.Variable(name="A", shape=[2, 2])

        out = ad.einsum("ab,cd->acbd", A, ad.tensorinv(ad.identity(3)))
        newout = simplify(out)

        assert isinstance(newout, ad.EinsumNode)
        assert isinstance(newout.inputs[1], ad.IdentityNode)

        assert tree_eq(out, newout, [A], tol=1e-6)
def test_fuse_subgraph(backendopt):
    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[2, 2])
        b = ad.Variable(name="b", shape=[2, 2])
        c = ad.Variable(name="c", shape=[2, 2])
        d = ad.Variable(name="d", shape=[2, 2])

        ab = ad.einsum("ab,bc->ac", a, b)
        abc = ad.einsum("ab,bc->ac", ab, c)
        abcd = ad.einsum("ab,bc->ac", abc, d)

        out_new = fuse_einsums(abcd, [ab, c, d])
        assert tree_eq(abcd, out_new, [a, b, c, d])
def test_simplify_inv_w_redundent_einsum(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        A = ad.Variable(name="A", shape=[2, 2])

        out = ad.einsum("ab,cd->abcd", A, ad.tensorinv(ad.einsum("ab->ab", A)))
        newout = simplify(out)

        inv_node = newout.inputs[1]

        assert isinstance(inv_node.inputs[0], ad.VariableNode)

        assert tree_eq(out, newout, [A], tol=1e-6)
def test_prune_identity_w_dup(backendopt):
    for datatype in backendopt:
        T.set_backend(datatype)

        a1 = ad.Variable(name="a1", shape=[3, 3])
        i1 = ad.identity(3)
        i2 = ad.identity(3)
        i3 = ad.identity(3)

        out = ad.einsum("ab,bc,cd,de,ef->af", a1, a1, i1, i2, i3)
        prune_identity_nodes(out)
        out_expect = ad.einsum("ab,bc->ac", a1, a1)
        assert len(out.inputs) == 2

        assert tree_eq(out, out_expect, [a1])
Exemple #30
0
def test_kronecker_product_repeated_inputs(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        A = ad.Variable(name="A", shape=[2, 2])

        out = ad.einsum("ab,cd->acbd", A, A)
        inv = ad.tensorinv(out)
        newinv = optimize_inverse(inv)

        assert isinstance(newinv, ad.EinsumNode)
        for node in newinv.inputs:
            assert isinstance(node, ad.TensorInverseNode)

        assert tree_eq(inv, newinv, [A], tol=1e-5)