def test_kronecker_product_non_even(backendopt): A = ad.Variable(name="A", shape=[4, 4, 2, 2]) B = ad.Variable(name="B", shape=[2, 2]) out = ad.einsum("abcd,ef->abcdef", A, B) inv = ad.tensorinv(out, ind=2) newinv = optimize_inverse(inv) assert inv is newinv
def test_kronecker_product_nondecomposable(backendopt): A = ad.Variable(name="A", shape=[2, 3]) B = ad.Variable(name="B", shape=[3, 2]) out = ad.einsum("ab,cd->acbd", A, B) inv = ad.tensorinv(out) newinv = optimize_inverse(inv) assert inv is newinv
def test_kronecker_product_repeated_inputs(backendopt): for datatype in backendopt: T.set_backend(datatype) A = ad.Variable(name="A", shape=[2, 2]) out = ad.einsum("ab,cd->acbd", A, A) inv = ad.tensorinv(out) newinv = optimize_inverse(inv) assert isinstance(newinv, ad.EinsumNode) for node in newinv.inputs: assert isinstance(node, ad.TensorInverseNode) assert tree_eq(inv, newinv, [A], tol=1e-5)
def test_inv_multiple_decomposation(backendopt): for datatype in backendopt: T.set_backend(datatype) A = ad.Variable(name="A", shape=[2, 2]) B = ad.Variable(name="B", shape=[2, 2]) C = ad.Variable(name="C", shape=[2, 2]) out = ad.einsum("ab,cd,ef->acebdf", A, B, C) inv = ad.tensorinv(out) newinv = optimize_inverse(inv) assert isinstance(newinv, ad.EinsumNode) for node in newinv.inputs: assert isinstance(node, ad.TensorInverseNode) assert len(newinv.inputs) == 3 assert tree_eq(inv, newinv, [A, B, C], tol=1e-5)
def test_high_dim_inv(backendopt): for datatype in backendopt: T.set_backend(datatype) A = ad.Variable(name="A", shape=[2, 2, 2, 2]) B = ad.Variable(name="B", shape=[2, 2, 2, 2]) out = ad.einsum("aceg,dbhf->abcdefgh", A, B) inv = ad.tensorinv(out) # T.einsum('aceg,bdfh->abcdefgh', # T.tensorinv(T.einsum('aceg->aceg',A), ind=2), # T.tensorinv(T.einsum('dbhf->bdfh',B), ind=2)) newinv = optimize_inverse(inv) assert isinstance(newinv, ad.EinsumNode) for node in newinv.inputs: assert isinstance(node, ad.TensorInverseNode) assert tree_eq(inv, newinv, [A, B], tol=1e-6)
def test_complex_product_inv(backendopt): for datatype in backendopt: T.set_backend(datatype) A = ad.Variable(name="A", shape=[2, 2]) B = ad.Variable(name="B", shape=[2, 2]) C = ad.Variable(name="C", shape=[2, 2]) D = ad.Variable(name="D", shape=[2, 2]) out = ad.einsum("ab,bc,de,ef->adcf", A, B, C, D) inv = ad.tensorinv(out) # T.einsum('ac,df->adcf', # T.tensorinv(T.einsum('ab,bc->ac',A,B), ind=1), # T.tensorinv(T.einsum('de,ef->df',C,D), ind=1)) newinv = optimize_inverse(inv) assert isinstance(newinv, ad.EinsumNode) for node in newinv.inputs: assert isinstance(node, ad.TensorInverseNode) assert tree_eq(inv, newinv, [A, B, C, D], tol=1e-5)
def simplify(output_node): """Simplify a graph with a single output node. The simplified form will distribute selected operations (+), and fuse all connected einsums. Args: node: The output node. Returns: node: The newly generated node. """ def fuse_all_einsums(node): linearize(node) ret_node = PseudoNode(node) all_pnodes = find_topo_sort_p([ret_node]) with OutputInjectedModeP(all_pnodes): trees = find_sub_einsumtree(ret_node) for tree in trees: out_node_p, in_nodes = tree new_z = fuse_einsums(out_node_p.node, in_nodes) prune_identity_nodes(new_z) replace_node(out_node_p, new_z) node = declone(ret_node.node) return node output_node = distribute_graph_w_linearize(output_node) output_node = fuse_all_einsums(output_node) output_pnode = PseudoNode(output_node) all_pnodes = find_topo_sort_p([output_pnode]) # optimize inverse with OutputInjectedModeP(all_pnodes): for pnode in all_pnodes: node = pnode.node if isinstance(node, ad.EinsumNode): # To make sure the same einsum nodes have the same same, # so that we can collapse the add node. rewrite_einsum_expr(node) if node.inputs != []: node.set_inputs(node.inputs) if isinstance(node, ad.TensorInverseNode): new_inv_node = optimize_inverse(node) replace_node(pnode, new_inv_node) # fuse again output_node = output_pnode.node output_node = fuse_all_einsums(output_node) # prune the orthonormal matmuls all_pnodes = find_topo_sort_p([output_pnode]) with OutputInjectedModeP(all_pnodes): for pnode in all_pnodes: node = pnode.node if node.inputs != []: node.set_inputs(node.inputs) if isinstance(node, ad.EinsumNode): new_node = prune_orthonormal_matmuls(node) replace_node(pnode, new_node) # prune inverse nodes output_pnode = PseudoNode(output_node) all_pnodes = find_topo_sort_p([output_pnode]) with OutputInjectedModeP(all_pnodes): for pnode in all_pnodes: node = pnode.node if node.inputs != []: node.set_inputs(node.inputs) if isinstance(node, ad.EinsumNode): new_node = prune_inv_node(node) replace_node(pnode, new_node) # prune the scalar nodes and remove unnecessary identity nodes all_pnodes = find_topo_sort_p([output_pnode]) with OutputInjectedModeP(all_pnodes): for pnode in all_pnodes: node = pnode.node if node.inputs != []: node.set_inputs(node.inputs) if isinstance(node, ad.EinsumNode): prune_identity_nodes(node) new_node = prune_scalar_nodes(node) replace_node(pnode, new_node) # collapse symmetric expressions all_pnodes = find_topo_sort_p([output_pnode]) for i in range(len(all_pnodes)): for j in range(i): collapse_symmetric_expr(all_pnodes[i].node, all_pnodes[j].node) sympy_input_types = (ad.DistributiveNode, ad.ScalarNode, ad.MulNode) #sympy_simplify the distributed nodes if isinstance(output_node, ad.DistributiveNode): sympy_inputs = [] all_nodes = find_topo_sort([output_node]) for node in all_nodes: if isinstance(node, ad.EinsumNode): # To make sure the same einsum nodes have the same name, # so that they can be reduced by sympy. rewrite_einsum_expr(node) if node.inputs != []: node.set_inputs(node.inputs) if not isinstance(node, sympy_input_types): sympy_inputs.append(node) output_node = sympy_simplify(output_node, sympy_inputs) return output_node