コード例 #1
0
def test_einsum_gen_corner_case(backendopt):
    """
    Note: Numpy contraction path cannot find the opt path for this expression.
        It will output the same expression as the input.
    --------    E    --------
    |       |       |       |
    a       b       c       d
    |       |       |       |
    A - e - B - f - C - g - D
    |       |       |       |
    h       i       j       k
    |       |       |       |
    """
    size = 5
    A = ad.Variable(name="A", shape=[size, size, size])
    B = ad.Variable(name="B", shape=[size, size, size, size])
    C = ad.Variable(name="C", shape=[size, size, size, size])
    D = ad.Variable(name="D", shape=[size, size, size])
    E = ad.Variable(name="E", shape=[size, size, size, size])

    output = ad.einsum('aeh,bfie,cgjf,dgk,abcd->hijk', A, B, C, D, E)
    new_output = generate_optimal_tree(output)

    for node in find_topo_sort([new_output]):
        if not isinstance(node, ad.VariableNode):
            assert (len(node.inputs) == 2)
コード例 #2
0
def test_get_common_ancestor(backendopt):

    A = ad.Variable(name="A", shape=[3, 2])

    X1 = ad.Variable(name="X1", shape=[3, 2, 2])
    X2 = ad.Variable(name="X2", shape=[3, 3, 2, 2])
    X3 = ad.Variable(name="X3", shape=[3, 2, 2])
    """
        The network and indices positions are as follows:

                      g - A
                          |
        c        d        e
        |        |        |
        X1 - a - X2 - b - X3
        |        |        |
        h        i        j
                          |
                      l - A

    """
    einsum_node = ad.einsum('lj,ge,bej,abdi,ach->cdhigl', A, A, X3, X2, X1)
    opt_einsum = generate_optimal_tree(einsum_node)
    sub_einsum = get_common_ancestor(opt_einsum, einsum_node.inputs, A)

    assert sorted(get_all_inputs(sub_einsum),
                  key=lambda node: node.name) == sorted(
                      [A, A, X3], key=lambda node: node.name)
コード例 #3
0
def optimize(node):
    """Optimize a graph with a single output node.

    Args:
        node: The output node.
    Returns:
        node: The newly generated node.
    """
    node = distribute_tree(node)
    linearize(node)

    all_nodes = find_topo_sort([node])
    ret_node = PseudoNode(node)
    with OutputInjectedMode(all_nodes):
        trees = find_sub_einsumtree(ret_node)
        for tree in trees:
            out_node_p, in_nodes = tree
            new_z = fuse_einsums(out_node_p.node, in_nodes)
            prune_identity_nodes(new_z)
            new_z = generate_optimal_tree(new_z)
            replace_node(out_node_p, new_z)

    node = declone(ret_node.node)
    all_nodes = find_topo_sort([node])
    for node in all_nodes:
        if isinstance(node, ad.EinsumNode):
            rewrite_einsum_expr(node)

    for node in find_topo_sort([node]):
        if node.inputs != []:
            node.set_inputs(node.inputs)

    dedup(node)
    return node
コード例 #4
0
def test_einsum_gen_custom(backendopt):
    for datatype in backendopt:
        a = ad.Variable(name="a", shape=[2, 2])
        b = ad.Variable(name="b", shape=[2, 5])
        c = ad.Variable(name="c", shape=[5, 2])

        output = ad.einsum('ij,jk,kl->il', a, b, c)
        new_output = generate_optimal_tree(output, path=[(1, 2), (0, 1)])
        assert tree_eq(output, new_output, [a, b, c])
コード例 #5
0
def test_einsum_gen(backendopt):
    for datatype in backendopt:
        N = 10
        C = ad.Variable(name="C", shape=[N, N])
        I = ad.Variable(name="I", shape=[N, N, N, N])

        output = ad.einsum('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C)
        new_output = generate_optimal_tree(output)
        assert tree_eq(output, new_output, [C, I])
        assert len(new_output.inputs) == 2
コード例 #6
0
def test_get_common_ancestor_w_inv(backendopt):

    A = ad.Variable(name="A", shape=[3, 3])
    X = ad.Variable(name="X", shape=[3, 3, 3])
    inv = ad.tensorinv(ad.einsum("ab,ac->bc", A, A), ind=1)
    einsum_node = ad.einsum('abc,ad,ce->bce', X, A, inv)
    opt_einsum = generate_optimal_tree(einsum_node)
    sub_einsum = get_common_ancestor(opt_einsum, einsum_node.inputs, A)

    # sub_einsum should be ad.einsum('ad,abc->dbc',A,X), and shouldn't include the inv node.
    assert sorted(get_all_inputs(sub_einsum),
                  key=lambda node: node.name) == sorted(
                      [A, X], key=lambda node: node.name)
コード例 #7
0
def test_get_common_ancestor_simple(backendopt):

    A = ad.Variable(name="A", shape=[3, 2])

    X1 = ad.Variable(name="X1", shape=[3, 4, 4])
    X2 = ad.Variable(name="X2", shape=[3, 2, 2])
    """
        The network and indices positions are as follows:

             g - A
                 |
        d        e
        |        |
        X1 - b - X2
        |        |
        i        j
    """
    einsum_node = ad.einsum('ge,bdi,bej->gdij', A, X1, X2)
    opt_einsum = generate_optimal_tree(einsum_node)
    sub_einsum = get_common_ancestor(opt_einsum, einsum_node.inputs, A)

    assert sorted(get_all_inputs(sub_einsum),
                  key=lambda node: node.name) == sorted(
                      [A, X2], key=lambda node: node.name)