コード例 #1
0
def optimize(node):
    """Optimize a graph with a single output node.

    Args:
        node: The output node.
    Returns:
        node: The newly generated node.
    """
    node = distribute_tree(node)
    linearize(node)

    all_nodes = find_topo_sort([node])
    ret_node = PseudoNode(node)
    with OutputInjectedMode(all_nodes):
        trees = find_sub_einsumtree(ret_node)
        for tree in trees:
            out_node_p, in_nodes = tree
            new_z = fuse_einsums(out_node_p.node, in_nodes)
            prune_identity_nodes(new_z)
            new_z = generate_optimal_tree(new_z)
            replace_node(out_node_p, new_z)

    node = declone(ret_node.node)
    all_nodes = find_topo_sort([node])
    for node in all_nodes:
        if isinstance(node, ad.EinsumNode):
            rewrite_einsum_expr(node)

    for node in find_topo_sort([node]):
        if node.inputs != []:
            node.set_inputs(node.inputs)

    dedup(node)
    return node
コード例 #2
0
def generate_sequential_optimal_tree(einsum_nodes, input_nodes):
    """
    Regenerating einsum expressions based on the dimension tree.
    Parameters
    ----------
    einsum_nodes : list
        List of einsum nodes to be calculated based on the dimension tree.
    input_nodes : list
        List of input nodes whose contraction in the einsum_nodes obeys
        the sequence from the list end to the list start.

    Returns
    -------
        List of einsum nodes whose results are the same as einsum_nodes,
        while obeys the dimension tree calculation sequence.

    Examples
    --------
    >>> einsum_node_A = ad.einsum("abcd,bm,cm,dm->am", X, B, C, D)
    >>> einsum_node_B = ad.einsum("abcd,am,cm,dm->bm", X, A, C, D)
    >>> einsum_node_C = ad.einsum("abcd,am,bm,dm->cm", X, A, B, D)
    >>> dt = generate_sequential_optimal_tree([einsum_node_A, einsum_node_B, einsum_node_C], [A, B, C])
    >>> dt
    [ad.einsum('bm,abm->am', B, ad.einsum('cm,abcm->abm', C, ad.einsum('abcd,dm->abcm', X, D))),
    ad.einsum('am,abm->bm', A, ad.einsum('cm,abcm->abm', C, ad.einsum('abcd,dm->abcm', X, D))),
    ad.einsum('am,bm,abcm->cm', A, B, ad.einsum('abcd,dm->abcm', X, D)),
    ]
    (einsum strings may be different)
    """

    if len(einsum_nodes) == 1 and len(input_nodes) == 1:
        return einsum_nodes

    new_nodes = []
    for (i, node) in enumerate(einsum_nodes):
        contract_order = input_nodes[i + 1:]
        contract_order.reverse()
        contract_order = contract_order + input_nodes[:i]
        # get the subarray that is the inputs of node
        contract_order = list(
            filter(lambda n: n in node.inputs, contract_order))

        new_nodes.append(
            generate_optimal_tree_w_constraint(node, contract_order))

    # After generate_optimal_tree_w_constraint, some einstrs are not in the canonical format,
    # needs to rewrite again for dedup
    all_nodes = find_topo_sort(new_nodes)
    with OutputInjectedMode(all_nodes):
        for node in all_nodes:
            if isinstance(node, ad.EinsumNode):
                rewrite_einsum_expr(node)
            if node.inputs != []:
                node.set_inputs(node.inputs)

    dedup(*new_nodes)
    remove_transposes(find_topo_sort(new_nodes))
    return new_nodes
コード例 #3
0
def test_einsum_rewrite_duplicate_input(backendopt):

    a = ad.Variable(name="a", shape=[3, 2])

    x = ad.einsum('ca,cb->ab', a, a)
    y = ad.einsum('cb,ca->ab', a, a)

    rewrite_einsum_expr(x)
    rewrite_einsum_expr(y)

    assert x.einsum_subscripts == y.einsum_subscripts
コード例 #4
0
def test_einsum_equal_repeated_transpose(backendopt):

    A = ad.Variable(name="A", shape=[3, 5])

    x = ad.einsum('or,ob->br', A, A)
    y = ad.einsum('eb,ed->bd', A, A)

    uf1 = rewrite_einsum_expr(x)
    uf2 = rewrite_einsum_expr(y)

    assert x.einsum_subscripts == y.einsum_subscripts
    assert x.inputs == y.inputs
