def test_tree_distribution_ppE(dist_op, backendopt): """ [Distributive] ((A + B) + C) * G will produce AG + BG + CG Note that (A+B) has parent (A + B) + C. """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 2]) b = ad.Variable(name="b", shape=[3, 2]) c = ad.Variable(name="c", shape=[3, 2]) g = ad.Variable(name="g", shape=[2, 2]) output = ad.einsum('ik,kk->ik', dist_op(dist_op(a, b), c), g) new_output = distribute_tree(output) assert isinstance(new_output, dist_op) assert tree_eq(output, new_output, [a, b, c, g])
def test_einsum_fuse_graph(backendopt): """ [Fuse einsum used twice] This case is rather subtle. We want to auto fuse A B C | \ / | es | /| | / | es | \ | es Here es is einsum. """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 3]) b = ad.Variable(name="b", shape=[3, 2]) c = ad.Variable(name="c", shape=[2, 3]) BC = ad.einsum('ik, kj->ij', b, c) # 3x3 ABC = ad.einsum('ik, kj->ij', a, BC) # 3x3 out = ad.einsum('jk, ki->ji', ABC, BC) # 3x3 linearize(out) tree, = find_sub_einsumtree(PseudoNode(out)) out, ins = tree new_z = fuse_einsums(out.node, ins) assert tree_eq(out.node, new_z, [a, b, c])
def test_tree_distribution_w_add_output(dist_op, backendopt): """ Test C * (A + B) + F * (D + E) = (C * A + C * B) + (F * D + F * E) """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 3]) b = ad.Variable(name="b", shape=[3, 3]) c = ad.Variable(name="c", shape=[3, 3]) d = ad.Variable(name="d", shape=[3, 3]) e = ad.Variable(name="e", shape=[3, 3]) f = ad.Variable(name="f", shape=[3, 3]) out1 = ad.einsum('ik,kj->ij', c, dist_op(a, b)) out2 = ad.einsum('ik,kj->ij', d, dist_op(e, f)) output = dist_op(out1, out2) new_output = distribute_tree(output) assert isinstance(new_output, dist_op) for input_node in new_output.inputs: assert isinstance(input_node, dist_op) assert tree_eq(output, new_output, [a, b, c, d, e, f])
def test_tree_distribution_two_layers(dist_op, backendopt): """ [Distributive] ((A + B) * G) * C will produce AGC + BGC Note that (A+B) * G is contracted first. """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 2]) b = ad.Variable(name="b", shape=[3, 2]) g = ad.Variable(name="g", shape=[2, 2]) c = ad.Variable(name="c", shape=[2, 3]) interm = ad.einsum('ik, kk->ik', dist_op(a, b), g) output = ad.einsum('ik,kj->ij', interm, c) new_output = distribute_tree(output) assert isinstance(new_output, dist_op) assert tree_eq(output, new_output, [a, b, c, g])
def test_einsum_fuse_w_identity(backendopt): """ [Fuse einsum with multiple identities] We want to fuse A identity identity | \ / | \ / | \ / | es | / | / | / | / es Here es is einsum. """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 3]) es_identity = ad.einsum('ik,kj->ij', ad.identity(3), ad.identity(3)) out = ad.einsum('ai,ij->aj', a, es_identity) tree, = find_sub_einsumtree(PseudoNode(out)) out, ins = tree new_out = fuse_einsums(out.node, ins) assert tree_eq(out.node, new_out, [a])
def test_prune_identity(backendopt): for datatype in backendopt: T.set_backend(datatype) a1 = ad.Variable(name="a1", shape=[3, 3]) a2 = ad.Variable(name="a2", shape=[3, 3]) i1 = ad.identity(3) i2 = ad.identity(3) i3 = ad.identity(3) out = ad.einsum("ab,cd,ac,be,ef->abdf", a1, a2, i1, i2, i3) prune_identity_nodes(out) """ Explanation to the einsum above: The identity node i1 means that a and c should be the same dim. we can get rid of i1 and rewrite the expr as ad.einsum("ab,ad,be,ef->abdf", a1, a2, i2, i3). we can also combine i2 and i3 cuz e is useless. Therefore, we can rewrite the expr as ad.einsum("ab,ad,bf->abdf", a1, a2, i2). """ out_expect = ad.einsum("ab,ad,bf->abdf", a1, a2, i2) assert len(out.inputs) == 3 assert tree_eq(out, out_expect, [a1, a2])
def test_einsum_multiuse_auto_copy(backendopt): """ Test autolinearization and auto fuse. A B inputs |\ | | \ | | \ | | C | / | / output Next: we would need to autoprune. """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a1", shape=[3, 2]) b = ad.Variable(name="b", shape=[2, 3]) c = ad.einsum('ik,kj->ij', a, b) output = ad.einsum('ik,ij->kj', a, c) linearize(output) all_nodes = find_topo_sort([output]) cloned_nodes = [ tmp for tmp in all_nodes if isinstance(tmp, ad.CloneNode) ] out_new = fuse_einsums(output, [*cloned_nodes, b]) # Test that every inputs is now fused. assert all([not isinstance(x, ad.EinsumNode) for x in out_new.inputs]) assert tree_eq(output, out_new, [*cloned_nodes, b])
def test_einsum_multiuse(backendopt): """ Test manual fuse. A B inputs |\ | | \ | | \ | | C | / | / output Note that here we assume A is split into 2 vars by some other operations. """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a1", shape=[3, 2]) a_copy = ad.Variable(name="a2", shape=[3, 2]) b = ad.Variable(name="b", shape=[2, 3]) c = ad.einsum('ik,kj->ij', a, b) output = ad.einsum('ik,ij->kj', a_copy, c) # New graph out_new = fuse_einsums(output, [a, a_copy, b]) assert tree_eq(output, out_new, [a, a_copy, b])
def test_dimension_tree_w_identity(): A = ad.Variable(name="A", shape=[2, 2]) B = ad.identity(2) C = ad.Variable(name="C", shape=[2, 2]) X = ad.Variable(name="X", shape=[2, 2, 2]) einsum_node_A = ad.einsum("abc,bm,cm->am", X, B, C) einsum_node_B = ad.einsum("abc,am,cm->bm", X, A, C) einsum_node_C = ad.einsum("abc,am,bm->cm", X, A, B) dt = generate_sequential_optimal_tree( [einsum_node_A, einsum_node_B, einsum_node_C], [A, B, C]) assert tree_eq(dt[0], einsum_node_A, [A, C, X]) assert tree_eq(dt[1], einsum_node_B, [A, C, X]) assert tree_eq(dt[2], einsum_node_C, [A, B, X])
def test_einsum(backendopt): for datatype in backendopt: T.set_backend(datatype) input_nodes, z = get_tree() z_new = fuse_einsums(z, input_nodes) assert tree_eq(z, z_new, input_nodes)
def test_einsum_gen_custom(backendopt): for datatype in backendopt: a = ad.Variable(name="a", shape=[2, 2]) b = ad.Variable(name="b", shape=[2, 5]) c = ad.Variable(name="c", shape=[5, 2]) output = ad.einsum('ij,jk,kl->il', a, b, c) new_output = generate_optimal_tree(output, path=[(1, 2), (0, 1)]) assert tree_eq(output, new_output, [a, b, c])
def test_tensordot(backendopt): for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 3, 3, 3]) b = ad.Variable(name="b", shape=[3, 3, 3, 3]) result = ad.tensordot(a, b, axes=[[1, 3], [0, 1]]) result2 = ad.einsum("abcd,bdef->acef", a, b) assert tree_eq(result, result2, [a, b])
def test_einsum_gen(backendopt): for datatype in backendopt: N = 10 C = ad.Variable(name="C", shape=[N, N]) I = ad.Variable(name="I", shape=[N, N, N, N]) output = ad.einsum('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C) new_output = generate_optimal_tree(output) assert tree_eq(output, new_output, [C, I]) assert len(new_output.inputs) == 2
def test_simple_dmrg_tree(): A1 = ad.Variable(name="A1", shape=[3, 2]) A2 = ad.Variable(name="A2", shape=[3, 3, 2]) A3 = ad.Variable(name="A3", shape=[3, 2]) X1 = ad.Variable(name="X1", shape=[3, 2, 2]) X2 = ad.Variable(name="X2", shape=[3, 3, 2, 2]) X3 = ad.Variable(name="X3", shape=[3, 2, 2]) """ The network and indices positions are as follows: A1 - f - A2 - g - A3 | | | c d e | | | X1 - a - X2 - b - X3 | | | h i j | | | A1 - k - A2 - l - A3 """ einsum_node_A1 = ad.einsum("ach,abdi,bej,fgd,kli,ge,lj->fckh", X1, X2, X3, A2, A2, A3, A3) einsum_node_A2 = ad.einsum("ach,abdi,bej,fc,kh,ge,lj->fgdkli", X1, X2, X3, A1, A1, A3, A3) einsum_node_A3 = ad.einsum("ach,abdi,bej,fc,kh,fgd,kli->gelj", X1, X2, X3, A1, A1, A2, A2) dt = generate_sequential_optimal_tree( [einsum_node_A1, einsum_node_A2, einsum_node_A3], [A1, A2, A3]) assert tree_eq(dt[0], einsum_node_A1, [X1, X2, X3, A1, A1, A2, A2, A3, A3]) assert tree_eq(dt[1], einsum_node_A2, [X1, X2, X3, A1, A1, A2, A2, A3, A3]) # In the correct contraction path, only X3 should be contracted with A3, # all other X nodes should be contracted later. einsum_inputs = list( filter(lambda node: isinstance(node, ad.EinsumNode), find_topo_sort(dt))) assert sorted(einsum_inputs[0].inputs, key=lambda node: node.name) == sorted( [A3, A3, X3], key=lambda node: node.name)
def test_mps(backendopt): for datatype in backendopt: T.set_backend(datatype) mps_graph = MpsGraph.create(4, ranks=[5, 6, 7]) executor = ad.Executor([mps_graph.output]) expect_mps = ad.einsum('ab,acd,cef,eg->bdfg', *mps_graph.inputs) assert tree_eq(mps_graph.output, expect_mps, mps_graph.inputs)
def test_mpo(backendopt): for datatype in backendopt: T.set_backend(datatype) mpo_graph = MpoGraph.create(4, ranks=[5, 6, 7]) executor = ad.Executor([mpo_graph.output]) expect_mpo = ad.einsum('abc,adef,dghi,gjk->behjcfik', *mpo_graph.inputs) assert tree_eq(mpo_graph.output, expect_mpo, mpo_graph.inputs)
def test_mul_by_const_float(backendopt): for datatype in backendopt: T.set_backend(datatype) x = ad.Variable(name="x", shape=[3]) y1 = ad.sum(5 * x) y2 = ad.sum(5.0 * x) assert y1.name == y2.name assert tree_eq(y1, y2, [x])
def test_split_einsum_dup(backendopt): for datatype in backendopt: A = ad.Variable(name="A", shape=[2, 2]) B = ad.Variable(name="B", shape=[2, 2]) einsum_node = ad.einsum("ab,bc,cd->ad", A, B, B) split_input_nodes = [A] new_einsum = split_einsum(einsum_node, split_input_nodes) assert len(new_einsum.inputs) == 2 # A, einsum(B, B) assert tree_eq(new_einsum, einsum_node, [A, B])
def test_einsum_fuse_only_identity(backendopt): for datatype in backendopt: T.set_backend(datatype) es_identity = ad.einsum('ik,kj->ij', ad.identity(3), ad.identity(3)) out = ad.einsum('ai,ij->aj', ad.identity(3), es_identity) tree, = find_sub_einsumtree(PseudoNode(out)) out, ins = tree new_out = fuse_einsums(out.node, ins) assert tree_eq(out.node, new_out, [])
def test_distribute_dup(dist_op, backendopt): from graph_ops.graph_transformer import distribute_graph_w_linearize for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 3]) b = ad.Variable(name="b", shape=[3, 3]) c = ad.Variable(name="c", shape=[3, 3]) output = ad.einsum("ab,ab->", dist_op(a, c), dist_op(a, c)) new_output = distribute_graph_w_linearize(output) assert tree_eq(output, new_output, [a, c])
def test_einsum_simple_rewrite(backendopt): """ Rewrite the einsum expression. """ for datatype in backendopt: T.set_backend(datatype) a1 = ad.Variable(name="a1", shape=[3, 2]) a2 = ad.Variable(name="a2", shape=[2, 3]) x = ad.einsum('ik,kj->ij', a1, a2) x_new = fuse_einsums(x, [a1, a2]) assert tree_eq(x, x_new, [a1, a2])
def test_prune_inv_nodes_transpose(backendopt): for datatype in backendopt: A = ad.Variable(name="A", shape=[2, 2]) B = ad.Variable(name="B", shape=[2, 2]) inv_input = ad.einsum('ab,bc->ca', A, B) # inv(inv_input.T) @ inv_input.T output = ad.einsum('ca,cd,de->ae', ad.tensorinv(inv_input, ind=1), A, B) new_output = prune_inv_node(output) assert isinstance(new_output, ad.IdentityNode) assert tree_eq(output, new_output, [A, B], tol=1e-6)
def test_prune_scalar_nodes(backendopt): for datatype in backendopt: T.set_backend(datatype) a1 = ad.Variable(name="a1", shape=[3, 3]) a2 = ad.Variable(name="a2", shape=[3, 3]) s = ad.scalar(3.) out = ad.einsum("ab,,ab->ab", a1, s, a2) out_prune = prune_scalar_nodes(out) assert isinstance(out_prune, ad.MulByConstNode) assert tree_eq(out, out_prune, [a1, a2])
def test_split_einsum(backendopt): A = ad.Variable(name="A", shape=[2, 2]) B = ad.Variable(name="B", shape=[2, 2]) C = ad.Variable(name="C", shape=[2, 2]) D = ad.Variable(name="D", shape=[2, 2]) E = ad.Variable(name="E", shape=[2, 2]) einsum_node = ad.einsum("ab,bc,cd,de,ef->af", A, B, C, D, E) split_input_nodes = [A, B] new_einsum = split_einsum(einsum_node, split_input_nodes) assert len(new_einsum.inputs) == 3 # A, B, einsum(C, D, E) assert tree_eq(new_einsum, einsum_node, [A, B, C, D, E])
def test_dimension_tree_4d(): A = ad.Variable(name="A", shape=[2, 2]) B = ad.Variable(name="B", shape=[2, 2]) C = ad.Variable(name="C", shape=[2, 2]) D = ad.Variable(name="D", shape=[2, 2]) X = ad.Variable(name="X", shape=[2, 2, 2, 2]) einsum_node_A = ad.einsum("abcd,bm,cm,dm->am", X, B, C, D) einsum_node_B = ad.einsum("abcd,am,cm,dm->bm", X, A, C, D) einsum_node_C = ad.einsum("abcd,am,bm,dm->cm", X, A, B, D) einsum_node_D = ad.einsum("abcd,am,bm,cm->dm", X, A, B, C) dt = generate_sequential_optimal_tree( [einsum_node_A, einsum_node_B, einsum_node_C, einsum_node_D], [A, B, C, D]) # 5 inputs, 4 outputs, 5 intermedaites assert len(find_topo_sort(dt)) == 14 assert tree_eq(dt[0], einsum_node_A, [A, B, C, D, X]) assert tree_eq(dt[1], einsum_node_B, [A, B, C, D, X]) assert tree_eq(dt[2], einsum_node_C, [A, B, C, D, X]) assert tree_eq(dt[3], einsum_node_D, [A, B, C, D, X])
def test_simplify_inv_w_identity(backendopt): for datatype in backendopt: T.set_backend(datatype) A = ad.Variable(name="A", shape=[2, 2]) out = ad.einsum("ab,cd->acbd", A, ad.tensorinv(ad.identity(3))) newout = simplify(out) assert isinstance(newout, ad.EinsumNode) assert isinstance(newout.inputs[1], ad.IdentityNode) assert tree_eq(out, newout, [A], tol=1e-6)
def test_fuse_subgraph(backendopt): for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[2, 2]) b = ad.Variable(name="b", shape=[2, 2]) c = ad.Variable(name="c", shape=[2, 2]) d = ad.Variable(name="d", shape=[2, 2]) ab = ad.einsum("ab,bc->ac", a, b) abc = ad.einsum("ab,bc->ac", ab, c) abcd = ad.einsum("ab,bc->ac", abc, d) out_new = fuse_einsums(abcd, [ab, c, d]) assert tree_eq(abcd, out_new, [a, b, c, d])
def test_simplify_inv_w_redundent_einsum(backendopt): for datatype in backendopt: T.set_backend(datatype) A = ad.Variable(name="A", shape=[2, 2]) out = ad.einsum("ab,cd->abcd", A, ad.tensorinv(ad.einsum("ab->ab", A))) newout = simplify(out) inv_node = newout.inputs[1] assert isinstance(inv_node.inputs[0], ad.VariableNode) assert tree_eq(out, newout, [A], tol=1e-6)
def test_prune_identity_w_dup(backendopt): for datatype in backendopt: T.set_backend(datatype) a1 = ad.Variable(name="a1", shape=[3, 3]) i1 = ad.identity(3) i2 = ad.identity(3) i3 = ad.identity(3) out = ad.einsum("ab,bc,cd,de,ef->af", a1, a1, i1, i2, i3) prune_identity_nodes(out) out_expect = ad.einsum("ab,bc->ac", a1, a1) assert len(out.inputs) == 2 assert tree_eq(out, out_expect, [a1])
def test_kronecker_product_repeated_inputs(backendopt): for datatype in backendopt: T.set_backend(datatype) A = ad.Variable(name="A", shape=[2, 2]) out = ad.einsum("ab,cd->acbd", A, A) inv = ad.tensorinv(out) newinv = optimize_inverse(inv) assert isinstance(newinv, ad.EinsumNode) for node in newinv.inputs: assert isinstance(node, ad.TensorInverseNode) assert tree_eq(inv, newinv, [A], tol=1e-5)