def test_prune_identity(backendopt): for datatype in backendopt: T.set_backend(datatype) a1 = ad.Variable(name="a1", shape=[3, 3]) a2 = ad.Variable(name="a2", shape=[3, 3]) i1 = ad.identity(3) i2 = ad.identity(3) i3 = ad.identity(3) out = ad.einsum("ab,cd,ac,be,ef->abdf", a1, a2, i1, i2, i3) prune_identity_nodes(out) """ Explanation to the einsum above: The identity node i1 means that a and c should be the same dim. we can get rid of i1 and rewrite the expr as ad.einsum("ab,ad,be,ef->abdf", a1, a2, i2, i3). we can also combine i2 and i3 cuz e is useless. Therefore, we can rewrite the expr as ad.einsum("ab,ad,bf->abdf", a1, a2, i2). """ out_expect = ad.einsum("ab,ad,bf->abdf", a1, a2, i2) assert len(out.inputs) == 3 assert tree_eq(out, out_expect, [a1, a2])
def test_einsum_fuse_w_identity(backendopt): """ [Fuse einsum with multiple identities] We want to fuse A identity identity | \ / | \ / | \ / | es | / | / | / | / es Here es is einsum. """ for datatype in backendopt: T.set_backend(datatype) a = ad.Variable(name="a", shape=[3, 3]) es_identity = ad.einsum('ik,kj->ij', ad.identity(3), ad.identity(3)) out = ad.einsum('ai,ij->aj', a, es_identity) tree, = find_sub_einsumtree(PseudoNode(out)) out, ins = tree new_out = fuse_einsums(out.node, ins) assert tree_eq(out.node, new_out, [a])
def test_simplify_optimize_w_tail_einsum(): A = ad.Variable(name="A", shape=[2, 2]) out = ad.einsum("ab,bc->ac", A, ad.einsum("ab,bc->ac", ad.identity(2), ad.identity(2))) newout_optimize = optimize(out) newout_simplify = simplify(out) assert newout_optimize == A assert newout_simplify == A
def test_einsum_fuse_only_identity(backendopt): for datatype in backendopt: T.set_backend(datatype) es_identity = ad.einsum('ik,kj->ij', ad.identity(3), ad.identity(3)) out = ad.einsum('ai,ij->aj', ad.identity(3), es_identity) tree, = find_sub_einsumtree(PseudoNode(out)) out, ins = tree new_out = fuse_einsums(out.node, ins) assert tree_eq(out.node, new_out, [])
def test_simplify_optimize_w_tail_einsum(backendopt): for datatype in backendopt: T.set_backend(datatype) A = ad.Variable(name="A", shape=[2, 2]) out = ad.einsum("ab,bc->ac", A, ad.einsum("ab,bc->ac", ad.identity(2), ad.identity(2))) newout_optimize = optimize(out) newout_simplify = simplify(out) assert newout_optimize == A assert newout_simplify == A
def test_prune_identity_w_dup(backendopt): for datatype in backendopt: T.set_backend(datatype) a1 = ad.Variable(name="a1", shape=[3, 3]) i1 = ad.identity(3) i2 = ad.identity(3) i3 = ad.identity(3) out = ad.einsum("ab,bc,cd,de,ef->af", a1, a1, i1, i2, i3) prune_identity_nodes(out) out_expect = ad.einsum("ab,bc->ac", a1, a1) assert len(out.inputs) == 2 assert tree_eq(out, out_expect, [a1])
def prune_identity_nodes(einsum_node): """ reduce the number of identity nodes in the einsum_node's inputs. Inplace update. Args: einsum_node: An fused einsum node. """ if not (isinstance(einsum_node, ad.EinsumNode)): return uf_str, p_outnode, p_innodes = generate_einsum_info(einsum_node) whole_str = p_outnode.subscript + "".join( [node.subscript for node in p_innodes]) p_identity_nodes = list( filter(lambda pnode: isinstance(pnode.node, ad.IdentityNode), p_innodes)) p_variable_nodes = [ pnode for pnode in p_innodes if pnode not in p_identity_nodes ] # each disjoint set in uf_identity represents the indices # linked by identity node uf_identity = UF(list(whole_str)) for pnode in p_identity_nodes: uf_identity.connect(pnode.subscript[0], pnode.subscript[1]) input_indices_set, output_indices_set = set(), set() for pnode in p_variable_nodes: # replace subscripts by the root chars sub_list = [uf_identity.root(char) for char in pnode.subscript] pnode.subscript = "".join(sub_list) input_indices_set |= set(sub_list) p_updated_inputs = p_variable_nodes out_sub_list = [] for i, char in enumerate(p_outnode.subscript): uf_root_char = uf_identity.root(char) if uf_root_char in output_indices_set: # we cannot assign the same char to two indices in the # output. Therefore, assign a new char, and add one # identity node to the inputs to show the constraint. new_char = uf_str.cg.getchar() out_sub_list.append(new_char) p_identity_node = PseudoNode(node=ad.identity( einsum_node.shape[i]), subscript=f"{uf_root_char}{new_char}") p_updated_inputs.append(p_identity_node) else: # directly assign the root char to the subscripts out_sub_list.append(uf_root_char) output_indices_set.add(uf_root_char) p_outnode.subscript = "".join(out_sub_list) new_input_subs = [pnode.subscript for pnode in p_updated_inputs] new_subscripts = ",".join(new_input_subs) + "->" + p_outnode.subscript einsum_node.einsum_subscripts = new_subscripts einsum_node.set_inputs([pnode.node for pnode in p_updated_inputs])
def test_collapse_expr_w_identity(): a = ad.Variable(name="a", shape=[2, 2]) I = ad.identity(2) out1 = ad.einsum("ab,bc->ac", a, I) out2 = ad.einsum("ab,cb->ac", a, I) collapse_symmetric_expr(out1, out2) assert out1.name == out2.name
def test_einsum_equal_uf_assign_order(): A = ad.Variable(name="A", shape=[3, 3]) B = ad.Variable(name="B", shape=[3, 3]) I = ad.identity(10) x = ad.einsum('pb,or,ob,pr,st->srtb', B, A, A, B, I) y = ad.einsum('eb,ed,fb,fd,ac->abcd', A, A, B, B, I) uf1 = rewrite_einsum_expr(x) uf2 = rewrite_einsum_expr(y) assert x.einsum_subscripts == y.einsum_subscripts assert x.inputs == y.inputs
def test_simplify_inv_w_identity(backendopt): for datatype in backendopt: T.set_backend(datatype) A = ad.Variable(name="A", shape=[2, 2]) out = ad.einsum("ab,cd->acbd", A, ad.tensorinv(ad.identity(3))) newout = simplify(out) assert isinstance(newout, ad.EinsumNode) assert isinstance(newout.inputs[1], ad.IdentityNode) assert tree_eq(out, newout, [A], tol=1e-6)
def test_s2s_w_constants(backendopt): for datatype in backendopt: T.set_backend(datatype) A = ad.Variable(name="A", shape=[2, 2]) I = ad.identity(2) B = ad.einsum("ab,bc->ac", A, I) A_val = T.tensor([[1., 2.], [3., 4.]]) StS = SourceToSource() fwd_str = StS.forward([B], function_name='fwd', backend=datatype) m = import_code(fwd_str) out, = m.fwd([A_val]) assert T.array_equal(A_val, out)
def test_dimension_tree_w_identity(): A = ad.Variable(name="A", shape=[2, 2]) B = ad.identity(2) C = ad.Variable(name="C", shape=[2, 2]) X = ad.Variable(name="X", shape=[2, 2, 2]) einsum_node_A = ad.einsum("abc,bm,cm->am", X, B, C) einsum_node_B = ad.einsum("abc,am,cm->bm", X, A, C) einsum_node_C = ad.einsum("abc,am,bm->cm", X, A, B) dt = generate_sequential_optimal_tree( [einsum_node_A, einsum_node_B, einsum_node_C], [A, B, C]) assert tree_eq(dt[0], einsum_node_A, [A, C, X]) assert tree_eq(dt[1], einsum_node_B, [A, C, X]) assert tree_eq(dt[2], einsum_node_C, [A, B, X])
def prune_single_inv_node(einsum_node, inv_node): """ Prune the inv_node in the einsum node if condition mets. Note: 1. can only optimize the node when the input of inv is an einsum node. 2. only supports the case when the splitted nodes are different from the remaining ones. For example: ad.einsum("ab,bc,cd,de->ae", inv("ab,bc->ac", A, B), A, B, C) will be optimzied to ad.einsum("ab,bc->ac", C, ad.identity()), but we cannot optimize ad.einsum("ab,bc,cd,de->ae", inv("ab,bc->ac", A, B), A, B, B). Parameters ---------- einsum_node: The fused einsum node inv_node: the input inv node to be pruned Returns ------- If the einsum_node cannot be optimized, then return the input einsum_node. If it can be optimized, return the optimized einsum node. """ from autohoot.einsum_graph.expr_generator import rewrite_einsum_expr from autohoot.graph_ops.optimal_tree import split_einsum inv_node_input = inv_node.inputs[0] if not isinstance(inv_node_input, ad.EinsumNode): logger.info(f"inv input is not einsum node, can't prune inv") return einsum_node if not set(inv_node_input.inputs).issubset(set(einsum_node.inputs)): logger.info( f"inv inputs is not subset of einsum node inputs, can't prune inv") return einsum_node einsum_inputs_in_inv = [ n for n in einsum_node.inputs if n in inv_node_input.inputs ] if len(einsum_inputs_in_inv) < len(inv_node_input.inputs): logger.info( f"number of inv inputs is more than corresponding einsum inputs, can't prune inv" ) return einsum_node split_einsum_node = split_einsum( einsum_node, list(set(einsum_node.inputs) - set(inv_node_input.inputs))) # Assign pseudo nodes and chars in_subs, out_subs, _ = parse_einsum_input( (split_einsum_node.einsum_subscripts, *split_einsum_node.inputs)) in_subs_list = in_subs.split(',') updated_p_in_nodes = [] for i, node in enumerate(split_einsum_node.inputs): if isinstance(node, ad.EinsumNode): p_einsum_input = PseudoNode(node=node, subscript=in_subs_list[i]) elif node is inv_node: p_inv_input = PseudoNode(node=node, subscript=in_subs_list[i]) else: updated_p_in_nodes.append( PseudoNode(node=node, subscript=in_subs_list[i])) contract_char = "".join( set(p_einsum_input.subscript) & set(p_inv_input.subscript)) uncontract_str = "".join( set("".join([p_einsum_input.subscript, p_inv_input.subscript])) - set(contract_char)) if not (len(p_einsum_input.subscript) == 2 and len(p_inv_input.subscript) == 2 and len(contract_char) == 1 and len(uncontract_str) == 2): # this is not a matmul. Just return the initial node logger.info( f"the op between inv input and the selected einsum is not matmul, can't prune inv" ) return einsum_node if p_einsum_input.subscript[0] == p_inv_input.subscript[ 0] or p_einsum_input.subscript[1] == p_inv_input.subscript[1]: # the str is like "ab,ac", and one einsum needs to be transposed to compare p_in_subs, p_out_subs, _ = parse_einsum_input( (p_einsum_input.node.einsum_subscripts, *p_einsum_input.node.inputs)) einsum_input = ad.einsum( f"{p_in_subs}->{p_out_subs[1]}{p_out_subs[0]}", *p_einsum_input.node.inputs) else: einsum_input = p_einsum_input.node rewrite_einsum_expr(einsum_input) rewrite_einsum_expr(inv_node_input) if einsum_input.name != inv_node_input.name: logger.info( f"inv input and the selected einsum have different expressions, can't prune inv" ) return einsum_node # prune the inv node updated_p_in_nodes = updated_p_in_nodes + [ PseudoNode(node=ad.identity(inv_node_input.shape[0]), subscript=uncontract_str) ] return generate_new_einsum(updated_p_in_nodes, out_subs)
def prune_orthonormal_matmuls(einsum_node): """ Remove the matrices of a einsum_node if M @ M.T like structures exist. Args: einsum_node: An fused einsum node. Return: An optimized einsum node. """ # A map from the orthonormal matrix mode to (orthonormal_index, contraction_index) orthonormal_indices_map = {'column': (0, 1), 'row': (1, 0)} _, p_outnode, p_innodes = generate_einsum_info(einsum_node) subs_list = [pnode.subscript for pnode in p_innodes] + [p_outnode.subscript] ortho_pnode_map = {} for pnode in p_innodes: if isinstance(pnode.node, ad.MatrixNode) and pnode.node.orthonormal != None: nodename = pnode.node.name if nodename in ortho_pnode_map: ortho_pnode_map[nodename].append(pnode) else: ortho_pnode_map[nodename] = [pnode] for pnodes in ortho_pnode_map.values(): if len(pnodes) < 2: continue remaining_pnodes = pnodes pnodes_subs = list(itertools.combinations(pnodes, 2)) for pnodes_binary_input in pnodes_subs: if not set(pnodes_binary_input).issubset(set(remaining_pnodes)): continue pnode_A, pnode_B = pnodes_binary_input o_index, c_index = orthonormal_indices_map[ pnode_A.node.orthonormal] # Criteria for the pruning: the o_index of two inputs are different, # and the c_index only appear in these two nodes. c_index_is_equal = pnode_A.subscript[c_index] == pnode_B.subscript[ c_index] o_index_not_equal = pnode_A.subscript[ o_index] != pnode_B.subscript[o_index] if not (c_index_is_equal and o_index_not_equal): continue num_subs_w_cindex = len( list( filter(lambda subs: pnode_A.subscript[c_index] in subs, subs_list))) if not num_subs_w_cindex == 2: continue remaining_pnodes = [ pnode for pnode in remaining_pnodes if not pnode in pnodes_binary_input ] p_innodes = [ pnode for pnode in p_innodes if not pnode in pnodes_binary_input ] i_node = ad.identity(pnode_A.node.shape[o_index]) i_subs = f"{pnode_A.subscript[o_index]}{pnode_B.subscript[o_index]}" p_innodes.append(PseudoNode(node=i_node, subscript=i_subs)) new_input_subs = [pnode.subscript for pnode in p_innodes] new_subscripts = ",".join(new_input_subs) + "->" + p_outnode.subscript new_inputs = [pnode.node for pnode in p_innodes] return ad.einsum(new_subscripts, *new_inputs)
def test_identity(): A = ad.identity(3) assert AutodiffParser.parse(A.name, []).name == A.name
def p_expression_identity(t): 'expression : IDENTITY LPAREN NUMBER RPAREN' t[0] = ad.identity(t[3])