Example #1
0
def test_einsum_multitier(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        input_nodes1, zs1 = get_tree("set1")
        input_nodes2, zs2 = get_tree("set2")
        out1 = zs1 + zs2

        input_nodes3, zs3 = get_tree("set3")
        input_nodes4, zs4 = get_tree("set4")
        out2 = zs3 + zs4
        out = ad.einsum("ij, jk->ik", out1, out2)
        input_nodes = input_nodes1 + input_nodes2 + input_nodes3 + input_nodes4

        generated_feed_dict = gen_dict(input_nodes)

        executor = ad.Executor([out])
        z_val, = executor.run(feed_dict=generated_feed_dict)

        with OutputInjectedModeP(find_topo_sort_p([PseudoNode(out)])):
            trees = find_sub_einsumtree(PseudoNode(out))
            for tree in trees:
                out_node, in_nodes = tree
                new_z = fuse_einsums(out_node.node, in_nodes)
                replace_node(out_node, new_z)

        executor = ad.Executor([out])
        z_new_val, = executor.run(feed_dict=generated_feed_dict)

        assert float_eq(z_val, z_new_val)
Example #2
0
    def fuse_all_einsums(node):
        linearize(node)
        ret_node = PseudoNode(node)
        all_pnodes = find_topo_sort_p([ret_node])
        with OutputInjectedModeP(all_pnodes):
            trees = find_sub_einsumtree(ret_node)
            for tree in trees:
                out_node_p, in_nodes = tree
                new_z = fuse_einsums(out_node_p.node, in_nodes)
                prune_identity_nodes(new_z)
                replace_node(out_node_p, new_z)

        node = declone(ret_node.node)
        return node
Example #3
0
def test_einsum_subtree_clone(backendopt):
    """
        [Subtree clone]
        This case is rather subtle.
        We want to auto fuse
            A   B   C   D
            |    \ /    |
            |     es    |
            |    /  \   |
            |  /      \ |
            es         es
              \       /
                  +

        Here es is einsum.
    """

    for datatype in backendopt:
        T.set_backend(datatype)
        a = ad.Variable(name="a", shape=[3, 3])
        b = ad.Variable(name="b", shape=[3, 2])
        c = ad.Variable(name="c", shape=[2, 3])
        d = ad.Variable(name="d", shape=[3, 3])

        BC = ad.einsum('ik, kj->ij', b, c)  # 3x3

        ABC = ad.einsum('ik, kj->ij', a, BC)  # 3x3

        BCD = ad.einsum('jk, ki->ji', BC, d)  # 3x3

        out = ABC + BCD

        input_nodes = [a, b, c, d]
        generated_feed_dict = gen_dict(input_nodes)

        executor = ad.Executor([out])
        out_val, = executor.run(feed_dict=generated_feed_dict)

        with OutputInjectedModeP(find_topo_sort_p([PseudoNode(out)])):
            trees = find_sub_einsumtree(PseudoNode(out))
            assert len(trees) == 2
            for tree in trees:
                out_node, in_nodes = tree
                new_z = fuse_einsums(out_node.node, in_nodes)
                replace_node(out_node, new_z)

        new_out_val, = executor.run(feed_dict=generated_feed_dict)

        assert float_eq(out_val, new_out_val)
Example #4
0
def distribute_tree(output):
    """ Distribute a tree of einsum and add nodes.

    NOTE: the output node should be a linearized node.
    Behavior undefined if there are other kind of nodes.

    Args:
        output: The output of a tree.

    Returns:
        output: a newly generated node with add operands distributed.
    
    Idea:
        1. Construct the output tree.
        2. Find binary op.
        3. Apply distribute.
        4. Iterate 1->3
    """
    def get_first_binary_op(pnodes):
        for pnode in pnodes:
            node = pnode.node
            if isinstance(node,
                          ad.DistributiveNode) and len(node.outputs) >= 1:
                has_einsum_nodes = all(
                    [isinstance(x, ad.EinsumNode) for x in node.outputs])
                if has_einsum_nodes:
                    return node
        return None

    while 1:
        all_pnodes = find_topo_sort_p([PseudoNode(output)])
        with OutputInjectedModeP(all_pnodes):
            first_binary_op = get_first_binary_op(all_pnodes)
            if first_binary_op is None:
                break
            for einsum_node in first_binary_op.outputs:
                if isinstance(einsum_node, ad.DistributiveNode):
                    continue
                assert isinstance(einsum_node, ad.EinsumNode)
                new_node = _distribute(first_binary_op, einsum_node)
                replace_node(PseudoNode(einsum_node), new_node)
                if einsum_node == output:
                    output = new_node
    # This is need for source generation.
    output.set_inputs(output.inputs)
    return output
Example #5
0
def dedup(*nodes):
    """Remove the duplicate nodes with same name.
    Args:
        nodes: One or many nodes.
    """
    assert len(nodes) > 0

    topo_order = find_topo_sort_p([PseudoNode(n) for n in nodes])
    with OutputInjectedModeP(topo_order):
        unique_nodes_map = {}
        unique_nodes = set()
        # Use the last occurrence.
        for ptmp in topo_order:
            tmp = ptmp.node
            unique_nodes_map[tmp.name] = tmp
        unique_nodes = set(unique_nodes_map.values())

        for ptmp in topo_order:
            tmp = ptmp.node
            if tmp not in unique_nodes:
                unique_copy = unique_nodes_map[tmp.name]
                replace_node(ptmp, unique_copy)
Example #6
0
def linearize(output_node):
    """Linearize a graph by adding clone nodes for computation optimization.

    Args:
        output_node: A single node.
    Returns: 
        None. Update is inplace. 

    NOTE: If you ever need to debug this function, the generated name is 
        inconsistent becasue of the added edges.

    """
    # Need to create new nodes for whichever node that has 2 or more outgoing edges.
    all_pnodes = find_topo_sort_p([PseudoNode(output_node)])
    # Inject outputs relationship.
    with OutputInjectedModeP(all_pnodes):
        for pn in all_pnodes:
            n = pn.node
            if len(n.outputs) > 1:
                for n_o in set(n.outputs):
                    n_o.set_inputs([
                        tmp if tmp.name != n.name else copy_tree(n)
                        for tmp in n_o.inputs
                    ])
Example #7
0
def simplify(output_node):
    """Simplify a graph with a single output node.
    The simplified form will distribute selected operations
    (+), and fuse all connected einsums.

    Args:
        node: The output node.
    Returns:
        node: The newly generated node.
    """
    def fuse_all_einsums(node):
        linearize(node)
        ret_node = PseudoNode(node)
        all_pnodes = find_topo_sort_p([ret_node])
        with OutputInjectedModeP(all_pnodes):
            trees = find_sub_einsumtree(ret_node)
            for tree in trees:
                out_node_p, in_nodes = tree
                new_z = fuse_einsums(out_node_p.node, in_nodes)
                prune_identity_nodes(new_z)
                replace_node(out_node_p, new_z)

        node = declone(ret_node.node)
        return node

    output_node = distribute_graph_w_linearize(output_node)
    output_node = fuse_all_einsums(output_node)

    output_pnode = PseudoNode(output_node)
    all_pnodes = find_topo_sort_p([output_pnode])
    # optimize inverse
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if isinstance(node, ad.EinsumNode):
                # To make sure the same einsum nodes have the same same,
                # so that we can collapse the add node.
                rewrite_einsum_expr(node)
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.TensorInverseNode):
                new_inv_node = optimize_inverse(node)
                replace_node(pnode, new_inv_node)

    # fuse again
    output_node = output_pnode.node
    output_node = fuse_all_einsums(output_node)

    # prune the orthonormal matmuls
    all_pnodes = find_topo_sort_p([output_pnode])
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.EinsumNode):
                new_node = prune_orthonormal_matmuls(node)
                replace_node(pnode, new_node)

    # prune inverse nodes
    output_pnode = PseudoNode(output_node)
    all_pnodes = find_topo_sort_p([output_pnode])
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.EinsumNode):
                new_node = prune_inv_node(node)
                replace_node(pnode, new_node)

    # prune the scalar nodes and remove unnecessary identity nodes
    all_pnodes = find_topo_sort_p([output_pnode])
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.EinsumNode):
                prune_identity_nodes(node)
                new_node = prune_scalar_nodes(node)
                replace_node(pnode, new_node)

    # collapse symmetric expressions
    all_pnodes = find_topo_sort_p([output_pnode])
    for i in range(len(all_pnodes)):
        for j in range(i):
            collapse_symmetric_expr(all_pnodes[i].node, all_pnodes[j].node)

    sympy_input_types = (ad.DistributiveNode, ad.ScalarNode, ad.MulNode)
    #sympy_simplify the distributed nodes
    if isinstance(output_node, ad.DistributiveNode):
        sympy_inputs = []
        all_nodes = find_topo_sort([output_node])
        for node in all_nodes:
            if isinstance(node, ad.EinsumNode):
                # To make sure the same einsum nodes have the same name,
                # so that they can be reduced by sympy.
                rewrite_einsum_expr(node)
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if not isinstance(node, sympy_input_types):
                sympy_inputs.append(node)
        output_node = sympy_simplify(output_node, sympy_inputs)

    return output_node