def test_einsum_gen_corner_case(backendopt): """ Note: Numpy contraction path cannot find the opt path for this expression. It will output the same expression as the input. -------- E -------- | | | | a b c d | | | | A - e - B - f - C - g - D | | | | h i j k | | | | """ size = 5 A = ad.Variable(name="A", shape=[size, size, size]) B = ad.Variable(name="B", shape=[size, size, size, size]) C = ad.Variable(name="C", shape=[size, size, size, size]) D = ad.Variable(name="D", shape=[size, size, size]) E = ad.Variable(name="E", shape=[size, size, size, size]) output = ad.einsum('aeh,bfie,cgjf,dgk,abcd->hijk', A, B, C, D, E) new_output = generate_optimal_tree(output) for node in find_topo_sort([new_output]): if not isinstance(node, ad.VariableNode): assert (len(node.inputs) == 2)
def test_get_common_ancestor(backendopt): A = ad.Variable(name="A", shape=[3, 2]) X1 = ad.Variable(name="X1", shape=[3, 2, 2]) X2 = ad.Variable(name="X2", shape=[3, 3, 2, 2]) X3 = ad.Variable(name="X3", shape=[3, 2, 2]) """ The network and indices positions are as follows: g - A | c d e | | | X1 - a - X2 - b - X3 | | | h i j | l - A """ einsum_node = ad.einsum('lj,ge,bej,abdi,ach->cdhigl', A, A, X3, X2, X1) opt_einsum = generate_optimal_tree(einsum_node) sub_einsum = get_common_ancestor(opt_einsum, einsum_node.inputs, A) assert sorted(get_all_inputs(sub_einsum), key=lambda node: node.name) == sorted( [A, A, X3], key=lambda node: node.name)
def optimize(node): """Optimize a graph with a single output node. Args: node: The output node. Returns: node: The newly generated node. """ node = distribute_tree(node) linearize(node) all_nodes = find_topo_sort([node]) ret_node = PseudoNode(node) with OutputInjectedMode(all_nodes): trees = find_sub_einsumtree(ret_node) for tree in trees: out_node_p, in_nodes = tree new_z = fuse_einsums(out_node_p.node, in_nodes) prune_identity_nodes(new_z) new_z = generate_optimal_tree(new_z) replace_node(out_node_p, new_z) node = declone(ret_node.node) all_nodes = find_topo_sort([node]) for node in all_nodes: if isinstance(node, ad.EinsumNode): rewrite_einsum_expr(node) for node in find_topo_sort([node]): if node.inputs != []: node.set_inputs(node.inputs) dedup(node) return node
def test_einsum_gen_custom(backendopt): for datatype in backendopt: a = ad.Variable(name="a", shape=[2, 2]) b = ad.Variable(name="b", shape=[2, 5]) c = ad.Variable(name="c", shape=[5, 2]) output = ad.einsum('ij,jk,kl->il', a, b, c) new_output = generate_optimal_tree(output, path=[(1, 2), (0, 1)]) assert tree_eq(output, new_output, [a, b, c])
def test_einsum_gen(backendopt): for datatype in backendopt: N = 10 C = ad.Variable(name="C", shape=[N, N]) I = ad.Variable(name="I", shape=[N, N, N, N]) output = ad.einsum('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C) new_output = generate_optimal_tree(output) assert tree_eq(output, new_output, [C, I]) assert len(new_output.inputs) == 2
def test_get_common_ancestor_w_inv(backendopt): A = ad.Variable(name="A", shape=[3, 3]) X = ad.Variable(name="X", shape=[3, 3, 3]) inv = ad.tensorinv(ad.einsum("ab,ac->bc", A, A), ind=1) einsum_node = ad.einsum('abc,ad,ce->bce', X, A, inv) opt_einsum = generate_optimal_tree(einsum_node) sub_einsum = get_common_ancestor(opt_einsum, einsum_node.inputs, A) # sub_einsum should be ad.einsum('ad,abc->dbc',A,X), and shouldn't include the inv node. assert sorted(get_all_inputs(sub_einsum), key=lambda node: node.name) == sorted( [A, X], key=lambda node: node.name)
def test_get_common_ancestor_simple(backendopt): A = ad.Variable(name="A", shape=[3, 2]) X1 = ad.Variable(name="X1", shape=[3, 4, 4]) X2 = ad.Variable(name="X2", shape=[3, 2, 2]) """ The network and indices positions are as follows: g - A | d e | | X1 - b - X2 | | i j """ einsum_node = ad.einsum('ge,bdi,bej->gdij', A, X1, X2) opt_einsum = generate_optimal_tree(einsum_node) sub_einsum = get_common_ancestor(opt_einsum, einsum_node.inputs, A) assert sorted(get_all_inputs(sub_einsum), key=lambda node: node.name) == sorted( [A, X2], key=lambda node: node.name)