def test_collapse_symmetric_expr_complex(): """ out1: A1 - a - A2 - b - A3 | | | c d e | | | H1 - f - H2 - g - H3 | | | h i j out2: a b c | | | H1 - d - H2 - e - H3 | | | f g h A1 - i - A2 - j - A3 """ H1 = ad.Variable(name="H1", shape=[2, 2, 2], symmetry=[[0, 2]]) H2 = ad.Variable(name="H2", shape=[2, 2, 2, 2], symmetry=[[0, 2]]) H3 = ad.Variable(name="H3", shape=[2, 2, 2], symmetry=[[0, 1]]) A1 = ad.Variable(name="H1", shape=[2, 2]) A2 = ad.Variable(name="H2", shape=[2, 2, 2]) A3 = ad.Variable(name="H3", shape=[2, 2]) out1 = ad.einsum("ca,dab,eb,cfh,dgif,ejg->hij", A1, A2, A3, H1, H2, H3) out2 = ad.einsum("fi,gij,hj,adf,begd,che->abc", A1, A2, A3, H1, H2, H3) collapse_symmetric_expr(out1, out2) assert out1.name == out2.name
def test_collapse_expr_w_identity(): a = ad.Variable(name="a", shape=[2, 2]) I = ad.identity(2) out1 = ad.einsum("ab,bc->ac", a, I) out2 = ad.einsum("ab,cb->ac", a, I) collapse_symmetric_expr(out1, out2) assert out1.name == out2.name
def test_cannot_collapse_expr(): h = ad.Variable(name="h", shape=[2, 2, 2, 2]) a = ad.Variable(name="a", shape=[2, 2]) out1 = ad.einsum("ijkl,ik->jl", h, a) out2 = ad.einsum("ijkl,jl->ik", h, a) collapse_symmetric_expr(out1, out2) assert out1.name != out2.name
def test_collapse_symmetry_w_multiple_contraction_ind(): H = ad.Variable(name="H", shape=[2, 2], symmetry=[[0, 1]]) x1 = ad.Variable(name="x1", shape=[2]) x2 = ad.Variable(name="x2", shape=[2]) inner1 = ad.einsum("ab,a,b->", H, x1, x2) inner2 = ad.einsum("ab,b,a->", H, x1, x2) collapse_symmetric_expr(inner1, inner2) assert inner1.name == inner2.name
def test_cannot_collapse_symmetric_expr(): h = ad.Variable(name="h", shape=[2, 2, 2, 2], symmetry=[[0, 1], [2, 3]]) a = ad.Variable(name="a", shape=[2, 2]) # non-einsum node collapse_symmetric_expr(h, a) assert h.name != a.name # different inputs out1 = ad.einsum("ijkl,ik->jl", h, a) out2 = ad.einsum("jl,ijkl->ik", a, h) collapse_symmetric_expr(out1, out2) assert out1.name != out2.name # different output shape out1 = ad.einsum("ijkl,ik->jl", h, a) out2 = ad.einsum("ijkl,ik->jkl", h, a) collapse_symmetric_expr(out1, out2) assert out1.name != out2.name
def simplify(output_node): """Simplify a graph with a single output node. The simplified form will distribute selected operations (+), and fuse all connected einsums. Args: node: The output node. Returns: node: The newly generated node. """ def fuse_all_einsums(node): linearize(node) ret_node = PseudoNode(node) all_pnodes = find_topo_sort_p([ret_node]) with OutputInjectedModeP(all_pnodes): trees = find_sub_einsumtree(ret_node) for tree in trees: out_node_p, in_nodes = tree new_z = fuse_einsums(out_node_p.node, in_nodes) prune_identity_nodes(new_z) replace_node(out_node_p, new_z) node = declone(ret_node.node) return node output_node = distribute_graph_w_linearize(output_node) output_node = fuse_all_einsums(output_node) output_pnode = PseudoNode(output_node) all_pnodes = find_topo_sort_p([output_pnode]) # optimize inverse with OutputInjectedModeP(all_pnodes): for pnode in all_pnodes: node = pnode.node if isinstance(node, ad.EinsumNode): # To make sure the same einsum nodes have the same same, # so that we can collapse the add node. rewrite_einsum_expr(node) if node.inputs != []: node.set_inputs(node.inputs) if isinstance(node, ad.TensorInverseNode): new_inv_node = optimize_inverse(node) replace_node(pnode, new_inv_node) # fuse again output_node = output_pnode.node output_node = fuse_all_einsums(output_node) # prune the orthonormal matmuls all_pnodes = find_topo_sort_p([output_pnode]) with OutputInjectedModeP(all_pnodes): for pnode in all_pnodes: node = pnode.node if node.inputs != []: node.set_inputs(node.inputs) if isinstance(node, ad.EinsumNode): new_node = prune_orthonormal_matmuls(node) replace_node(pnode, new_node) # prune inverse nodes output_pnode = PseudoNode(output_node) all_pnodes = find_topo_sort_p([output_pnode]) with OutputInjectedModeP(all_pnodes): for pnode in all_pnodes: node = pnode.node if node.inputs != []: node.set_inputs(node.inputs) if isinstance(node, ad.EinsumNode): new_node = prune_inv_node(node) replace_node(pnode, new_node) # prune the scalar nodes and remove unnecessary identity nodes all_pnodes = find_topo_sort_p([output_pnode]) with OutputInjectedModeP(all_pnodes): for pnode in all_pnodes: node = pnode.node if node.inputs != []: node.set_inputs(node.inputs) if isinstance(node, ad.EinsumNode): prune_identity_nodes(node) new_node = prune_scalar_nodes(node) replace_node(pnode, new_node) # collapse symmetric expressions all_pnodes = find_topo_sort_p([output_pnode]) for i in range(len(all_pnodes)): for j in range(i): collapse_symmetric_expr(all_pnodes[i].node, all_pnodes[j].node) sympy_input_types = (ad.DistributiveNode, ad.ScalarNode, ad.MulNode) #sympy_simplify the distributed nodes if isinstance(output_node, ad.DistributiveNode): sympy_inputs = [] all_nodes = find_topo_sort([output_node]) for node in all_nodes: if isinstance(node, ad.EinsumNode): # To make sure the same einsum nodes have the same name, # so that they can be reduced by sympy. rewrite_einsum_expr(node) if node.inputs != []: node.set_inputs(node.inputs) if not isinstance(node, sympy_input_types): sympy_inputs.append(node) output_node = sympy_simplify(output_node, sympy_inputs) return output_node