Пример #1
0
def generate_sequential_optimal_tree(einsum_nodes, input_nodes):
    """
    Regenerating einsum expressions based on the dimension tree.
    Parameters
    ----------
    einsum_nodes : list
        List of einsum nodes to be calculated based on the dimension tree.
    input_nodes : list
        List of input nodes whose contraction in the einsum_nodes obeys
        the sequence from the list end to the list start.

    Returns
    -------
        List of einsum nodes whose results are the same as einsum_nodes,
        while obeys the dimension tree calculation sequence.

    Examples
    --------
    >>> einsum_node_A = ad.einsum("abcd,bm,cm,dm->am", X, B, C, D)
    >>> einsum_node_B = ad.einsum("abcd,am,cm,dm->bm", X, A, C, D)
    >>> einsum_node_C = ad.einsum("abcd,am,bm,dm->cm", X, A, B, D)
    >>> dt = generate_sequential_optimal_tree([einsum_node_A, einsum_node_B, einsum_node_C], [A, B, C])
    >>> dt
    [ad.einsum('bm,abm->am', B, ad.einsum('cm,abcm->abm', C, ad.einsum('abcd,dm->abcm', X, D))),
    ad.einsum('am,abm->bm', A, ad.einsum('cm,abcm->abm', C, ad.einsum('abcd,dm->abcm', X, D))),
    ad.einsum('am,bm,abcm->cm', A, B, ad.einsum('abcd,dm->abcm', X, D)),
    ]
    (einsum strings may be different)
    """

    if len(einsum_nodes) == 1 and len(input_nodes) == 1:
        return einsum_nodes

    new_nodes = []
    for (i, node) in enumerate(einsum_nodes):
        contract_order = input_nodes[i + 1:]
        contract_order.reverse()
        contract_order = contract_order + input_nodes[:i]
        # get the subarray that is the inputs of node
        contract_order = list(
            filter(lambda n: n in node.inputs, contract_order))

        new_nodes.append(
            generate_optimal_tree_w_constraint(node, contract_order))

    # After generate_optimal_tree_w_constraint, some einstrs are not in the canonical format,
    # needs to rewrite again for dedup
    all_nodes = find_topo_sort(new_nodes)
    with OutputInjectedMode(all_nodes):
        for node in all_nodes:
            if isinstance(node, ad.EinsumNode):
                rewrite_einsum_expr(node)
            if node.inputs != []:
                node.set_inputs(node.inputs)

    dedup(*new_nodes)
    remove_transposes(find_topo_sort(new_nodes))
    return new_nodes
Пример #2
0
def test_einsum_rewrite_duplicate_input(backendopt):

    a = ad.Variable(name="a", shape=[3, 2])

    x = ad.einsum('ca,cb->ab', a, a)
    y = ad.einsum('cb,ca->ab', a, a)

    rewrite_einsum_expr(x)
    rewrite_einsum_expr(y)

    assert x.einsum_subscripts == y.einsum_subscripts
Пример #3
0
def test_einsum_equal_repeated_transpose(backendopt):

    A = ad.Variable(name="A", shape=[3, 5])

    x = ad.einsum('or,ob->br', A, A)
    y = ad.einsum('eb,ed->bd', A, A)

    uf1 = rewrite_einsum_expr(x)
    uf2 = rewrite_einsum_expr(y)

    assert x.einsum_subscripts == y.einsum_subscripts
    assert x.inputs == y.inputs
Пример #4
0
def test_einsum_equal_repeated_transpose(backendopt):

    A = ad.Variable(name="A", shape=[3, 3])
    B = ad.Variable(name="B", shape=[3, 3])

    x = ad.einsum("ac,ba,bc->", A, A, B)
    y = ad.einsum("ba,ac,bc->", A, A, B)

    uf1 = rewrite_einsum_expr(x)
    uf2 = rewrite_einsum_expr(y)

    assert x.einsum_subscripts == y.einsum_subscripts
    assert x.inputs == y.inputs
Пример #5
0
def test_einsum_equal(backendopt):

    a1 = ad.Variable(name="a1", shape=[3, 2])
    a2 = ad.Variable(name="a2", shape=[2, 3])

    x = ad.einsum('ik,kj->ij', a1, a2)
    y = ad.einsum('ml,sm->sl', a2, a1)

    rewrite_einsum_expr(x)
    rewrite_einsum_expr(y)

    assert x.einsum_subscripts == y.einsum_subscripts
    assert x.inputs == y.inputs
Пример #6
0
def test_einsum_equal_uf_assign_order(backendopt):

    A = ad.Variable(name="A", shape=[3, 3])
    B = ad.Variable(name="B", shape=[3, 3])
    I = ad.identity(10)

    x = ad.einsum('pb,or,ob,pr,st->srtb', B, A, A, B, I)
    y = ad.einsum('eb,ed,fb,fd,ac->abcd', A, A, B, B, I)

    uf1 = rewrite_einsum_expr(x)
    uf2 = rewrite_einsum_expr(y)

    assert x.einsum_subscripts == y.einsum_subscripts
    assert x.inputs == y.inputs