コード例 #5
0
def test_einsum_equal_repeated_transpose(backendopt):

    A = ad.Variable(name="A", shape=[3, 3])
    B = ad.Variable(name="B", shape=[3, 3])

    x = ad.einsum("ac,ba,bc->", A, A, B)
    y = ad.einsum("ba,ac,bc->", A, A, B)

    uf1 = rewrite_einsum_expr(x)
    uf2 = rewrite_einsum_expr(y)

    assert x.einsum_subscripts == y.einsum_subscripts
    assert x.inputs == y.inputs
コード例 #6
0
def test_einsum_equal(backendopt):

    a1 = ad.Variable(name="a1", shape=[3, 2])
    a2 = ad.Variable(name="a2", shape=[2, 3])

    x = ad.einsum('ik,kj->ij', a1, a2)
    y = ad.einsum('ml,sm->sl', a2, a1)

    rewrite_einsum_expr(x)
    rewrite_einsum_expr(y)

    assert x.einsum_subscripts == y.einsum_subscripts
    assert x.inputs == y.inputs
コード例 #7
0
def test_einsum_equal_uf_assign_order(backendopt):

    A = ad.Variable(name="A", shape=[3, 3])
    B = ad.Variable(name="B", shape=[3, 3])
    I = ad.identity(10)

    x = ad.einsum('pb,or,ob,pr,st->srtb', B, A, A, B, I)
    y = ad.einsum('eb,ed,fb,fd,ac->abcd', A, A, B, B, I)

    uf1 = rewrite_einsum_expr(x)
    uf2 = rewrite_einsum_expr(y)

    assert x.einsum_subscripts == y.einsum_subscripts
    assert x.inputs == y.inputs
コード例 #8
0
def test_rewrite_expr(backendopt):
    """
        Test rewrite the einsum expression.
    """

    a1 = ad.Variable(name="a1", shape=[3, 2])
    a2 = ad.Variable(name="a2", shape=[2, 3])

    x = ad.einsum('ik,kj->ij', a1, a2)

    y = ad.einsum('sm,ml->sl', a1, a2)

    rewrite_einsum_expr(x)
    rewrite_einsum_expr(y)
    assert x.einsum_subscripts == y.einsum_subscripts
コード例 #9
0
def prune_single_inv_node(einsum_node, inv_node):
    """
    Prune the inv_node in the einsum node if condition mets.

    Note:
    1. can only optimize the node when the input of inv is an einsum node.
    2. only supports the case when the splitted nodes are different from the remaining ones.
        For example: ad.einsum("ab,bc,cd,de->ae", inv("ab,bc->ac", A, B), A, B, C) will be
        optimzied to ad.einsum("ab,bc->ac", C, ad.identity()),
        but we cannot optimize ad.einsum("ab,bc,cd,de->ae", inv("ab,bc->ac", A, B), A, B, B).

    Parameters
    ----------
    einsum_node: The fused einsum node
    inv_node: the input inv node to be pruned

    Returns
    -------
    If the einsum_node cannot be optimized, then return the input einsum_node.
    If it can be optimized, return the optimized einsum node.

    """
    from autohoot.einsum_graph.expr_generator import rewrite_einsum_expr
    from autohoot.graph_ops.optimal_tree import split_einsum

    inv_node_input = inv_node.inputs[0]
    if not isinstance(inv_node_input, ad.EinsumNode):
        logger.info(f"inv input is not einsum node, can't prune inv")
        return einsum_node

    if not set(inv_node_input.inputs).issubset(set(einsum_node.inputs)):
        logger.info(
            f"inv inputs is not subset of einsum node inputs, can't prune inv")
        return einsum_node

    einsum_inputs_in_inv = [
        n for n in einsum_node.inputs if n in inv_node_input.inputs
    ]
    if len(einsum_inputs_in_inv) < len(inv_node_input.inputs):
        logger.info(
            f"number of inv inputs is more than corresponding einsum inputs, can't prune inv"
        )
        return einsum_node

    split_einsum_node = split_einsum(
        einsum_node,
        list(set(einsum_node.inputs) - set(inv_node_input.inputs)))

    # Assign pseudo nodes and chars
    in_subs, out_subs, _ = parse_einsum_input(
        (split_einsum_node.einsum_subscripts, *split_einsum_node.inputs))
    in_subs_list = in_subs.split(',')

    updated_p_in_nodes = []
    for i, node in enumerate(split_einsum_node.inputs):
        if isinstance(node, ad.EinsumNode):
            p_einsum_input = PseudoNode(node=node, subscript=in_subs_list[i])
        elif node is inv_node:
            p_inv_input = PseudoNode(node=node, subscript=in_subs_list[i])
        else:
            updated_p_in_nodes.append(
                PseudoNode(node=node, subscript=in_subs_list[i]))

    contract_char = "".join(
        set(p_einsum_input.subscript) & set(p_inv_input.subscript))
    uncontract_str = "".join(
        set("".join([p_einsum_input.subscript, p_inv_input.subscript])) -
        set(contract_char))

    if not (len(p_einsum_input.subscript) == 2 and len(p_inv_input.subscript)
            == 2 and len(contract_char) == 1 and len(uncontract_str) == 2):
        # this is not a matmul. Just return the initial node
        logger.info(
            f"the op between inv input and the selected einsum is not matmul, can't prune inv"
        )
        return einsum_node

    if p_einsum_input.subscript[0] == p_inv_input.subscript[
            0] or p_einsum_input.subscript[1] == p_inv_input.subscript[1]:
        # the str is like "ab,ac", and one einsum needs to be transposed to compare
        p_in_subs, p_out_subs, _ = parse_einsum_input(
            (p_einsum_input.node.einsum_subscripts,
             *p_einsum_input.node.inputs))
        einsum_input = ad.einsum(
            f"{p_in_subs}->{p_out_subs[1]}{p_out_subs[0]}",
            *p_einsum_input.node.inputs)
    else:
        einsum_input = p_einsum_input.node

    rewrite_einsum_expr(einsum_input)
    rewrite_einsum_expr(inv_node_input)

    if einsum_input.name != inv_node_input.name:
        logger.info(
            f"inv input and the selected einsum have different expressions, can't prune inv"
        )
        return einsum_node

    # prune the inv node
    updated_p_in_nodes = updated_p_in_nodes + [
        PseudoNode(node=ad.identity(inv_node_input.shape[0]),
                   subscript=uncontract_str)
    ]

    return generate_new_einsum(updated_p_in_nodes, out_subs)
