def test_prune_identity(backendopt):
    for datatype in backendopt:
        T.set_backend(datatype)

        a1 = ad.Variable(name="a1", shape=[3, 3])
        a2 = ad.Variable(name="a2", shape=[3, 3])
        i1 = ad.identity(3)
        i2 = ad.identity(3)
        i3 = ad.identity(3)

        out = ad.einsum("ab,cd,ac,be,ef->abdf", a1, a2, i1, i2, i3)
        prune_identity_nodes(out)
        """
        Explanation to the einsum above:
        The identity node i1 means that a and c should be the same dim.
        we can get rid of i1 and rewrite the expr as 
        ad.einsum("ab,ad,be,ef->abdf", a1, a2, i2, i3).
        we can also combine i2 and i3 cuz e is useless. Therefore,
        we can rewrite the expr as
        ad.einsum("ab,ad,bf->abdf", a1, a2, i2).
        """
        out_expect = ad.einsum("ab,ad,bf->abdf", a1, a2, i2)
        assert len(out.inputs) == 3

        assert tree_eq(out, out_expect, [a1, a2])
Example #2
0
def test_einsum_fuse_w_identity(backendopt):
    """
        [Fuse einsum with multiple identities]
        We want to fuse
            A   identity  identity
            |    \       /
            |     \     /
            |      \   /
            |       es
            |     /
            |    /
            |   /
            |  /
            es
        Here es is einsum.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[3, 3])
        es_identity = ad.einsum('ik,kj->ij', ad.identity(3), ad.identity(3))
        out = ad.einsum('ai,ij->aj', a, es_identity)

        tree, = find_sub_einsumtree(PseudoNode(out))
        out, ins = tree
        new_out = fuse_einsums(out.node, ins)
        assert tree_eq(out.node, new_out, [a])
Example #3
0
def test_simplify_optimize_w_tail_einsum():

    A = ad.Variable(name="A", shape=[2, 2])

    out = ad.einsum("ab,bc->ac", A,
                    ad.einsum("ab,bc->ac", ad.identity(2), ad.identity(2)))
    newout_optimize = optimize(out)
    newout_simplify = simplify(out)

    assert newout_optimize == A
    assert newout_simplify == A
Example #4
0
def test_einsum_fuse_only_identity(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        es_identity = ad.einsum('ik,kj->ij', ad.identity(3), ad.identity(3))
        out = ad.einsum('ai,ij->aj', ad.identity(3), es_identity)

        tree, = find_sub_einsumtree(PseudoNode(out))
        out, ins = tree
        new_out = fuse_einsums(out.node, ins)
        assert tree_eq(out.node, new_out, [])
def test_simplify_optimize_w_tail_einsum(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        A = ad.Variable(name="A", shape=[2, 2])

        out = ad.einsum("ab,bc->ac", A,
                        ad.einsum("ab,bc->ac", ad.identity(2), ad.identity(2)))
        newout_optimize = optimize(out)
        newout_simplify = simplify(out)

        assert newout_optimize == A
        assert newout_simplify == A
def test_prune_identity_w_dup(backendopt):
    for datatype in backendopt:
        T.set_backend(datatype)

        a1 = ad.Variable(name="a1", shape=[3, 3])
        i1 = ad.identity(3)
        i2 = ad.identity(3)
        i3 = ad.identity(3)

        out = ad.einsum("ab,bc,cd,de,ef->af", a1, a1, i1, i2, i3)
        prune_identity_nodes(out)
        out_expect = ad.einsum("ab,bc->ac", a1, a1)
        assert len(out.inputs) == 2

        assert tree_eq(out, out_expect, [a1])
Example #7
0
def prune_identity_nodes(einsum_node):
    """
        reduce the number of identity nodes in the
        einsum_node's inputs. Inplace update.

        Args:
            einsum_node: An fused einsum node.
    """
    if not (isinstance(einsum_node, ad.EinsumNode)):
        return

    uf_str, p_outnode, p_innodes = generate_einsum_info(einsum_node)
    whole_str = p_outnode.subscript + "".join(
        [node.subscript for node in p_innodes])

    p_identity_nodes = list(
        filter(lambda pnode: isinstance(pnode.node, ad.IdentityNode),
               p_innodes))
    p_variable_nodes = [
        pnode for pnode in p_innodes if pnode not in p_identity_nodes
    ]

    # each disjoint set in uf_identity represents the indices
    # linked by identity node
    uf_identity = UF(list(whole_str))
    for pnode in p_identity_nodes:
        uf_identity.connect(pnode.subscript[0], pnode.subscript[1])

    input_indices_set, output_indices_set = set(), set()
    for pnode in p_variable_nodes:
        # replace subscripts by the root chars
        sub_list = [uf_identity.root(char) for char in pnode.subscript]
        pnode.subscript = "".join(sub_list)
        input_indices_set |= set(sub_list)

    p_updated_inputs = p_variable_nodes
    out_sub_list = []
    for i, char in enumerate(p_outnode.subscript):
        uf_root_char = uf_identity.root(char)
        if uf_root_char in output_indices_set:
            # we cannot assign the same char to two indices in the
            # output. Therefore, assign a new char, and add one
            # identity node to the inputs to show the constraint.
            new_char = uf_str.cg.getchar()
            out_sub_list.append(new_char)
            p_identity_node = PseudoNode(node=ad.identity(
                einsum_node.shape[i]),
                                         subscript=f"{uf_root_char}{new_char}")
            p_updated_inputs.append(p_identity_node)
        else:
            # directly assign the root char to the subscripts
            out_sub_list.append(uf_root_char)
            output_indices_set.add(uf_root_char)
    p_outnode.subscript = "".join(out_sub_list)

    new_input_subs = [pnode.subscript for pnode in p_updated_inputs]
    new_subscripts = ",".join(new_input_subs) + "->" + p_outnode.subscript
    einsum_node.einsum_subscripts = new_subscripts
    einsum_node.set_inputs([pnode.node for pnode in p_updated_inputs])
Example #8
0
def test_collapse_expr_w_identity():
    a = ad.Variable(name="a", shape=[2, 2])
    I = ad.identity(2)

    out1 = ad.einsum("ab,bc->ac", a, I)
    out2 = ad.einsum("ab,cb->ac", a, I)

    collapse_symmetric_expr(out1, out2)

    assert out1.name == out2.name
Example #9
0
def test_einsum_equal_uf_assign_order():

    A = ad.Variable(name="A", shape=[3, 3])
    B = ad.Variable(name="B", shape=[3, 3])
    I = ad.identity(10)

    x = ad.einsum('pb,or,ob,pr,st->srtb', B, A, A, B, I)
    y = ad.einsum('eb,ed,fb,fd,ac->abcd', A, A, B, B, I)

    uf1 = rewrite_einsum_expr(x)
    uf2 = rewrite_einsum_expr(y)

    assert x.einsum_subscripts == y.einsum_subscripts
    assert x.inputs == y.inputs
def test_simplify_inv_w_identity(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        A = ad.Variable(name="A", shape=[2, 2])

        out = ad.einsum("ab,cd->acbd", A, ad.tensorinv(ad.identity(3)))
        newout = simplify(out)

        assert isinstance(newout, ad.EinsumNode)
        assert isinstance(newout.inputs[1], ad.IdentityNode)

        assert tree_eq(out, newout, [A], tol=1e-6)
Example #11
0
def test_s2s_w_constants(backendopt):
    for datatype in backendopt:
        T.set_backend(datatype)
        A = ad.Variable(name="A", shape=[2, 2])
        I = ad.identity(2)
        B = ad.einsum("ab,bc->ac", A, I)

        A_val = T.tensor([[1., 2.], [3., 4.]])

        StS = SourceToSource()
        fwd_str = StS.forward([B], function_name='fwd', backend=datatype)
        m = import_code(fwd_str)
        out, = m.fwd([A_val])

        assert T.array_equal(A_val, out)
Example #12
0
def test_dimension_tree_w_identity():
    A = ad.Variable(name="A", shape=[2, 2])
    B = ad.identity(2)
    C = ad.Variable(name="C", shape=[2, 2])
    X = ad.Variable(name="X", shape=[2, 2, 2])

    einsum_node_A = ad.einsum("abc,bm,cm->am", X, B, C)
    einsum_node_B = ad.einsum("abc,am,cm->bm", X, A, C)
    einsum_node_C = ad.einsum("abc,am,bm->cm", X, A, B)

    dt = generate_sequential_optimal_tree(
        [einsum_node_A, einsum_node_B, einsum_node_C], [A, B, C])

    assert tree_eq(dt[0], einsum_node_A, [A, C, X])
    assert tree_eq(dt[1], einsum_node_B, [A, C, X])
    assert tree_eq(dt[2], einsum_node_C, [A, B, X])
Example #13
0
def prune_single_inv_node(einsum_node, inv_node):
    """
    Prune the inv_node in the einsum node if condition mets.

    Note:
    1. can only optimize the node when the input of inv is an einsum node.
    2. only supports the case when the splitted nodes are different from the remaining ones.
        For example: ad.einsum("ab,bc,cd,de->ae", inv("ab,bc->ac", A, B), A, B, C) will be
        optimzied to ad.einsum("ab,bc->ac", C, ad.identity()),
        but we cannot optimize ad.einsum("ab,bc,cd,de->ae", inv("ab,bc->ac", A, B), A, B, B).

    Parameters
    ----------
    einsum_node: The fused einsum node
    inv_node: the input inv node to be pruned

    Returns
    -------
    If the einsum_node cannot be optimized, then return the input einsum_node.
    If it can be optimized, return the optimized einsum node.

    """
    from autohoot.einsum_graph.expr_generator import rewrite_einsum_expr
    from autohoot.graph_ops.optimal_tree import split_einsum

    inv_node_input = inv_node.inputs[0]
    if not isinstance(inv_node_input, ad.EinsumNode):
        logger.info(f"inv input is not einsum node, can't prune inv")
        return einsum_node

    if not set(inv_node_input.inputs).issubset(set(einsum_node.inputs)):
        logger.info(
            f"inv inputs is not subset of einsum node inputs, can't prune inv")
        return einsum_node

    einsum_inputs_in_inv = [
        n for n in einsum_node.inputs if n in inv_node_input.inputs
    ]
    if len(einsum_inputs_in_inv) < len(inv_node_input.inputs):
        logger.info(
            f"number of inv inputs is more than corresponding einsum inputs, can't prune inv"
        )
        return einsum_node

    split_einsum_node = split_einsum(
        einsum_node,
        list(set(einsum_node.inputs) - set(inv_node_input.inputs)))

    # Assign pseudo nodes and chars
    in_subs, out_subs, _ = parse_einsum_input(
        (split_einsum_node.einsum_subscripts, *split_einsum_node.inputs))
    in_subs_list = in_subs.split(',')

    updated_p_in_nodes = []
    for i, node in enumerate(split_einsum_node.inputs):
        if isinstance(node, ad.EinsumNode):
            p_einsum_input = PseudoNode(node=node, subscript=in_subs_list[i])
        elif node is inv_node:
            p_inv_input = PseudoNode(node=node, subscript=in_subs_list[i])
        else:
            updated_p_in_nodes.append(
                PseudoNode(node=node, subscript=in_subs_list[i]))

    contract_char = "".join(
        set(p_einsum_input.subscript) & set(p_inv_input.subscript))
    uncontract_str = "".join(
        set("".join([p_einsum_input.subscript, p_inv_input.subscript])) -
        set(contract_char))

    if not (len(p_einsum_input.subscript) == 2 and len(p_inv_input.subscript)
            == 2 and len(contract_char) == 1 and len(uncontract_str) == 2):
        # this is not a matmul. Just return the initial node
        logger.info(
            f"the op between inv input and the selected einsum is not matmul, can't prune inv"
        )
        return einsum_node

    if p_einsum_input.subscript[0] == p_inv_input.subscript[
            0] or p_einsum_input.subscript[1] == p_inv_input.subscript[1]:
        # the str is like "ab,ac", and one einsum needs to be transposed to compare
        p_in_subs, p_out_subs, _ = parse_einsum_input(
            (p_einsum_input.node.einsum_subscripts,
             *p_einsum_input.node.inputs))
        einsum_input = ad.einsum(
            f"{p_in_subs}->{p_out_subs[1]}{p_out_subs[0]}",
            *p_einsum_input.node.inputs)
    else:
        einsum_input = p_einsum_input.node

    rewrite_einsum_expr(einsum_input)
    rewrite_einsum_expr(inv_node_input)

    if einsum_input.name != inv_node_input.name:
        logger.info(
            f"inv input and the selected einsum have different expressions, can't prune inv"
        )
        return einsum_node

    # prune the inv node
    updated_p_in_nodes = updated_p_in_nodes + [
        PseudoNode(node=ad.identity(inv_node_input.shape[0]),
                   subscript=uncontract_str)
    ]

    return generate_new_einsum(updated_p_in_nodes, out_subs)
Example #14
0
def prune_orthonormal_matmuls(einsum_node):
    """
    Remove the matrices of a einsum_node if M @ M.T like structures exist.
    Args:
        einsum_node: An fused einsum node.
    Return:
        An optimized einsum node.
    """

    # A map from the orthonormal matrix mode to (orthonormal_index, contraction_index)
    orthonormal_indices_map = {'column': (0, 1), 'row': (1, 0)}

    _, p_outnode, p_innodes = generate_einsum_info(einsum_node)
    subs_list = [pnode.subscript
                 for pnode in p_innodes] + [p_outnode.subscript]

    ortho_pnode_map = {}
    for pnode in p_innodes:
        if isinstance(pnode.node,
                      ad.MatrixNode) and pnode.node.orthonormal != None:
            nodename = pnode.node.name
            if nodename in ortho_pnode_map:
                ortho_pnode_map[nodename].append(pnode)
            else:
                ortho_pnode_map[nodename] = [pnode]

    for pnodes in ortho_pnode_map.values():
        if len(pnodes) < 2:
            continue

        remaining_pnodes = pnodes
        pnodes_subs = list(itertools.combinations(pnodes, 2))

        for pnodes_binary_input in pnodes_subs:
            if not set(pnodes_binary_input).issubset(set(remaining_pnodes)):
                continue

            pnode_A, pnode_B = pnodes_binary_input
            o_index, c_index = orthonormal_indices_map[
                pnode_A.node.orthonormal]
            # Criteria for the pruning: the o_index of two inputs are different,
            # and the c_index only appear in these two nodes.
            c_index_is_equal = pnode_A.subscript[c_index] == pnode_B.subscript[
                c_index]
            o_index_not_equal = pnode_A.subscript[
                o_index] != pnode_B.subscript[o_index]
            if not (c_index_is_equal and o_index_not_equal):
                continue
            num_subs_w_cindex = len(
                list(
                    filter(lambda subs: pnode_A.subscript[c_index] in subs,
                           subs_list)))
            if not num_subs_w_cindex == 2:
                continue
            remaining_pnodes = [
                pnode for pnode in remaining_pnodes
                if not pnode in pnodes_binary_input
            ]
            p_innodes = [
                pnode for pnode in p_innodes
                if not pnode in pnodes_binary_input
            ]

            i_node = ad.identity(pnode_A.node.shape[o_index])
            i_subs = f"{pnode_A.subscript[o_index]}{pnode_B.subscript[o_index]}"
            p_innodes.append(PseudoNode(node=i_node, subscript=i_subs))

    new_input_subs = [pnode.subscript for pnode in p_innodes]
    new_subscripts = ",".join(new_input_subs) + "->" + p_outnode.subscript
    new_inputs = [pnode.node for pnode in p_innodes]

    return ad.einsum(new_subscripts, *new_inputs)
Example #15
0
def test_identity():

    A = ad.identity(3)

    assert AutodiffParser.parse(A.name, []).name == A.name
Example #16
0
 def p_expression_identity(t):
     'expression : IDENTITY LPAREN NUMBER RPAREN'
     t[0] = ad.identity(t[3])