Пример #7
0
def test_rewrite_expr(backendopt):
    """
        Test rewrite the einsum expression.
    """

    a1 = ad.Variable(name="a1", shape=[3, 2])
    a2 = ad.Variable(name="a2", shape=[2, 3])

    x = ad.einsum('ik,kj->ij', a1, a2)

    y = ad.einsum('sm,ml->sl', a1, a2)

    rewrite_einsum_expr(x)
    rewrite_einsum_expr(y)
    assert x.einsum_subscripts == y.einsum_subscripts
Пример #8
0
def prune_single_inv_node(einsum_node, inv_node):
    """
    Prune the inv_node in the einsum node if condition mets.

    Note:
    1. can only optimize the node when the input of inv is an einsum node.
    2. only supports the case when the splitted nodes are different from the remaining ones.
        For example: ad.einsum("ab,bc,cd,de->ae", inv("ab,bc->ac", A, B), A, B, C) will be
        optimzied to ad.einsum("ab,bc->ac", C, ad.identity()),
        but we cannot optimize ad.einsum("ab,bc,cd,de->ae", inv("ab,bc->ac", A, B), A, B, B).

    Parameters
    ----------
    einsum_node: The fused einsum node
    inv_node: the input inv node to be pruned

    Returns
    -------
    If the einsum_node cannot be optimized, then return the input einsum_node.
    If it can be optimized, return the optimized einsum node.

    """
    from graph_ops.graph_transformer import rewrite_einsum_expr
    from graph_ops.graph_generator import split_einsum

    inv_node_input = inv_node.inputs[0]
    if not isinstance(inv_node_input, ad.EinsumNode):
        logger.info(f"inv input is not einsum node, can't prune inv")
        return einsum_node

    if not set(inv_node_input.inputs).issubset(set(einsum_node.inputs)):
        logger.info(
            f"inv inputs is not subset of einsum node inputs, can't prune inv")
        return einsum_node

    einsum_inputs_in_inv = [
        n for n in einsum_node.inputs if n in inv_node_input.inputs
    ]
    if len(einsum_inputs_in_inv) < len(inv_node_input.inputs):
        logger.info(
            f"number of inv inputs is more than corresponding einsum inputs, can't prune inv"
        )
        return einsum_node

    split_einsum_node = split_einsum(
        einsum_node,
        list(set(einsum_node.inputs) - set(inv_node_input.inputs)))

    # Assign pseudo nodes and chars
    in_subs, out_subs, _ = _parse_einsum_input(
        (split_einsum_node.einsum_subscripts, *split_einsum_node.inputs))
    in_subs_list = in_subs.split(',')

    updated_p_in_nodes = []
    for i, node in enumerate(split_einsum_node.inputs):
        if isinstance(node, ad.EinsumNode):
            p_einsum_input = PseudoNode(node=node, subscript=in_subs_list[i])
        elif node is inv_node:
            p_inv_input = PseudoNode(node=node, subscript=in_subs_list[i])
        else:
            updated_p_in_nodes.append(
                PseudoNode(node=node, subscript=in_subs_list[i]))

    contract_char = "".join(
        set(p_einsum_input.subscript) & set(p_inv_input.subscript))
    uncontract_str = "".join(
        set("".join([p_einsum_input.subscript, p_inv_input.subscript])) -
        set(contract_char))

    if not (len(p_einsum_input.subscript) == 2 and len(p_inv_input.subscript)
            == 2 and len(contract_char) == 1 and len(uncontract_str) == 2):
        # this is not a matmul. Just return the initial node
        logger.info(
            f"the op between inv input and the selected einsum is not matmul, can't prune inv"
        )
        return einsum_node

    if p_einsum_input.subscript[0] == p_inv_input.subscript[
            0] or p_einsum_input.subscript[1] == p_inv_input.subscript[1]:
        # the str is like "ab,ac", and one einsum needs to be transposed to compare
        p_in_subs, p_out_subs, _ = _parse_einsum_input(
            (p_einsum_input.node.einsum_subscripts,
             *p_einsum_input.node.inputs))
        einsum_input = ad.einsum(
            f"{p_in_subs}->{p_out_subs[1]}{p_out_subs[0]}",
            *p_einsum_input.node.inputs)
    else:
        einsum_input = p_einsum_input.node

    rewrite_einsum_expr(einsum_input)
    rewrite_einsum_expr(inv_node_input)

    if einsum_input.name != inv_node_input.name:
        logger.info(
            f"inv input and the selected einsum have different expressions, can't prune inv"
        )
        return einsum_node

    # prune the inv node
    updated_p_in_nodes = updated_p_in_nodes + [
        PseudoNode(node=ad.identity(inv_node_input.shape[0]),
                   subscript=uncontract_str)
    ]

    return generate_new_einsum(updated_p_in_nodes, out_subs)