Example #1
0
def test_einsum_multitier(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        input_nodes1, zs1 = get_tree("set1")
        input_nodes2, zs2 = get_tree("set2")
        out1 = zs1 + zs2

        input_nodes3, zs3 = get_tree("set3")
        input_nodes4, zs4 = get_tree("set4")
        out2 = zs3 + zs4
        out = ad.einsum("ij, jk->ik", out1, out2)
        input_nodes = input_nodes1 + input_nodes2 + input_nodes3 + input_nodes4

        generated_feed_dict = gen_dict(input_nodes)

        executor = ad.Executor([out])
        z_val, = executor.run(feed_dict=generated_feed_dict)

        with OutputInjectedModeP(find_topo_sort_p([PseudoNode(out)])):
            trees = find_sub_einsumtree(PseudoNode(out))
            for tree in trees:
                out_node, in_nodes = tree
                new_z = fuse_einsums(out_node.node, in_nodes)
                replace_node(out_node, new_z)

        executor = ad.Executor([out])
        z_new_val, = executor.run(feed_dict=generated_feed_dict)

        assert float_eq(z_val, z_new_val)
Example #2
0
def test_einsum_fuse_graph(backendopt):
    """
        [Fuse einsum used twice]
        This case is rather subtle.
        We want to auto fuse
            A   B   C 
            |    \ /  
            |     es  
            |    /|   
            |  /  |   
            es    |   
              \   |   
                es
        Here es is einsum.
    """

    for datatype in backendopt:
        T.set_backend(datatype)
        a = ad.Variable(name="a", shape=[3, 3])
        b = ad.Variable(name="b", shape=[3, 2])
        c = ad.Variable(name="c", shape=[2, 3])

        BC = ad.einsum('ik, kj->ij', b, c)  # 3x3

        ABC = ad.einsum('ik, kj->ij', a, BC)  # 3x3

        out = ad.einsum('jk, ki->ji', ABC, BC)  # 3x3

        linearize(out)
        tree, = find_sub_einsumtree(PseudoNode(out))
        out, ins = tree
        new_z = fuse_einsums(out.node, ins)

        assert tree_eq(out.node, new_z, [a, b, c])
Example #3
0
def test_einsum_multiuse(backendopt):
    """
        Test manual fuse.
        A    B   inputs 
        |\   |
        | \  |
        |  \ |
        |   C
        |  / 
        | /
        output

        Note that here we assume A is split into 2 vars by some other operations.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a1", shape=[3, 2])
        a_copy = ad.Variable(name="a2", shape=[3, 2])
        b = ad.Variable(name="b", shape=[2, 3])

        c = ad.einsum('ik,kj->ij', a, b)
        output = ad.einsum('ik,ij->kj', a_copy, c)
        # New graph
        out_new = fuse_einsums(output, [a, a_copy, b])
        assert tree_eq(output, out_new, [a, a_copy, b])
Example #4
0
def test_einsum_multiuse_auto_copy(backendopt):
    """
        Test autolinearization and auto fuse.
        A    B   inputs 
        |\   |
        | \  |
        |  \ |
        |   C
        |  / 
        | /
        output

        Next: we would need to autoprune.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a1", shape=[3, 2])
        b = ad.Variable(name="b", shape=[2, 3])

        c = ad.einsum('ik,kj->ij', a, b)
        output = ad.einsum('ik,ij->kj', a, c)

        linearize(output)
        all_nodes = find_topo_sort([output])
        cloned_nodes = [
            tmp for tmp in all_nodes if isinstance(tmp, ad.CloneNode)
        ]

        out_new = fuse_einsums(output, [*cloned_nodes, b])
        # Test that every inputs is now fused.
        assert all([not isinstance(x, ad.EinsumNode) for x in out_new.inputs])

        assert tree_eq(output, out_new, [*cloned_nodes, b])
Example #5
0
def test_einsum_fuse_w_identity(backendopt):
    """
        [Fuse einsum with multiple identities]
        We want to fuse
            A   identity  identity
            |    \       /
            |     \     /
            |      \   /
            |       es
            |     /
            |    /
            |   /
            |  /
            es
        Here es is einsum.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[3, 3])
        es_identity = ad.einsum('ik,kj->ij', ad.identity(3), ad.identity(3))
        out = ad.einsum('ai,ij->aj', a, es_identity)

        tree, = find_sub_einsumtree(PseudoNode(out))
        out, ins = tree
        new_out = fuse_einsums(out.node, ins)
        assert tree_eq(out.node, new_out, [a])
Example #6
0
def optimize(node):
    """Optimize a graph with a single output node.

    Args:
        node: The output node.
    Returns:
        node: The newly generated node.
    """
    node = distribute_tree(node)
    linearize(node)

    all_nodes = find_topo_sort([node])
    ret_node = PseudoNode(node)
    with OutputInjectedMode(all_nodes):
        trees = find_sub_einsumtree(ret_node)
        for tree in trees:
            out_node_p, in_nodes = tree
            new_z = fuse_einsums(out_node_p.node, in_nodes)
            prune_identity_nodes(new_z)
            new_z = generate_optimal_tree(new_z)
            replace_node(out_node_p, new_z)

    node = declone(ret_node.node)
    all_nodes = find_topo_sort([node])
    for node in all_nodes:
        if isinstance(node, ad.EinsumNode):
            rewrite_einsum_expr(node)

    for node in find_topo_sort([node]):
        if node.inputs != []:
            node.set_inputs(node.inputs)

    dedup(node)
    return node
Example #7
0
def test_einsum(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        input_nodes, z = get_tree()
        z_new = fuse_einsums(z, input_nodes)
        assert tree_eq(z, z_new, input_nodes)
Example #8
0
def test_einsum_fuse_only_identity(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        es_identity = ad.einsum('ik,kj->ij', ad.identity(3), ad.identity(3))
        out = ad.einsum('ai,ij->aj', ad.identity(3), es_identity)

        tree, = find_sub_einsumtree(PseudoNode(out))
        out, ins = tree
        new_out = fuse_einsums(out.node, ins)
        assert tree_eq(out.node, new_out, [])
Example #9
0
def test_einsum_simple_rewrite(backendopt):
    """
        Rewrite the einsum expression.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a1 = ad.Variable(name="a1", shape=[3, 2])
        a2 = ad.Variable(name="a2", shape=[2, 3])
        x = ad.einsum('ik,kj->ij', a1, a2)
        x_new = fuse_einsums(x, [a1, a2])
        assert tree_eq(x, x_new, [a1, a2])
Example #10
0
def generate_optimal_tree(node, path=None):
    """Generates the descendants of the optimal path.
    
    Args:
        node: The einsum node we are interested about.
        path: If specified, will be used to generate tree.
    Returns:
        final_node: The newly generated node.
    """
    from graph_ops.graph_optimizer import fuse_einsums
    from graph_ops.graph_dedup import declone
    from graph_ops.graph_transformer import linearize

    assert isinstance(node, ad.EinsumNode)
    leaves = get_all_inputs(node)
    for leaf in leaves:
        assert (not isinstance(leaf, ad.EinsumNode))

    if path is None:
        _, contract_list = contract_path(node.einsum_subscripts,
                                         *node.inputs,
                                         einsum_call=True)
    else:
        assert len(path) > 0
        _, contract_list = contract_path(node.einsum_subscripts,
                                         *node.inputs,
                                         optimize=path,
                                         einsum_call=True)

    original_inputs = [i for i in node.inputs]
    final_node = None
    for contract in contract_list:
        indices, _, subscript, _, _ = contract
        input_nodes = [original_inputs[i] for i in indices]
        new_node = ad.einsum(subscript, *input_nodes)
        original_inputs.append(new_node)
        for i_node in input_nodes:
            original_inputs.remove(i_node)
        final_node = new_node

    # opt_einsum sometimes generate the path where the last node
    # is just transposing the indices. If this happens, then just merge
    # this node with its input node.
    if len(final_node.inputs) == 1:
        # To handle the case where duplicated inputs exist
        linearize(final_node)
        in_node = final_node.inputs[0]
        final_node = fuse_einsums(final_node, in_node.inputs)
        final_node = declone(final_node)
    return final_node
Example #11
0
    def fuse_all_einsums(node):
        linearize(node)
        ret_node = PseudoNode(node)
        all_pnodes = find_topo_sort_p([ret_node])
        with OutputInjectedModeP(all_pnodes):
            trees = find_sub_einsumtree(ret_node)
            for tree in trees:
                out_node_p, in_nodes = tree
                new_z = fuse_einsums(out_node_p.node, in_nodes)
                prune_identity_nodes(new_z)
                replace_node(out_node_p, new_z)

        node = declone(ret_node.node)
        return node
Example #12
0
def test_einsum_subtree_clone(backendopt):
    """
        [Subtree clone]
        This case is rather subtle.
        We want to auto fuse
            A   B   C   D
            |    \ /    |
            |     es    |
            |    /  \   |
            |  /      \ |
            es         es
              \       /
                  +

        Here es is einsum.
    """

    for datatype in backendopt:
        T.set_backend(datatype)
        a = ad.Variable(name="a", shape=[3, 3])
        b = ad.Variable(name="b", shape=[3, 2])
        c = ad.Variable(name="c", shape=[2, 3])
        d = ad.Variable(name="d", shape=[3, 3])

        BC = ad.einsum('ik, kj->ij', b, c)  # 3x3

        ABC = ad.einsum('ik, kj->ij', a, BC)  # 3x3

        BCD = ad.einsum('jk, ki->ji', BC, d)  # 3x3

        out = ABC + BCD

        input_nodes = [a, b, c, d]
        generated_feed_dict = gen_dict(input_nodes)

        executor = ad.Executor([out])
        out_val, = executor.run(feed_dict=generated_feed_dict)

        with OutputInjectedModeP(find_topo_sort_p([PseudoNode(out)])):
            trees = find_sub_einsumtree(PseudoNode(out))
            assert len(trees) == 2
            for tree in trees:
                out_node, in_nodes = tree
                new_z = fuse_einsums(out_node.node, in_nodes)
                replace_node(out_node, new_z)

        new_out_val, = executor.run(feed_dict=generated_feed_dict)

        assert float_eq(out_val, new_out_val)
Example #13
0
def test_fuse_subgraph(backendopt):
    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[2, 2])
        b = ad.Variable(name="b", shape=[2, 2])
        c = ad.Variable(name="c", shape=[2, 2])
        d = ad.Variable(name="d", shape=[2, 2])

        ab = ad.einsum("ab,bc->ac", a, b)
        abc = ad.einsum("ab,bc->ac", ab, c)
        abcd = ad.einsum("ab,bc->ac", abc, d)

        out_new = fuse_einsums(abcd, [ab, c, d])
        assert tree_eq(abcd, out_new, [a, b, c, d])