def test_einsum_fuse_graph(backendopt): """ [Fuse einsum used twice] This case is rather subtle. We want to auto fuse A B C | \ / | es | /| | / | es | \ | es Here es is einsum. """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 3]) b = ad.Variable(name="b", shape=[3, 2]) c = ad.Variable(name="c", shape=[2, 3]) BC = ad.einsum('ik, kj->ij', b, c) # 3x3 ABC = ad.einsum('ik, kj->ij', a, BC) # 3x3 out = ad.einsum('jk, ki->ji', ABC, BC) # 3x3 linearize(out) tree, = find_sub_einsumtree(PseudoNode(out)) out, ins = tree new_z = fuse_einsums(out.node, ins) assert tree_eq(out.node, new_z, [a, b, c])
def test_einsum_multitier(backendopt): for datatype in backendopt: T.set_backend(datatype) input_nodes1, zs1 = get_tree("set1") input_nodes2, zs2 = get_tree("set2") out1 = zs1 + zs2 input_nodes3, zs3 = get_tree("set3") input_nodes4, zs4 = get_tree("set4") out2 = zs3 + zs4 out = ad.einsum("ij, jk->ik", out1, out2) input_nodes = input_nodes1 + input_nodes2 + input_nodes3 + input_nodes4 generated_feed_dict = gen_dict(input_nodes) executor = ad.Executor([out]) z_val, = executor.run(feed_dict=generated_feed_dict) with OutputInjectedModeP(find_topo_sort_p([PseudoNode(out)])): trees = find_sub_einsumtree(PseudoNode(out)) for tree in trees: out_node, in_nodes = tree new_z = fuse_einsums(out_node.node, in_nodes) replace_node(out_node, new_z) executor = ad.Executor([out]) z_new_val, = executor.run(feed_dict=generated_feed_dict) assert float_eq(z_val, z_new_val)
def test_einsum_fuse_w_identity(backendopt): """ [Fuse einsum with multiple identities] We want to fuse A identity identity | \ / | \ / | \ / | es | / | / | / | / es Here es is einsum. """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 3]) es_identity = ad.einsum('ik,kj->ij', ad.identity(3), ad.identity(3)) out = ad.einsum('ai,ij->aj', a, es_identity) tree, = find_sub_einsumtree(PseudoNode(out)) out, ins = tree new_out = fuse_einsums(out.node, ins) assert tree_eq(out.node, new_out, [a])
def optimize(node): """Optimize a graph with a single output node. Args: node: The output node. Returns: node: The newly generated node. """ node = distribute_tree(node) linearize(node) all_nodes = find_topo_sort([node]) ret_node = PseudoNode(node) with OutputInjectedMode(all_nodes): trees = find_sub_einsumtree(ret_node) for tree in trees: out_node_p, in_nodes = tree new_z = fuse_einsums(out_node_p.node, in_nodes) prune_identity_nodes(new_z) new_z = generate_optimal_tree(new_z) replace_node(out_node_p, new_z) node = declone(ret_node.node) all_nodes = find_topo_sort([node]) for node in all_nodes: if isinstance(node, ad.EinsumNode): rewrite_einsum_expr(node) for node in find_topo_sort([node]): if node.inputs != []: node.set_inputs(node.inputs) dedup(node) return node
def test_einsum_fuse_only_identity(backendopt): for datatype in backendopt: T.set_backend(datatype) es_identity = ad.einsum('ik,kj->ij', ad.identity(3), ad.identity(3)) out = ad.einsum('ai,ij->aj', ad.identity(3), es_identity) tree, = find_sub_einsumtree(PseudoNode(out)) out, ins = tree new_out = fuse_einsums(out.node, ins) assert tree_eq(out.node, new_out, [])
def fuse_all_einsums(node): linearize(node) ret_node = PseudoNode(node) all_pnodes = find_topo_sort_p([ret_node]) with OutputInjectedModeP(all_pnodes): trees = find_sub_einsumtree(ret_node) for tree in trees: out_node_p, in_nodes = tree new_z = fuse_einsums(out_node_p.node, in_nodes) prune_identity_nodes(new_z) replace_node(out_node_p, new_z) node = declone(ret_node.node) return node
def test_einsum_subtree_clone(backendopt): """ [Subtree clone] This case is rather subtle. We want to auto fuse A B C D | \ / | | es | | / \ | | / \ | es es \ / + Here es is einsum. """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 3]) b = ad.Variable(name="b", shape=[3, 2]) c = ad.Variable(name="c", shape=[2, 3]) d = ad.Variable(name="d", shape=[3, 3]) BC = ad.einsum('ik, kj->ij', b, c) # 3x3 ABC = ad.einsum('ik, kj->ij', a, BC) # 3x3 BCD = ad.einsum('jk, ki->ji', BC, d) # 3x3 out = ABC + BCD input_nodes = [a, b, c, d] generated_feed_dict = gen_dict(input_nodes) executor = ad.Executor([out]) out_val, = executor.run(feed_dict=generated_feed_dict) with OutputInjectedModeP(find_topo_sort_p([PseudoNode(out)])): trees = find_sub_einsumtree(PseudoNode(out)) assert len(trees) == 2 for tree in trees: out_node, in_nodes = tree new_z = fuse_einsums(out_node.node, in_nodes) replace_node(out_node, new_z) new_out_val, = executor.run(feed_dict=generated_feed_dict) assert float_eq(out_val, new_out_val)
def test_einsum_find_subtree_after_linearization(backendopt): """ An einsum graph like A B inputs |\ | | \ | | \ | | C | / | / output will produce An einsum graph like A B inputs |\ | | A1 | | \ | A2 C | / | / output The subtree inputs must then be [A1, A2, B] rather than A, B. """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a1", shape=[3, 2]) b = ad.Variable(name="b", shape=[2, 3]) c = ad.einsum('ik,kj->ij', a, b) output = ad.einsum('ik,ij->kj', a, c) feed_dict = gen_dict([a, b]) executor = ad.Executor([output]) out_val, = executor.run(feed_dict=feed_dict) # New graph linearize(output) tree, = find_sub_einsumtree(PseudoNode(output)) assert (len(tree[1]) == 3)