コード例 #10
0
def simplify(output_node):
    """Simplify a graph with a single output node.
    The simplified form will distribute selected operations
    (+), and fuse all connected einsums.

    Args:
        node: The output node.
    Returns:
        node: The newly generated node.
    """
    def fuse_all_einsums(node):
        linearize(node)
        ret_node = PseudoNode(node)
        all_pnodes = find_topo_sort_p([ret_node])
        with OutputInjectedModeP(all_pnodes):
            trees = find_sub_einsumtree(ret_node)
            for tree in trees:
                out_node_p, in_nodes = tree
                new_z = fuse_einsums(out_node_p.node, in_nodes)
                prune_identity_nodes(new_z)
                replace_node(out_node_p, new_z)

        node = declone(ret_node.node)
        return node

    output_node = distribute_graph_w_linearize(output_node)
    output_node = fuse_all_einsums(output_node)

    output_pnode = PseudoNode(output_node)
    all_pnodes = find_topo_sort_p([output_pnode])
    # optimize inverse
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if isinstance(node, ad.EinsumNode):
                # To make sure the same einsum nodes have the same same,
                # so that we can collapse the add node.
                rewrite_einsum_expr(node)
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.TensorInverseNode):
                new_inv_node = optimize_inverse(node)
                replace_node(pnode, new_inv_node)

    # fuse again
    output_node = output_pnode.node
    output_node = fuse_all_einsums(output_node)

    # prune the orthonormal matmuls
    all_pnodes = find_topo_sort_p([output_pnode])
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.EinsumNode):
                new_node = prune_orthonormal_matmuls(node)
                replace_node(pnode, new_node)

    # prune inverse nodes
    output_pnode = PseudoNode(output_node)
    all_pnodes = find_topo_sort_p([output_pnode])
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.EinsumNode):
                new_node = prune_inv_node(node)
                replace_node(pnode, new_node)

    # prune the scalar nodes and remove unnecessary identity nodes
    all_pnodes = find_topo_sort_p([output_pnode])
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.EinsumNode):
                prune_identity_nodes(node)
                new_node = prune_scalar_nodes(node)
                replace_node(pnode, new_node)

    # collapse symmetric expressions
    all_pnodes = find_topo_sort_p([output_pnode])
    for i in range(len(all_pnodes)):
        for j in range(i):
            collapse_symmetric_expr(all_pnodes[i].node, all_pnodes[j].node)

    sympy_input_types = (ad.DistributiveNode, ad.ScalarNode, ad.MulNode)
    #sympy_simplify the distributed nodes
    if isinstance(output_node, ad.DistributiveNode):
        sympy_inputs = []
        all_nodes = find_topo_sort([output_node])
        for node in all_nodes:
            if isinstance(node, ad.EinsumNode):
                # To make sure the same einsum nodes have the same name,
                # so that they can be reduced by sympy.
                rewrite_einsum_expr(node)
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if not isinstance(node, sympy_input_types):
                sympy_inputs.append(node)
        output_node = sympy_simplify(output_node, sympy_inputs)

    return output_node