def test_tree_distribution_w_add_output(dist_op, backendopt): """ Test C * (A + B) + F * (D + E) = (C * A + C * B) + (F * D + F * E) """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 3]) b = ad.Variable(name="b", shape=[3, 3]) c = ad.Variable(name="c", shape=[3, 3]) d = ad.Variable(name="d", shape=[3, 3]) e = ad.Variable(name="e", shape=[3, 3]) f = ad.Variable(name="f", shape=[3, 3]) out1 = ad.einsum('ik,kj->ij', c, dist_op(a, b)) out2 = ad.einsum('ik,kj->ij', d, dist_op(e, f)) output = dist_op(out1, out2) new_output = distribute_tree(output) assert isinstance(new_output, dist_op) for input_node in new_output.inputs: assert isinstance(input_node, dist_op) assert tree_eq(output, new_output, [a, b, c, d, e, f])
def test_einsum_fuse_graph(backendopt): """ [Fuse einsum used twice] This case is rather subtle. We want to auto fuse A B C | \ / | es | /| | / | es | \ | es Here es is einsum. """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 3]) b = ad.Variable(name="b", shape=[3, 2]) c = ad.Variable(name="c", shape=[2, 3]) BC = ad.einsum('ik, kj->ij', b, c) # 3x3 ABC = ad.einsum('ik, kj->ij', a, BC) # 3x3 out = ad.einsum('jk, ki->ji', ABC, BC) # 3x3 linearize(out) tree, = find_sub_einsumtree(PseudoNode(out)) out, ins = tree new_z = fuse_einsums(out.node, ins) assert tree_eq(out.node, new_z, [a, b, c])
def test_tree_distribution_two_layers(dist_op, backendopt): """ [Distributive] ((A + B) * G) * C will produce AGC + BGC Note that (A+B) * G is contracted first. """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 2]) b = ad.Variable(name="b", shape=[3, 2]) g = ad.Variable(name="g", shape=[2, 2]) c = ad.Variable(name="c", shape=[2, 3]) interm = ad.einsum('ik, kk->ik', dist_op(a, b), g) output = ad.einsum('ik,kj->ij', interm, c) new_output = distribute_tree(output) assert isinstance(new_output, dist_op) assert tree_eq(output, new_output, [a, b, c, g])
def _distribute(binary_op_node, output): """ Distribute the operations. E.g (A + B) * C = A * C + B * C Currently only consider the case where the binary_op is plus node. Args: binary_op_node: This is the plus node output: This is the (A + B) * C node. Return: The new output node that already distribute the computation. """ assert isinstance(binary_op_node, ad.DistributiveNode) assert isinstance(output, ad.EinsumNode) assert binary_op_node in output.inputs # Then find the childs, the binary op should only have two. A, B = binary_op_node.inputs AC_seq = [ tmp if tmp.name != binary_op_node.name else A for tmp in output.inputs ] BC_seq = [ tmp if tmp.name != binary_op_node.name else B for tmp in output.inputs ] AC = ad.einsum(output.einsum_subscripts, *AC_seq) BC = ad.einsum(output.einsum_subscripts, *BC_seq) return type(binary_op_node)(AC, BC)
def test_einsum_multiuse(backendopt): """ Test manual fuse. A B inputs |\ | | \ | | \ | | C | / | / output Note that here we assume A is split into 2 vars by some other operations. """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a1", shape=[3, 2]) a_copy = ad.Variable(name="a2", shape=[3, 2]) b = ad.Variable(name="b", shape=[2, 3]) c = ad.einsum('ik,kj->ij', a, b) output = ad.einsum('ik,ij->kj', a_copy, c) # New graph out_new = fuse_einsums(output, [a, a_copy, b]) assert tree_eq(output, out_new, [a, a_copy, b])
def test_einsum_multiuse_auto_copy(backendopt): """ Test autolinearization and auto fuse. A B inputs |\ | | \ | | \ | | C | / | / output Next: we would need to autoprune. """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a1", shape=[3, 2]) b = ad.Variable(name="b", shape=[2, 3]) c = ad.einsum('ik,kj->ij', a, b) output = ad.einsum('ik,ij->kj', a, c) linearize(output) all_nodes = find_topo_sort([output]) cloned_nodes = [ tmp for tmp in all_nodes if isinstance(tmp, ad.CloneNode) ] out_new = fuse_einsums(output, [*cloned_nodes, b]) # Test that every inputs is now fused. assert all([not isinstance(x, ad.EinsumNode) for x in out_new.inputs]) assert tree_eq(output, out_new, [*cloned_nodes, b])
def test_get_transpose_indices_dup(): a = ad.Variable(name='a', shape=[2, 2]) h = ad.Variable(name='h', shape=[2, 2, 2]) out1 = ad.einsum("ad,bc,ecd->abe", a, a, h) out2 = ad.einsum("ac,bd,ecd->eab", a, a, h) trans = get_transpose_indices(out1, out2) assert trans == [2, 0, 1] or trans == [2, 1, 0]
def test_collapse_symmetric_expr_complex(): """ out1: A1 - a - A2 - b - A3 | | | c d e | | | H1 - f - H2 - g - H3 | | | h i j out2: a b c | | | H1 - d - H2 - e - H3 | | | f g h A1 - i - A2 - j - A3 """ H1 = ad.Variable(name="H1", shape=[2, 2, 2], symmetry=[[0, 2]]) H2 = ad.Variable(name="H2", shape=[2, 2, 2, 2], symmetry=[[0, 2]]) H3 = ad.Variable(name="H3", shape=[2, 2, 2], symmetry=[[0, 1]]) A1 = ad.Variable(name="H1", shape=[2, 2]) A2 = ad.Variable(name="H2", shape=[2, 2, 2]) A3 = ad.Variable(name="H3", shape=[2, 2]) out1 = ad.einsum("ca,dab,eb,cfh,dgif,ejg->hij", A1, A2, A3, H1, H2, H3) out2 = ad.einsum("fi,gij,hj,adf,begd,che->abc", A1, A2, A3, H1, H2, H3) collapse_symmetric_expr(out1, out2) assert out1.name == out2.name
def test_executor_dependent(backendopt): for datatype in backendopt: T.set_backend(datatype) A = ad.Variable(name="A", shape=[3, 3]) B = ad.Variable(name="B", shape=[3, 3]) AA = ad.einsum('ab,ab->', A, A) BB = ad.einsum('ab,ab->', B, B) AB = ad.einsum('ab,ab->', A, B) out_A = AA + AB out_B = AB + AA executor = ad.Executor({out_A, out_B}) data = gen_dict([A, B]) A_val, = executor.run(feed_dict=data, reset_graph=False, out_nodes=[out_A]) data2 = gen_dict([A]) data2.update({B: data[B]}) B_val, = executor.run(feed_dict=data2, out_nodes=[out_B]) # This is checking A's val is not reused in B_val computationA. assert A_val != B_val
def test_hvp2(backendopt): for datatype in backendopt: T.set_backend(datatype) x = ad.Variable(name="x", shape=[3, 1]) H = ad.Variable(name="H", shape=[3, 3]) v = ad.Variable(name="v", shape=[3, 1]) y = ad.sum( ad.einsum("ab,bc->ac", ad.einsum("ab,bc->ac", ad.transpose(x), H), x)) grad_x, = ad.gradients(y, [x]) Hv, = ad.hvp(output_node=y, node_list=[x], vector_list=[v]) executor = ad.Executor([y, grad_x, Hv]) x_val = T.tensor([[1.], [2.], [3]]) # 3x1 v_val = T.tensor([[1.], [2.], [3]]) # 3x1 H_val = T.tensor([[2., 0., 0.], [0., 2., 0.], [0., 0., 2.]]) # 3x3 y_val, grad_x_val, Hv_val = executor.run(feed_dict={ x: x_val, H: H_val, v: v_val }) expected_yval = T.sum(T.transpose(x_val) @ H_val @ x_val) expected_grad_x_val = 2 * H_val @ x_val expected_hv_val = T.tensor([[4.], [8.], [12.]]) assert isinstance(y, ad.Node) assert T.array_equal(y_val, expected_yval) assert T.array_equal(grad_x_val, expected_grad_x_val) assert T.array_equal(Hv_val, expected_hv_val)
def test_einsum_fuse_w_identity(backendopt): """ [Fuse einsum with multiple identities] We want to fuse A identity identity | \ / | \ / | \ / | es | / | / | / | / es Here es is einsum. """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 3]) es_identity = ad.einsum('ik,kj->ij', ad.identity(3), ad.identity(3)) out = ad.einsum('ai,ij->aj', a, es_identity) tree, = find_sub_einsumtree(PseudoNode(out)) out, ins = tree new_out = fuse_einsums(out.node, ins) assert tree_eq(out.node, new_out, [a])
def test_prune_identity(backendopt): for datatype in backendopt: T.set_backend(datatype) a1 = ad.Variable(name="a1", shape=[3, 3]) a2 = ad.Variable(name="a2", shape=[3, 3]) i1 = ad.identity(3) i2 = ad.identity(3) i3 = ad.identity(3) out = ad.einsum("ab,cd,ac,be,ef->abdf", a1, a2, i1, i2, i3) prune_identity_nodes(out) """ Explanation to the einsum above: The identity node i1 means that a and c should be the same dim. we can get rid of i1 and rewrite the expr as ad.einsum("ab,ad,be,ef->abdf", a1, a2, i2, i3). we can also combine i2 and i3 cuz e is useless. Therefore, we can rewrite the expr as ad.einsum("ab,ad,bf->abdf", a1, a2, i2). """ out_expect = ad.einsum("ab,ad,bf->abdf", a1, a2, i2) assert len(out.inputs) == 3 assert tree_eq(out, out_expect, [a1, a2])
def test_prune_inv_different_num_inputs_no_pruning(backendopt): A = ad.Variable(name="A", shape=[2, 2]) inv_input = ad.einsum('ab,bc->ac', A, A) output = ad.einsum('ab,bc->ac', ad.tensorinv(inv_input, ind=1), A) new_output = prune_inv_node(output) assert new_output is output
def test_get_common_ancester_intermediate_leaves(backendopt): a = ad.Variable(name="a", shape=[2, 2]) b = ad.Variable(name="b", shape=[2, 2]) c = ad.einsum("ab,bc->ac", a, b) d = ad.einsum("ab,ab->ab", c, c) ancester = get_common_ancestor(d, d.inputs, c) assert ancester == d
def test_prune_inv_set_not_match(backendopt): A = ad.Variable(name="A", shape=[2, 2]) B = ad.Variable(name="B", shape=[2, 2]) inv = ad.tensorinv(ad.einsum('ab,bc->ac', A, B), ind=1) output = ad.einsum('ab,bc->ac', inv, A) new_output = prune_inv_node(output) assert new_output is output
def test_cannot_collapse_expr(): h = ad.Variable(name="h", shape=[2, 2, 2, 2]) a = ad.Variable(name="a", shape=[2, 2]) out1 = ad.einsum("ijkl,ik->jl", h, a) out2 = ad.einsum("ijkl,jl->ik", h, a) collapse_symmetric_expr(out1, out2) assert out1.name != out2.name
def test_collapse_expr_w_identity(): a = ad.Variable(name="a", shape=[2, 2]) I = ad.identity(2) out1 = ad.einsum("ab,bc->ac", a, I) out2 = ad.einsum("ab,cb->ac", a, I) collapse_symmetric_expr(out1, out2) assert out1.name == out2.name
def test_collapse_symmetry_w_multiple_contraction_ind(): H = ad.Variable(name="H", shape=[2, 2], symmetry=[[0, 1]]) x1 = ad.Variable(name="x1", shape=[2]) x2 = ad.Variable(name="x2", shape=[2]) inner1 = ad.einsum("ab,a,b->", H, x1, x2) inner2 = ad.einsum("ab,b,a->", H, x1, x2) collapse_symmetric_expr(inner1, inner2) assert inner1.name == inner2.name
def test_prune_inv_nonmatmul_no_pruning(backendopt): A = ad.Variable(name="A", shape=[2, 2]) B = ad.Variable(name="B", shape=[2, 2]) inv_input = ad.einsum('ab,bc->ac', A, B) # inv(inv_input) * inv_input.T, cannot be pruned output = ad.einsum('ac,ab,bc->ac', ad.tensorinv(inv_input, ind=1), A, B) new_output = prune_inv_node(output) assert new_output is output
def test_einsum_rewrite_duplicate_input(backendopt): a = ad.Variable(name="a", shape=[3, 2]) x = ad.einsum('ca,cb->ab', a, a) y = ad.einsum('cb,ca->ab', a, a) rewrite_einsum_expr(x) rewrite_einsum_expr(y) assert x.einsum_subscripts == y.einsum_subscripts
def test_einsum_equal_repeated_transpose(backendopt): A = ad.Variable(name="A", shape=[3, 5]) x = ad.einsum('or,ob->br', A, A) y = ad.einsum('eb,ed->bd', A, A) uf1 = rewrite_einsum_expr(x) uf2 = rewrite_einsum_expr(y) assert x.einsum_subscripts == y.einsum_subscripts assert x.inputs == y.inputs
def test_remove_transposes_multiple_trans(): a = ad.Variable(name="a", shape=[2, 2, 2, 2]) intermediate1 = ad.einsum("abcd->dcba", a) intermediate2 = ad.einsum("abcd->abdc", a) ret1 = ad.einsum("dcba->badc", intermediate1) ret2 = ad.einsum("abdc->badc", intermediate2) remove_transposes(find_topo_sort([ret1, ret2])) assert ret1.name == ret2.name
def test_einsum_fuse_only_identity(backendopt): for datatype in backendopt: T.set_backend(datatype) es_identity = ad.einsum('ik,kj->ij', ad.identity(3), ad.identity(3)) out = ad.einsum('ai,ij->aj', ad.identity(3), es_identity) tree, = find_sub_einsumtree(PseudoNode(out)) out, ins = tree new_out = fuse_einsums(out.node, ins) assert tree_eq(out.node, new_out, [])
def test_einsum_equal(backendopt): a1 = ad.Variable(name="a1", shape=[3, 2]) a2 = ad.Variable(name="a2", shape=[2, 3]) x = ad.einsum('ik,kj->ij', a1, a2) y = ad.einsum('ml,sm->sl', a2, a1) rewrite_einsum_expr(x) rewrite_einsum_expr(y) assert x.einsum_subscripts == y.einsum_subscripts assert x.inputs == y.inputs
def test_prune_inv_nodes_transpose(backendopt): for datatype in backendopt: A = ad.Variable(name="A", shape=[2, 2]) B = ad.Variable(name="B", shape=[2, 2]) inv_input = ad.einsum('ab,bc->ca', A, B) # inv(inv_input.T) @ inv_input.T output = ad.einsum('ca,cd,de->ae', ad.tensorinv(inv_input, ind=1), A, B) new_output = prune_inv_node(output) assert isinstance(new_output, ad.IdentityNode) assert tree_eq(output, new_output, [A, B], tol=1e-6)
def test_get_common_ancestor_w_inv(backendopt): A = ad.Variable(name="A", shape=[3, 3]) X = ad.Variable(name="X", shape=[3, 3, 3]) inv = ad.tensorinv(ad.einsum("ab,ac->bc", A, A), ind=1) einsum_node = ad.einsum('abc,ad,ce->bce', X, A, inv) opt_einsum = generate_optimal_tree(einsum_node) sub_einsum = get_common_ancestor(opt_einsum, einsum_node.inputs, A) # sub_einsum should be ad.einsum('ad,abc->dbc',A,X), and shouldn't include the inv node. assert sorted(get_all_inputs(sub_einsum), key=lambda node: node.name) == sorted( [A, X], key=lambda node: node.name)
def test_einsum_equal_repeated_transpose(backendopt): A = ad.Variable(name="A", shape=[3, 3]) B = ad.Variable(name="B", shape=[3, 3]) x = ad.einsum("ac,ba,bc->", A, A, B) y = ad.einsum("ba,ac,bc->", A, A, B) uf1 = rewrite_einsum_expr(x) uf2 = rewrite_einsum_expr(y) assert x.einsum_subscripts == y.einsum_subscripts assert x.inputs == y.inputs
def test_simplify_symmetric_einsum_expr(backendopt): H = ad.Variable(name="H", shape=[2, 2], symmetry=[[0, 1]]) x1 = ad.Variable(name="x1", shape=[2]) x2 = ad.Variable(name="x2", shape=[2]) inner1 = ad.einsum("ab,a,b->", H, x1, x2) inner2 = ad.einsum("ab,b,a->", H, x1, x2) out = 0.5 * inner1 + 0.5 * inner2 newout_simplify = simplify(out) # ad.einsum("ab,a,b->", H, x1, x2) or ad.einsum("ab,b,a->", H, x1, x2) assert isinstance(newout_simplify, ad.EinsumNode)
def test_simplify_optimize_w_tail_einsum(backendopt): for datatype in backendopt: T.set_backend(datatype) A = ad.Variable(name="A", shape=[2, 2]) out = ad.einsum("ab,bc->ac", A, ad.einsum("ab,bc->ac", ad.identity(2), ad.identity(2))) newout_optimize = optimize(out) newout_simplify = simplify(out) assert newout_optimize == A assert newout_simplify == A
def test_einsum_equal_uf_assign_order(backendopt): A = ad.Variable(name="A", shape=[3, 3]) B = ad.Variable(name="B", shape=[3, 3]) I = ad.identity(10) x = ad.einsum('pb,or,ob,pr,st->srtb', B, A, A, B, I) y = ad.einsum('eb,ed,fb,fd,ac->abcd', A, A, B, B, I) uf1 = rewrite_einsum_expr(x) uf2 = rewrite_einsum_expr(y) assert x.einsum_subscripts == y.einsum_subscripts assert x.inputs == y.inputs