Exemplo n.º 1
0
def test_einsum_fuse_graph(backendopt):
    """
        [Fuse einsum used twice]
        This case is rather subtle.
        We want to auto fuse
            A   B   C 
            |    \ /  
            |     es  
            |    /|   
            |  /  |   
            es    |   
              \   |   
                es
        Here es is einsum.
    """

    for datatype in backendopt:
        T.set_backend(datatype)
        a = ad.Variable(name="a", shape=[3, 3])
        b = ad.Variable(name="b", shape=[3, 2])
        c = ad.Variable(name="c", shape=[2, 3])

        BC = ad.einsum('ik, kj->ij', b, c)  # 3x3

        ABC = ad.einsum('ik, kj->ij', a, BC)  # 3x3

        out = ad.einsum('jk, ki->ji', ABC, BC)  # 3x3

        linearize(out)
        tree, = find_sub_einsumtree(PseudoNode(out))
        out, ins = tree
        new_z = fuse_einsums(out.node, ins)

        assert tree_eq(out.node, new_z, [a, b, c])
Exemplo n.º 2
0
def test_einsum_multitier(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        input_nodes1, zs1 = get_tree("set1")
        input_nodes2, zs2 = get_tree("set2")
        out1 = zs1 + zs2

        input_nodes3, zs3 = get_tree("set3")
        input_nodes4, zs4 = get_tree("set4")
        out2 = zs3 + zs4
        out = ad.einsum("ij, jk->ik", out1, out2)
        input_nodes = input_nodes1 + input_nodes2 + input_nodes3 + input_nodes4

        generated_feed_dict = gen_dict(input_nodes)

        executor = ad.Executor([out])
        z_val, = executor.run(feed_dict=generated_feed_dict)

        with OutputInjectedModeP(find_topo_sort_p([PseudoNode(out)])):
            trees = find_sub_einsumtree(PseudoNode(out))
            for tree in trees:
                out_node, in_nodes = tree
                new_z = fuse_einsums(out_node.node, in_nodes)
                replace_node(out_node, new_z)

        executor = ad.Executor([out])
        z_new_val, = executor.run(feed_dict=generated_feed_dict)

        assert float_eq(z_val, z_new_val)
Exemplo n.º 3
0
def test_einsum_multiuse_auto_copy(backendopt):
    """
        Test autolinearization and auto fuse.
        A    B   inputs 
        |\   |
        | \  |
        |  \ |
        |   C
        |  / 
        | /
        output

        Next: we would need to autoprune.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a1", shape=[3, 2])
        b = ad.Variable(name="b", shape=[2, 3])

        c = ad.einsum('ik,kj->ij', a, b)
        output = ad.einsum('ik,ij->kj', a, c)

        linearize(output)
        all_nodes = find_topo_sort([output])
        cloned_nodes = [
            tmp for tmp in all_nodes if isinstance(tmp, ad.CloneNode)
        ]

        out_new = fuse_einsums(output, [*cloned_nodes, b])
        # Test that every inputs is now fused.
        assert all([not isinstance(x, ad.EinsumNode) for x in out_new.inputs])

        assert tree_eq(output, out_new, [*cloned_nodes, b])
Exemplo n.º 4
0
def test_einsum_multiuse(backendopt):
    """
        Test manual fuse.
        A    B   inputs 
        |\   |
        | \  |
        |  \ |
        |   C
        |  / 
        | /
        output

        Note that here we assume A is split into 2 vars by some other operations.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a1", shape=[3, 2])
        a_copy = ad.Variable(name="a2", shape=[3, 2])
        b = ad.Variable(name="b", shape=[2, 3])

        c = ad.einsum('ik,kj->ij', a, b)
        output = ad.einsum('ik,ij->kj', a_copy, c)
        # New graph
        out_new = fuse_einsums(output, [a, a_copy, b])
        assert tree_eq(output, out_new, [a, a_copy, b])
Exemplo n.º 5
0
def test_einsum_fuse_w_identity(backendopt):
    """
        [Fuse einsum with multiple identities]
        We want to fuse
            A   identity  identity
            |    \       /
            |     \     /
            |      \   /
            |       es
            |     /
            |    /
            |   /
            |  /
            es
        Here es is einsum.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[3, 3])
        es_identity = ad.einsum('ik,kj->ij', ad.identity(3), ad.identity(3))
        out = ad.einsum('ai,ij->aj', a, es_identity)

        tree, = find_sub_einsumtree(PseudoNode(out))
        out, ins = tree
        new_out = fuse_einsums(out.node, ins)
        assert tree_eq(out.node, new_out, [a])
Exemplo n.º 6
0
def test_einsum(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        input_nodes, z = get_tree()
        z_new = fuse_einsums(z, input_nodes)
        assert tree_eq(z, z_new, input_nodes)
Exemplo n.º 7
0
def test_einsum_fuse_only_identity(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        es_identity = ad.einsum('ik,kj->ij', ad.identity(3), ad.identity(3))
        out = ad.einsum('ai,ij->aj', ad.identity(3), es_identity)

        tree, = find_sub_einsumtree(PseudoNode(out))
        out, ins = tree
        new_out = fuse_einsums(out.node, ins)
        assert tree_eq(out.node, new_out, [])
Exemplo n.º 8
0
def test_einsum_simple_rewrite(backendopt):
    """
        Rewrite the einsum expression.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a1 = ad.Variable(name="a1", shape=[3, 2])
        a2 = ad.Variable(name="a2", shape=[2, 3])
        x = ad.einsum('ik,kj->ij', a1, a2)
        x_new = fuse_einsums(x, [a1, a2])
        assert tree_eq(x, x_new, [a1, a2])
Exemplo n.º 9
0
def test_einsum_subtree_clone(backendopt):
    """
        [Subtree clone]
        This case is rather subtle.
        We want to auto fuse
            A   B   C   D
            |    \ /    |
            |     es    |
            |    /  \   |
            |  /      \ |
            es         es
              \       /
                  +

        Here es is einsum.
    """

    for datatype in backendopt:
        T.set_backend(datatype)
        a = ad.Variable(name="a", shape=[3, 3])
        b = ad.Variable(name="b", shape=[3, 2])
        c = ad.Variable(name="c", shape=[2, 3])
        d = ad.Variable(name="d", shape=[3, 3])

        BC = ad.einsum('ik, kj->ij', b, c)  # 3x3

        ABC = ad.einsum('ik, kj->ij', a, BC)  # 3x3

        BCD = ad.einsum('jk, ki->ji', BC, d)  # 3x3

        out = ABC + BCD

        input_nodes = [a, b, c, d]
        generated_feed_dict = gen_dict(input_nodes)

        executor = ad.Executor([out])
        out_val, = executor.run(feed_dict=generated_feed_dict)

        with OutputInjectedModeP(find_topo_sort_p([PseudoNode(out)])):
            trees = find_sub_einsumtree(PseudoNode(out))
            assert len(trees) == 2
            for tree in trees:
                out_node, in_nodes = tree
                new_z = fuse_einsums(out_node.node, in_nodes)
                replace_node(out_node, new_z)

        new_out_val, = executor.run(feed_dict=generated_feed_dict)

        assert float_eq(out_val, new_out_val)
Exemplo n.º 10
0
def test_fuse_subgraph(backendopt):
    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[2, 2])
        b = ad.Variable(name="b", shape=[2, 2])
        c = ad.Variable(name="c", shape=[2, 2])
        d = ad.Variable(name="d", shape=[2, 2])

        ab = ad.einsum("ab,bc->ac", a, b)
        abc = ad.einsum("ab,bc->ac", ab, c)
        abcd = ad.einsum("ab,bc->ac", abc, d)

        out_new = fuse_einsums(abcd, [ab, c, d])
        assert tree_eq(abcd, out_new, [a, b, c, d])