def test_tree_distribution_two_layers(dist_op, backendopt): """ [Distributive] ((A + B) * G) * C will produce AGC + BGC Note that (A+B) * G is contracted first. """ for datatype in backendopt: if datatype == "taco": # '..,kk,..->..' is not supported in taco continue T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 2]) b = ad.Variable(name="b", shape=[3, 2]) g = ad.Variable(name="g", shape=[2, 2]) c = ad.Variable(name="c", shape=[2, 3]) interm = ad.einsum('ik, kk->ik', dist_op(a, b), g) output = ad.einsum('ik,kj->ij', interm, c) new_output = distribute_tree(output) assert isinstance(new_output, dist_op) assert tree_eq(output, new_output, [a, b, c, g])
def test_tree_distribution_ppE(dist_op, backendopt): """ [Distributive] ((A + B) + C) * G will produce AG + BG + CG Note that (A+B) has parent (A + B) + C. """ for datatype in backendopt: if datatype == "taco": # '..,kk,..->..' is not supported in taco continue T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 2]) b = ad.Variable(name="b", shape=[3, 2]) c = ad.Variable(name="c", shape=[3, 2]) g = ad.Variable(name="g", shape=[2, 2]) output = ad.einsum('ik,kk->ik', dist_op(dist_op(a, b), c), g) new_output = distribute_tree(output) assert isinstance(new_output, dist_op) assert tree_eq(output, new_output, [a, b, c, g])
def test_tree_distribution_w_add_output(dist_op, backendopt): """ Test C * (A + B) + F * (D + E) = (C * A + C * B) + (F * D + F * E) """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 3]) b = ad.Variable(name="b", shape=[3, 3]) c = ad.Variable(name="c", shape=[3, 3]) d = ad.Variable(name="d", shape=[3, 3]) e = ad.Variable(name="e", shape=[3, 3]) f = ad.Variable(name="f", shape=[3, 3]) out1 = ad.einsum('ik,kj->ij', c, dist_op(a, b)) out2 = ad.einsum('ik,kj->ij', d, dist_op(e, f)) output = dist_op(out1, out2) new_output = distribute_tree(output) assert isinstance(new_output, dist_op) for input_node in new_output.inputs: assert isinstance(input_node, dist_op) assert tree_eq(output, new_output, [a, b, c, d, e, f])
def test_tree_distribution_order(dist_op, backendopt): """ [Distributive] Test C * (A + B) = C * A + C * B """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 2]) b = ad.Variable(name="b", shape=[3, 2]) c = ad.Variable(name="c", shape=[2, 3]) output = ad.einsum('ik,kj->ij', c, dist_op(a, b)) new_output = distribute_tree(output) assert isinstance(new_output, dist_op) assert tree_eq(output, new_output, [a, b, c])
def test_tree_distribution_mim(dist_op, backendopt): """ [Distributive] (A + B) * G * (C + D) will produce AGC + BGD + BGC + DGB Note that G must be in the middle. We do the following, (A+B)*G*(C+D) = A*G*(C+D) + B*G*(C+D) = AGC + AGD + BGC + BGD Mim: man in the middle """ for datatype in backendopt: if datatype == "taco": # '..,kk,..->..' is not supported in taco continue T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 2]) b = ad.Variable(name="b", shape=[3, 2]) g = ad.Variable(name="g", shape=[2, 2]) c = ad.Variable(name="c", shape=[2, 3]) d = ad.Variable(name="d", shape=[2, 3]) add_nodeab = dist_op(a, b) add_nodecd = dist_op(c, d) output = ad.einsum('ik,kk,kj->ij', add_nodeab, g, add_nodecd) new_output = distribute_tree(output) assert isinstance(new_output, dist_op) for node in new_output.inputs: assert isinstance(node, dist_op) assert tree_eq(output, new_output, [a, b, c, d, g])
def test_tree_distribution_four_terms(dist_op, backendopt): """ [Distributive] (A + B) * (C + D) A B C D inputs \ | | / \ | | / \ | | / A + B C+D \ | \ | \ | output will produce AC + BD + BC + DB """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 2]) b = ad.Variable(name="b", shape=[3, 2]) c = ad.Variable(name="c", shape=[2, 3]) d = ad.Variable(name="d", shape=[2, 3]) dist_nodeab = dist_op(a, b) dist_nodecd = dist_op(c, d) output = ad.einsum('ik,kj->ij', dist_nodeab, dist_nodecd) # Idea: # (A + B) * (C + D) = A * (C + D) + B * (C + D) # Then do A * (C + D) and B * (C + D) new_output = distribute_tree(output) assert isinstance(new_output, dist_op) add1, add2 = new_output.inputs assert isinstance(add1, dist_op) assert isinstance(add2, dist_op) assert tree_eq(output, new_output, [a, b, c, d])
def test_tree_distribution(dist_op, backendopt): """ [Distributive] An einsum graph like A B C inputs \ | | \ | | \ | | A + B | \ | \ | \ | output will produce A C B inputs \ |\ | \ | \ | \ | \ | AC BC \ | \ | \ | output (AC + BC) put in workd (A + B) * C = A * C + B * C where * is an einsum node. """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 2]) b = ad.Variable(name="b", shape=[3, 2]) c = ad.Variable(name="c", shape=[2, 3]) add_node = dist_op(a, b) output = ad.einsum('ik,kj->ij', add_node, c) new_output = distribute_tree(output) assert isinstance(new_output, dist_op) assert tree_eq(output, new_output, [a, b, c])