def test_einsum_fuse_graph(backendopt):
    """
        [Fuse einsum used twice]
        This case is rather subtle.
        We want to auto fuse
            A   B   C 
            |    \ /  
            |     es  
            |    /|   
            |  /  |   
            es    |   
              \   |   
                es
        Here es is einsum.
    """

    for datatype in backendopt:
        T.set_backend(datatype)
        a = ad.Variable(name="a", shape=[3, 3])
        b = ad.Variable(name="b", shape=[3, 2])
        c = ad.Variable(name="c", shape=[2, 3])

        BC = ad.einsum('ik, kj->ij', b, c)  # 3x3

        ABC = ad.einsum('ik, kj->ij', a, BC)  # 3x3

        out = ad.einsum('jk, ki->ji', ABC, BC)  # 3x3

        linearize(out)
        tree, = find_sub_einsumtree(PseudoNode(out))
        out, ins = tree
        new_z = fuse_einsums(out.node, ins)

        assert tree_eq(out.node, new_z, [a, b, c])
def test_einsum_multiuse_auto_copy(backendopt):
    """
        Test autolinearization and auto fuse.
        A    B   inputs 
        |\   |
        | \  |
        |  \ |
        |   C
        |  / 
        | /
        output

        Next: we would need to autoprune.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a1", shape=[3, 2])
        b = ad.Variable(name="b", shape=[2, 3])

        c = ad.einsum('ik,kj->ij', a, b)
        output = ad.einsum('ik,ij->kj', a, c)

        linearize(output)
        all_nodes = find_topo_sort([output])
        cloned_nodes = [
            tmp for tmp in all_nodes if isinstance(tmp, ad.CloneNode)
        ]

        out_new = fuse_einsums(output, [*cloned_nodes, b])
        # Test that every inputs is now fused.
        assert all([not isinstance(x, ad.EinsumNode) for x in out_new.inputs])

        assert tree_eq(output, out_new, [*cloned_nodes, b])
def test_linearization_multiple_same_output(backendopt):
    """
        An einsum graph like
        A      inputs
        |\
        | \
        |  \
        |   |
        |  /
        | /
        output

        will produce

        An einsum graph like
        A      inputs
        |\
        | \
        |  \
        A1  A2
        |  /
        | /
        output

        The subtree inputs must then be [A1, A2] rather than A.
    """
    x = ad.Variable(name="x", shape=[3])
    y = ad.einsum("i,i->", x, x)
    linearize(y)
    assert len(y.inputs) == 2
Beispiel #4
0
def generate_optimal_tree(node, path=None):
    """Generates the descendants of the optimal path.
    
    Args:
        node: The einsum node we are interested about.
        path: If specified, will be used to generate tree.
    Returns:
        final_node: The newly generated node.
    """
    from graph_ops.graph_optimizer import fuse_einsums
    from graph_ops.graph_dedup import declone
    from graph_ops.graph_transformer import linearize

    assert isinstance(node, ad.EinsumNode)
    leaves = get_all_inputs(node)
    for leaf in leaves:
        assert (not isinstance(leaf, ad.EinsumNode))

    if path is None:
        _, contract_list = contract_path(node.einsum_subscripts,
                                         *node.inputs,
                                         einsum_call=True)
    else:
        assert len(path) > 0
        _, contract_list = contract_path(node.einsum_subscripts,
                                         *node.inputs,
                                         optimize=path,
                                         einsum_call=True)

    original_inputs = [i for i in node.inputs]
    final_node = None
    for contract in contract_list:
        indices, _, subscript, _, _ = contract
        input_nodes = [original_inputs[i] for i in indices]
        new_node = ad.einsum(subscript, *input_nodes)
        original_inputs.append(new_node)
        for i_node in input_nodes:
            original_inputs.remove(i_node)
        final_node = new_node

    # opt_einsum sometimes generate the path where the last node
    # is just transposing the indices. If this happens, then just merge
    # this node with its input node.
    if len(final_node.inputs) == 1:
        # To handle the case where duplicated inputs exist
        linearize(final_node)
        in_node = final_node.inputs[0]
        final_node = fuse_einsums(final_node, in_node.inputs)
        final_node = declone(final_node)
    return final_node
def test_einsum_find_subtree_after_linearization(backendopt):
    """
        An einsum graph like
        A    B   inputs 
        |\   |
        | \  |
        |  \ |
        |   C
        |  / 
        | /
        output

        will produce

        An einsum graph like
        A    B   inputs 
        |\   |
        | A1 |
        |  \ |
        A2  C
        |  / 
        | /
        output

        The subtree inputs must then be [A1, A2, B] rather than A, B.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a1", shape=[3, 2])
        b = ad.Variable(name="b", shape=[2, 3])

        c = ad.einsum('ik,kj->ij', a, b)
        output = ad.einsum('ik,ij->kj', a, c)

        feed_dict = gen_dict([a, b])
        executor = ad.Executor([output])
        out_val, = executor.run(feed_dict=feed_dict)

        # New graph
        linearize(output)
        tree, = find_sub_einsumtree(PseudoNode(output))
        assert (len(tree[1]) == 3)
def test_einsum_multiuse(backendopt):
    """
        An einsum graph like
        A    B   inputs 
        |\   |
        | \  |
        |  \ |
        |   C
        |  / 
        | /
        output

        will produce

        An einsum graph like
        A    B   inputs 
        |\   |
        | A1 |
        |  \ |
        A2  C
        |  / 
        | /
        output
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a1", shape=[3, 2])
        b = ad.Variable(name="b", shape=[2, 3])

        c = ad.einsum('ik,kj->ij', a, b)
        output = ad.einsum('ik,ij->kj', a, c)

        feed_dict = gen_dict([a, b])

        executor = ad.Executor([output])
        out_val, = executor.run(feed_dict=feed_dict)

        linearize(output)
        executor = ad.Executor([output])
        out_new_val, = executor.run(feed_dict=feed_dict)

        assert T.array_equal(out_val, out_new_val)