def make_sdfg(implementation, dtype, id=0, in_shape=[n, n], out_shape=[n, n], in_subset="0:n, 0:n", out_subset="0:n, 0:n"): sdfg = dace.SDFG("linalg_solve_{}_{}_{}".format(implementation, dtype.__name__, id)) sdfg.add_symbol("n", dace.int64) state = sdfg.add_state("dataflow") sdfg.add_array("ain", in_shape, dtype) sdfg.add_array("bin", out_shape, dtype) sdfg.add_array("bout", out_shape, dtype) ain = state.add_read("ain") bin = state.add_read("bin") bout = state.add_write("bout") solve_node = Solve("solve") solve_node.implementation = implementation state.add_memlet_path(ain, solve_node, dst_conn="_ain", memlet=Memlet.simple(ain, in_subset, num_accesses=n * n)) state.add_memlet_path(bin, solve_node, dst_conn="_bin", memlet=Memlet.simple(bin, out_subset, num_accesses=n * n)) state.add_memlet_path(solve_node, bout, src_conn="_bout", memlet=Memlet.simple(bout, out_subset, num_accesses=n * n)) return sdfg
def make_sdfg(implementation, dtype, id=0, in_shape=[n, n], out_shape=[n, n], in_subset="0:n, 0:n", out_subset="0:n, 0:n", overwrite=False, getri=True): sdfg = dace.SDFG("linalg_inv_{}_{}_{}".format(implementation, dtype.__name__, id)) sdfg.add_symbol("n", dace.int64) state = sdfg.add_state("dataflow") sdfg.add_array("xin", in_shape, dtype) if not overwrite: sdfg.add_array("xout", out_shape, dtype) xin = state.add_read("xin") if overwrite: xout = state.add_write("xin") else: xout = state.add_write("xout") inv_node = Inv("inv", overwrite_a=overwrite, use_getri=getri) inv_node.implementation = implementation state.add_memlet_path(xin, inv_node, dst_conn="_ain", memlet=Memlet.simple(xin, in_subset, num_accesses=n * n)) state.add_memlet_path(inv_node, xout, src_conn="_aout", memlet=Memlet.simple(xout, out_subset, num_accesses=n * n)) return sdfg
def make_sdfg(implementation, dtype, storage=dace.StorageType.Default): n = dace.symbol("n", dace.int64) sdfg = dace.SDFG("linalg_cholesky_{}_{}".format(implementation, dtype)) state = sdfg.add_state("dataflow") inp = sdfg.add_array("xin", [n, n], dtype) out = sdfg.add_array("xout", [n, n], dtype) xin = state.add_read("xin") xout = state.add_write("xout") chlsky_node = Cholesky("cholesky", lower=True) chlsky_node.implementation = implementation state.add_memlet_path(xin, chlsky_node, dst_conn="_a", memlet=Memlet.from_array(*inp)) state.add_memlet_path(chlsky_node, xout, src_conn="_b", memlet=Memlet.from_array(*out)) return sdfg
def _make_sdfg(node, parent_state, parent_sdfg, implementation): inp_desc, inp_shape, out_desc, out_shape = node.validate(parent_sdfg, parent_state) dtype = inp_desc.dtype sdfg = dace.SDFG("{l}_sdfg".format(l=node.label)) ain_arr = sdfg.add_array('_a', inp_shape, dtype=dtype, strides=inp_desc.strides) bout_arr = sdfg.add_array('_b', out_shape, dtype=dtype, strides=out_desc.strides) info_arr = sdfg.add_array('_info', [1], dtype=dace.int32, transient=True) if implementation == 'cuSolverDn': binout_arr = sdfg.add_array('_bt', inp_shape, dtype=dtype, transient=True) else: binout_arr = bout_arr state = sdfg.add_state("{l}_state".format(l=node.label)) potrf_node = Potrf('potrf', lower=node.lower) potrf_node.implementation = implementation _, me, mx = state.add_mapped_tasklet('_uzero_', dict(__i="0:%s" % out_shape[0], __j="0:%s" % out_shape[1]), dict(_inp=Memlet.simple('_b', '__i, __j')), '_out = (__i < __j) ? 0 : _inp;', dict(_out=Memlet.simple('_b', '__i, __j')), language=dace.dtypes.Language.CPP, external_edges=True) ain = state.add_read('_a') if implementation == 'cuSolverDn': binout1 = state.add_access('_bt') binout2 = state.add_access('_bt') binout3 = state.in_edges(me)[0].src bout = state.out_edges(mx)[0].dst transpose_ain = Transpose('AT', dtype=dtype) transpose_ain.implementation = 'cuBLAS' state.add_edge(ain, None, transpose_ain, '_inp', Memlet.from_array(*ain_arr)) state.add_edge(transpose_ain, '_out', binout1, None, Memlet.from_array(*binout_arr)) transpose_out = Transpose('BT', dtype=dtype) transpose_out.implementation = 'cuBLAS' state.add_edge(binout2, None, transpose_out, '_inp', Memlet.from_array(*binout_arr)) state.add_edge(transpose_out, '_out', binout3, None, Memlet.from_array(*bout_arr)) else: binout1 = state.add_access('_b') binout2 = state.in_edges(me)[0].src binout3 = state.out_edges(mx)[0].dst state.add_nedge(ain, binout1, Memlet.from_array(*ain_arr)) info = state.add_write('_info') state.add_memlet_path(binout1, potrf_node, dst_conn="_xin", memlet=Memlet.from_array(*binout_arr)) state.add_memlet_path(potrf_node, info, src_conn="_res", memlet=Memlet.from_array(*info_arr)) state.add_memlet_path(potrf_node, binout2, src_conn="_xout", memlet=Memlet.from_array(*binout_arr)) return sdfg
def apply(self, sdfg): state: SDFGState = sdfg.nodes()[self.state_id] nsdfg_node = state.nodes()[self.subgraph[InlineSDFG._nested_sdfg]] nsdfg: SDFG = nsdfg_node.sdfg nstate: SDFGState = nsdfg.nodes()[0] nsdfg_scope_entry = state.entry_node(nsdfg_node) nsdfg_scope_exit = (state.exit_node(nsdfg_scope_entry) if nsdfg_scope_entry is not None else None) ####################################################### # Collect and update top-level SDFG metadata # Global/init/exit code for loc, code in nsdfg.global_code.items(): sdfg.append_global_code(code.code, loc) for loc, code in nsdfg.init_code.items(): sdfg.append_init_code(code.code, loc) for loc, code in nsdfg.exit_code.items(): sdfg.append_exit_code(code.code, loc) # Constants for cstname, cstval in nsdfg.constants.items(): if cstname in sdfg.constants: if cstval != sdfg.constants[cstname]: warnings.warn('Constant value mismatch for "%s" while ' 'inlining SDFG. Inner = %s != %s = outer' % (cstname, cstval, sdfg.constants[cstname])) else: sdfg.add_constant(cstname, cstval) # Find original source/destination edges (there is only one edge per # connector, according to match) inputs: Dict[str, MultiConnectorEdge] = {} outputs: Dict[str, MultiConnectorEdge] = {} input_set: Dict[str, str] = {} output_set: Dict[str, str] = {} for e in state.in_edges(nsdfg_node): inputs[e.dst_conn] = e input_set[e.data.data] = e.dst_conn for e in state.out_edges(nsdfg_node): outputs[e.src_conn] = e output_set[e.data.data] = e.src_conn # All transients become transients of the parent (if data already # exists, find new name) # Mapping from nested transient name to top-level name transients: Dict[str, str] = {} for node in nstate.nodes(): if isinstance(node, nodes.AccessNode): datadesc = nsdfg.arrays[node.data] if node.data not in transients and datadesc.transient: name = sdfg.add_datadesc('%s_%s' % (nsdfg.label, node.data), datadesc, find_new_name=True) transients[node.data] = name # All transients of edges between code nodes are also added to parent for edge in nstate.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): datadesc = nsdfg.arrays[edge.data.data] if edge.data.data not in transients and datadesc.transient: name = sdfg.add_datadesc('%s_%s' % (nsdfg.label, edge.data.data), datadesc, find_new_name=True) transients[edge.data.data] = name # Collect nodes to add to top-level graph new_incoming_edges: Dict[nodes.Node, MultiConnectorEdge] = {} new_outgoing_edges: Dict[nodes.Node, MultiConnectorEdge] = {} source_accesses = set() sink_accesses = set() for node in nstate.source_nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in transients): new_incoming_edges[node] = inputs[node.data] source_accesses.add(node) for node in nstate.sink_nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in transients): new_outgoing_edges[node] = outputs[node.data] sink_accesses.add(node) ####################################################### # Add nested SDFG into top-level SDFG # Add nested nodes into original state subgraph = SubgraphView(nstate, [ n for n in nstate.nodes() if n not in (source_accesses | sink_accesses) ]) state.add_nodes_from(subgraph.nodes()) for edge in subgraph.edges(): state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn, edge.data) ####################################################### # Replace data on inlined SDFG nodes/edges # Replace symbols using invocation symbol mapping # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes for symname, symvalue in nsdfg_node.symbol_mapping.items(): if str(symname) != str(symvalue): nsdfg.replace(symname, '__dacesym_' + symname) for symname, symvalue in nsdfg_node.symbol_mapping.items(): if str(symname) != str(symvalue): nsdfg.replace('__dacesym_' + symname, symvalue) # Replace data names with their top-level counterparts repldict = {} repldict.update(transients) repldict.update({ k: v.data.data for k, v in itertools.chain(inputs.items(), outputs.items()) }) for node in subgraph.nodes(): if isinstance(node, nodes.AccessNode) and node.data in repldict: node.data = repldict[node.data] for edge in subgraph.edges(): if edge.data.data in repldict: edge.data.data = repldict[edge.data.data] ####################################################### # Reconnect inlined SDFG # If a source/sink node is one of the inputs/outputs, reconnect it, # replacing memlets in outgoing/incoming paths modified_edges = set() modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate, state, True) modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate, state, False) # Modify all other internal edges pertaining to input/output nodes for node in subgraph.nodes(): if isinstance(node, nodes.AccessNode): if node.data in input_set or node.data in output_set: if node.data in input_set: outer_edge = inputs[input_set[node.data]] else: outer_edge = outputs[output_set[node.data]] for edge in state.all_edges(node): if (edge not in modified_edges and edge.data.data == node.data): for e in state.memlet_tree(edge): if e.data.data == node.data: e._data = helpers.unsqueeze_memlet( e.data, outer_edge.data) # If source/sink node is not connected to a source/destination access # node, and the nested SDFG is in a scope, connect to scope with empty # memlets if nsdfg_scope_entry is not None: for node in subgraph.nodes(): if state.in_degree(node) == 0: state.add_edge(nsdfg_scope_entry, None, node, None, Memlet()) if state.out_degree(node) == 0: state.add_edge(node, None, nsdfg_scope_exit, None, Memlet()) # Replace nested SDFG parents with new SDFG for node in nstate.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent = state node.sdfg.parent_sdfg = sdfg node.sdfg.parent_nsdfg_node = node # Remove all unused external inputs/output memlet paths, as well as # resulting isolated nodes removed_in_edges = self._remove_edge_path(state, inputs, set(inputs.keys()) - source_accesses, reverse=True) removed_out_edges = self._remove_edge_path(state, outputs, set(outputs.keys()) - sink_accesses, reverse=False) # Re-add in/out edges to first/last nodes in subgraph order = [ x for x in nx.topological_sort(nstate._nx) if isinstance(x, nodes.AccessNode) ] for edge in removed_in_edges: # Find first access node that refers to this edge node = next(n for n in order if n.data == edge.data.data) state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn, edge.data) for edge in removed_out_edges: # Find last access node that refers to this edge node = next(n for n in reversed(order) if n.data == edge.data.data) state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn, edge.data) ####################################################### # Remove nested SDFG node state.remove_node(nsdfg_node)
def test_duplicate_codegen(): # Unfortunately I have to generate this graph manually, as doing it with the python # frontend wouldn't result in the node ordering that we want sdfg = dace.SDFG("dup") state = sdfg.add_state() c_task = state.add_tasklet("c_task", inputs={"c"}, outputs={"d"}, code='d = c') e_task = state.add_tasklet("e_task", inputs={"a", "d"}, outputs={"e"}, code="e = a + d") f_task = state.add_tasklet("f_task", inputs={"b", "d"}, outputs={"f"}, code="f = b + d") _, A_arr = sdfg.add_array("A", [ 1, ], dace.float32) _, B_arr = sdfg.add_array("B", [ 1, ], dace.float32) _, C_arr = sdfg.add_array("C", [ 1, ], dace.float32) _, D_arr = sdfg.add_array("D", [ 1, ], dace.float32) _, E_arr = sdfg.add_array("E", [ 1, ], dace.float32) _, F_arr = sdfg.add_array("F", [ 1, ], dace.float32) A = state.add_read("A") B = state.add_read("B") C = state.add_read("C") D = state.add_access("D") E = state.add_write("E") F = state.add_write("F") state.add_edge(C, None, c_task, "c", Memlet.from_array("C", C_arr)) state.add_edge(c_task, "d", D, None, Memlet.from_array("D", D_arr)) state.add_edge(A, None, e_task, "a", Memlet.from_array("A", A_arr)) state.add_edge(B, None, f_task, "b", Memlet.from_array("B", B_arr)) state.add_edge(D, None, f_task, "d", Memlet.from_array("D", D_arr)) state.add_edge(D, None, e_task, "d", Memlet.from_array("D", D_arr)) state.add_edge(e_task, "e", E, None, Memlet.from_array("E", E_arr, wcr="lambda x, y: x + y")) state.add_edge(f_task, "f", F, None, Memlet.from_array("F", F_arr, wcr="lambda x, y: x + y")) A = np.array([1], dtype=np.float32) B = np.array([1], dtype=np.float32) C = np.array([1], dtype=np.float32) D = np.array([1], dtype=np.float32) E = np.zeros_like(A) F = np.zeros_like(A) sdfg(A=A, B=B, C=C, D=D, E=E, F=F) assert E[0] == 2 assert F[0] == 2
def backward( forward_node: Node, context: BackwardContext, given_gradients: typing.List[typing.Optional[str]], required_gradients: typing.List[typing.Optional[str]] ) -> typing.Tuple[Node, BackwardResult]: reduction_type = detect_reduction_type(forward_node.wcr) if len(given_gradients) != 1: raise AutoDiffException( "recieved invalid SDFG: reduce node {} should have exactly one output edge" .format(forward_node)) if len(required_gradients) != 1: raise AutoDiffException( "recieved invalid SDFG: reduce node {} should have exactly one input edge" .format(forward_node)) input_name = next(iter(required_gradients)) in_desc = in_desc_with_name(forward_node, context.forward_state, context.forward_sdfg, input_name) output_name = next(iter(given_gradients)) out_desc = out_desc_with_name(forward_node, context.forward_state, context.forward_sdfg, output_name) all_axes: typing.List[int] = list(range(len(in_desc.shape))) reduce_axes: typing.List[ int] = all_axes if forward_node.axes is None else forward_node.axes non_reduce_axes: typing.List[int] = [ i for i in all_axes if i not in reduce_axes ] result = BackwardResult.empty() if reduction_type is dtypes.ReductionType.Sum: # in this case, we need to simply scatter the grad across the axes that were reduced sdfg = SDFG("_reverse_" + str(reduction_type).replace(".", "_") + "_") state = sdfg.add_state() rev_input_conn_name = "input_gradient" rev_output_conn_name = "output_gradient" result.required_grad_names[output_name] = rev_output_conn_name result.given_grad_names[input_name] = rev_input_conn_name _, rev_input_arr = sdfg.add_array(rev_input_conn_name, shape=out_desc.shape, dtype=out_desc.dtype) _, rev_output_arr = sdfg.add_array(rev_output_conn_name, shape=in_desc.shape, dtype=in_desc.dtype) state.add_mapped_tasklet( "_distribute_grad_" + str(reduction_type).replace(".", "_") + "_", { "i" + str(i): "0:{}".format(shape) for i, shape in enumerate(in_desc.shape) }, { "__in": Memlet.simple( rev_input_conn_name, "0" if forward_node.axes is None else ",".join( "i" + str(i) for i in non_reduce_axes)) }, "__out = __in", { "__out": Memlet.simple(rev_output_conn_name, ",".join("i" + str(i) for i in all_axes), wcr_str="lambda x, y: x + y") }, external_edges=True) return context.backward_state.add_nested_sdfg( sdfg, None, {rev_input_conn_name}, {rev_output_conn_name}), result else: raise AutoDiffException( "Unsupported reduction type '{}'".format(reduction_type))
def apply(self, _, sdfg: sd.SDFG): # Obtain loop information guard: sd.SDFGState = self.loop_guard body: sd.SDFGState = self.loop_begin # Obtain iteration variable, range, and stride itervar, (start, end, step), _ = find_for_loop(sdfg, guard, body) forward_loop = step > 0 for node in body.nodes(): if isinstance(node, nodes.MapEntry): map_entry = node if isinstance(node, nodes.MapExit): map_exit = node # nest map's content in sdfg map_subgraph = body.scope_subgraph(map_entry, include_entry=False, include_exit=False) nsdfg = helpers.nest_state_subgraph(sdfg, body, map_subgraph, full_data=True) # replicate loop in nested sdfg new_before, new_guard, new_after = nsdfg.sdfg.add_loop( before_state=None, loop_state=nsdfg.sdfg.nodes()[0], loop_end_state=None, after_state=None, loop_var=itervar, initialize_expr=f'{start}', condition_expr=f'{itervar} <= {end}' if forward_loop else f'{itervar} >= {end}', increment_expr=f'{itervar} + {step}' if forward_loop else f'{itervar} - {abs(step)}') # remove outer loop before_guard_edge = nsdfg.sdfg.edges_between(new_before, new_guard)[0] for e in nsdfg.sdfg.out_edges(new_guard): if e.dst is new_after: guard_after_edge = e else: guard_body_edge = e for body_inedge in sdfg.in_edges(body): if body_inedge.src is guard: guard_body_edge.data.assignments.update(body_inedge.data.assignments) sdfg.remove_edge(body_inedge) for body_outedge in sdfg.out_edges(body): sdfg.remove_edge(body_outedge) for guard_inedge in sdfg.in_edges(guard): before_guard_edge.data.assignments.update(guard_inedge.data.assignments) guard_inedge.data.assignments = {} sdfg.add_edge(guard_inedge.src, body, guard_inedge.data) sdfg.remove_edge(guard_inedge) for guard_outedge in sdfg.out_edges(guard): if guard_outedge.dst is body: guard_body_edge.data.assignments.update(guard_outedge.data.assignments) else: guard_after_edge.data.assignments.update(guard_outedge.data.assignments) guard_outedge.data.condition = CodeBlock("1") sdfg.add_edge(body, guard_outedge.dst, guard_outedge.data) sdfg.remove_edge(guard_outedge) sdfg.remove_node(guard) if itervar in nsdfg.symbol_mapping: del nsdfg.symbol_mapping[itervar] if itervar in sdfg.symbols: del sdfg.symbols[itervar] # Add missing data/symbols for s in nsdfg.sdfg.free_symbols: if s in nsdfg.symbol_mapping: continue if s in sdfg.symbols: nsdfg.symbol_mapping[s] = s elif s in sdfg.arrays: desc = sdfg.arrays[s] access = body.add_access(s) conn = nsdfg.sdfg.add_datadesc(s, copy.deepcopy(desc)) nsdfg.sdfg.arrays[s].transient = False nsdfg.add_in_connector(conn) body.add_memlet_path(access, map_entry, nsdfg, memlet=Memlet.from_array(s, desc), dst_conn=conn) else: raise NotImplementedError(f"Free symbol {s} is neither a symbol nor data.") to_delete = set() for s in nsdfg.symbol_mapping: if s not in nsdfg.sdfg.free_symbols: to_delete.add(s) for s in to_delete: del nsdfg.symbol_mapping[s] # propagate scope for correct volumes scope_tree = ScopeTree(map_entry, map_exit) scope_tree.parent = ScopeTree(None, None) # The first execution helps remove apperances of symbols # that are now defined only in the nested SDFG in memlets. propagation.propagate_memlets_scope(sdfg, body, scope_tree) for s in to_delete: if helpers.is_symbol_unused(sdfg, s): sdfg.remove_symbol(s) from dace.transformation.interstate import RefineNestedAccess transformation = RefineNestedAccess() transformation.setup_match(sdfg, 0, sdfg.node_id(body), {RefineNestedAccess.nsdfg: body.node_id(nsdfg)}, 0) transformation.apply(body, sdfg) # Second propagation for refined accesses. propagation.propagate_memlets_scope(sdfg, body, scope_tree)
def _make_sdfg(node, parent_state, parent_sdfg, implementation): arr_desc = node.validate(parent_sdfg, parent_state) if node.overwrite: in_shape, in_dtype, in_strides, n = arr_desc else: (in_shape, in_dtype, in_strides, out_shape, out_dtype, out_strides, n) = arr_desc dtype = in_dtype sdfg = dace.SDFG("{l}_sdfg".format(l=node.label)) a_arr = sdfg.add_array('_ain', in_shape, dtype=in_dtype, strides=in_strides) if not node.overwrite: ain_arr = a_arr a_arr = sdfg.add_array('_aout', out_shape, dtype=out_dtype, strides=out_strides) ipiv_arr = sdfg.add_array('_pivots', [n], dtype=dace.int32, transient=True) info_arr = sdfg.add_array('_info', [1], dtype=dace.int32, transient=True) state = sdfg.add_state("{l}_state".format(l=node.label)) getrf_node = Getrf('getrf') getrf_node.implementation = implementation getri_node = Getri('getri') getri_node.implementation = implementation if node.overwrite: ain = state.add_read('_ain') ainout = state.add_access('_ain') aout = state.add_write('_ain') else: a = state.add_read('_ain') ain = state.add_read('_aout') ainout = state.add_access('_aout') aout = state.add_write('_aout') state.add_nedge(a, ain, Memlet.from_array(*ain_arr)) ipiv = state.add_access('_pivots') info1 = state.add_write('_info') info2 = state.add_write('_info') state.add_memlet_path(ain, getrf_node, dst_conn="_xin", memlet=Memlet.from_array(*a_arr)) state.add_memlet_path(getrf_node, info1, src_conn="_res", memlet=Memlet.from_array(*info_arr)) state.add_memlet_path(getrf_node, ipiv, src_conn="_ipiv", memlet=Memlet.from_array(*ipiv_arr)) state.add_memlet_path(getrf_node, ainout, src_conn="_xout", memlet=Memlet.from_array(*a_arr)) state.add_memlet_path(ainout, getri_node, dst_conn="_xin", memlet=Memlet.from_array(*a_arr)) state.add_memlet_path(ipiv, getri_node, dst_conn="_ipiv", memlet=Memlet.from_array(*ipiv_arr)) state.add_memlet_path(getri_node, info2, src_conn="_res", memlet=Memlet.from_array(*info_arr)) state.add_memlet_path(getri_node, aout, src_conn="_xout", memlet=Memlet.from_array(*a_arr)) return sdfg
def _make_sdfg_getrs(node, parent_state, parent_sdfg, implementation): arr_desc = node.validate(parent_sdfg, parent_state) if node.overwrite: in_shape, in_dtype, in_strides, n = arr_desc else: (in_shape, in_dtype, in_strides, out_shape, out_dtype, out_strides, n) = arr_desc dtype = in_dtype sdfg = dace.SDFG("{l}_sdfg".format(l=node.label)) a_arr = sdfg.add_array('_ain', in_shape, dtype=in_dtype, strides=in_strides) if not node.overwrite: ain_arr = a_arr a_arr = sdfg.add_array('_ainout', [n, n], dtype=in_dtype, transient=True) b_arr = sdfg.add_array('_aout', out_shape, dtype=out_dtype, strides=out_strides) else: b_arr = sdfg.add_array('_b', [n, n], dtype=dtype, transient=True) ipiv_arr = sdfg.add_array('_pivots', [n], dtype=dace.int32, transient=True) info_arr = sdfg.add_array('_info', [1], dtype=dace.int32, transient=True) state = sdfg.add_state("{l}_state".format(l=node.label)) getrf_node = Getrf('getrf') getrf_node.implementation = implementation getrs_node = Getrs('getrs') getrs_node.implementation = implementation if node.overwrite: ain = state.add_read('_ain') ainout = state.add_access('_ain') aout = state.add_write('_ain') bin_name = '_b' bout = state.add_write('_b') state.add_nedge(bout, aout, Memlet.from_array(*a_arr)) else: a = state.add_read('_ain') ain = state.add_read('_ainout') ainout = state.add_access('_ainout') # aout = state.add_write('_aout') state.add_nedge(a, ain, Memlet.from_array(*ain_arr)) bin_name = '_aout' bout = state.add_access('_aout') _, _, mx = state.add_mapped_tasklet( '_eye_', dict(i="0:n", j="0:n"), {}, '_out = (i == j) ? 1 : 0;', dict(_out=Memlet.simple(bin_name, 'i, j')), language=dace.dtypes.Language.CPP, external_edges=True) bin = state.out_edges(mx)[0].dst ipiv = state.add_access('_pivots') info1 = state.add_write('_info') info2 = state.add_write('_info') state.add_memlet_path(ain, getrf_node, dst_conn="_xin", memlet=Memlet.from_array(*a_arr)) state.add_memlet_path(getrf_node, info1, src_conn="_res", memlet=Memlet.from_array(*info_arr)) state.add_memlet_path(getrf_node, ipiv, src_conn="_ipiv", memlet=Memlet.from_array(*ipiv_arr)) state.add_memlet_path(getrf_node, ainout, src_conn="_xout", memlet=Memlet.from_array(*a_arr)) state.add_memlet_path(ainout, getrs_node, dst_conn="_a", memlet=Memlet.from_array(*a_arr)) state.add_memlet_path(bin, getrs_node, dst_conn="_rhs_in", memlet=Memlet.from_array(*b_arr)) state.add_memlet_path(ipiv, getrs_node, dst_conn="_ipiv", memlet=Memlet.from_array(*ipiv_arr)) state.add_memlet_path(getrs_node, info2, src_conn="_res", memlet=Memlet.from_array(*info_arr)) state.add_memlet_path(getrs_node, bout, src_conn="_rhs_out", memlet=Memlet.from_array(*b_arr)) return sdfg
def apply(self, state: SDFGState, sdfg: SDFG) -> nodes.AccessNode: dnode: nodes.AccessNode = self.access if self.expr_index == 0: edges = state.out_edges(dnode) else: edges = state.in_edges(dnode) # To understand how many components we need to create, all map ranges # throughout memlet paths must match exactly. We thus create a # dictionary of unique ranges mapping: Dict[Tuple[subsets.Range], List[gr.MultiConnectorEdge[mm.Memlet]]] = defaultdict( list) ranges = {} for edge in edges: mpath = state.memlet_path(edge) ranges[edge] = _collect_map_ranges(state, mpath) mapping[tuple(r[1] for r in ranges[edge])].append(edge) # Collect all edges with the same memory access pattern components_to_create: Dict[ Tuple[symbolic.SymbolicType], List[gr.MultiConnectorEdge[mm.Memlet]]] = defaultdict(list) for edges_with_same_range in mapping.values(): for edge in edges_with_same_range: # Get memlet path and innermost edge mpath = state.memlet_path(edge) innermost_edge = copy.deepcopy(mpath[-1] if self.expr_index == 0 else mpath[0]) # Store memlets of the same access in the same component expr = _canonicalize_memlet(innermost_edge.data, ranges[edge]) components_to_create[expr].append((innermost_edge, edge)) components = list(components_to_create.values()) # Split out components that have dependencies between them to avoid # deadlocks if self.expr_index == 0: ccs_to_add = [] for i, component in enumerate(components): edges_to_remove = set() for cedge in component: if any( nx.has_path(state.nx, o[1].dst, cedge[1].dst) for o in component if o is not cedge): ccs_to_add.append([cedge]) edges_to_remove.add(cedge) if edges_to_remove: components[i] = [ c for c in component if c not in edges_to_remove ] components.extend(ccs_to_add) # End of split desc = sdfg.arrays[dnode.data] # Create new streams of shape 1 streams = {} mpaths = {} for edge in edges: if self.use_memory_buffering: arrname = str(self.access) # Add gearbox total_size = edge.data.volume vector_size = int(self.memory_buffering_target_bytes / desc.dtype.bytes) if not is_int(sdfg.arrays[dnode.data].shape[-1]): warnings.warn( "Using the MemoryBuffering transformation is potential unsafe since {sym} is not an integer. There should be no issue if {sym} % {vec} == 0" .format(sym=sdfg.arrays[dnode.data].shape[-1], vec=vector_size)) for i in sdfg.arrays[dnode.data].strides: if not is_int(i): warnings.warn( "Using the MemoryBuffering transformation is potential unsafe since {sym} is not an integer. There should be no issue if {sym} % {vec} == 0" .format(sym=i, vec=vector_size)) if self.expr_index == 0: # Read edges = state.out_edges(dnode) gearbox_input_type = dtypes.vector(desc.dtype, vector_size) gearbox_output_type = desc.dtype gearbox_read_volume = total_size / vector_size gearbox_write_volume = total_size else: # Write edges = state.in_edges(dnode) gearbox_input_type = desc.dtype gearbox_output_type = dtypes.vector( desc.dtype, vector_size) gearbox_read_volume = total_size gearbox_write_volume = total_size / vector_size input_gearbox_name, input_gearbox_newdesc = sdfg.add_stream( "gearbox_input", gearbox_input_type, buffer_size=self.buffer_size, storage=self.storage, transient=True, find_new_name=True) output_gearbox_name, output_gearbox_newdesc = sdfg.add_stream( "gearbox_output", gearbox_output_type, buffer_size=self.buffer_size, storage=self.storage, transient=True, find_new_name=True) read_to_gearbox = state.add_read(input_gearbox_name) write_from_gearbox = state.add_write(output_gearbox_name) gearbox = Gearbox(total_size / vector_size) state.add_node(gearbox) state.add_memlet_path(read_to_gearbox, gearbox, dst_conn="from_memory", memlet=Memlet( input_gearbox_name + "[0]", volume=gearbox_read_volume)) state.add_memlet_path(gearbox, write_from_gearbox, src_conn="to_kernel", memlet=Memlet( output_gearbox_name + "[0]", volume=gearbox_write_volume)) if self.expr_index == 0: streams[edge] = input_gearbox_name name = output_gearbox_name newdesc = output_gearbox_newdesc else: streams[edge] = output_gearbox_name name = input_gearbox_name newdesc = input_gearbox_newdesc else: # Qualify name to avoid name clashes if memory interfaces are not decoupled for Xilinx stream_name = "stream_" + dnode.data name, newdesc = sdfg.add_stream(stream_name, desc.dtype, buffer_size=self.buffer_size, storage=self.storage, transient=True, find_new_name=True) streams[edge] = name # Add these such that we can easily use output_gearbox_name and input_gearbox_name without using if statements output_gearbox_name = name input_gearbox_name = name mpath = state.memlet_path(edge) mpaths[edge] = mpath # Replace memlets in path with stream access for e in mpath: e.data = mm.Memlet(data=name, subset='0', other_subset=e.data.other_subset) if isinstance(e.src, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.src, e.src_conn, newdesc) if isinstance(e.dst, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.dst, e.dst_conn, newdesc) # Replace access node and memlet tree with one access if self.expr_index == 0: replacement = state.add_read(output_gearbox_name) state.remove_edge(edge) state.add_edge(replacement, edge.src_conn, edge.dst, edge.dst_conn, edge.data) else: replacement = state.add_write(input_gearbox_name) state.remove_edge(edge) state.add_edge(edge.src, edge.src_conn, replacement, edge.dst_conn, edge.data) if self.use_memory_buffering: arrname = str(self.access) vector_size = int(self.memory_buffering_target_bytes / desc.dtype.bytes) # Vectorize access to global array. dtype = sdfg.arrays[arrname].dtype sdfg.arrays[arrname].dtype = dtypes.vector(dtype, vector_size) new_shape = list(sdfg.arrays[arrname].shape) contigidx = sdfg.arrays[arrname].strides.index(1) new_shape[contigidx] /= vector_size try: new_shape[contigidx] = int(new_shape[contigidx]) except TypeError: pass sdfg.arrays[arrname].shape = new_shape # Change strides new_strides: List = list(sdfg.arrays[arrname].strides) for i in range(len(new_strides)): if i == len(new_strides ) - 1: # Skip last dimension since it is always 1 continue new_strides[i] = new_strides[i] / vector_size sdfg.arrays[arrname].strides = new_strides post_state = get_post_state(sdfg, state) if post_state != None: # Change subset in the post state such that the correct amount of memory is copied back from the device for e in post_state.edges(): if e.data.data == self.access.data: new_subset = list(e.data.subset) i, j, k = new_subset[-1] new_subset[-1] = (i, (j + 1) / vector_size - 1, k) e.data = mm.Memlet(data=str(e.src), subset=subsets.Range(new_subset)) # Make read/write components ionodes = [] for component in components: # Pick the first edge as the edge to make the component from innermost_edge, outermost_edge = component[0] mpath = mpaths[outermost_edge] mapname = streams[outermost_edge] innermost_edge.data.other_subset = None # Get edge data and streams if self.expr_index == 0: opname = 'read' path = [e.dst for e in mpath[:-1]] rmemlets = [(dnode, '__inp', innermost_edge.data)] wmemlets = [] for i, (_, edge) in enumerate(component): name = streams[edge] ionode = state.add_write(name) ionodes.append(ionode) wmemlets.append( (ionode, '__out%d' % i, mm.Memlet(data=name, subset='0'))) code = '\n'.join('__out%d = __inp' % i for i in range(len(component))) else: # More than one input stream might mean a data race, so we only # address the first one in the tasklet code if len(component) > 1: warnings.warn( f'More than one input found for the same index for {dnode.data}' ) opname = 'write' path = [state.entry_node(e.src) for e in reversed(mpath[1:])] wmemlets = [(dnode, '__out', innermost_edge.data)] rmemlets = [] for i, (_, edge) in enumerate(component): name = streams[edge] ionode = state.add_read(name) ionodes.append(ionode) rmemlets.append( (ionode, '__inp%d' % i, mm.Memlet(data=name, subset='0'))) code = '__out = __inp0' # Create map structure for read/write component maps = [] for entry in path: map: nodes.Map = entry.map ranges = [(p, (r[0], r[1], r[2])) for p, r in zip(map.params, map.range)] # Change ranges of map if self.use_memory_buffering: # Find edges from/to map edge_subset = [ a_tuple[0] for a_tuple in list(innermost_edge.data.subset) ] # Change range of map if isinstance(edge_subset[-1], symbol) and str( edge_subset[-1]) == map.params[-1]: if not is_int(ranges[-1][1][1]): warnings.warn( "Using the MemoryBuffering transformation is potential unsafe since {sym} is not an integer. There should be no issue if {sym} % {vec} == 0" .format(sym=ranges[-1][1][1].args[1], vec=vector_size)) ranges[-1] = (ranges[-1][0], (ranges[-1][1][0], (ranges[-1][1][1] + 1) / vector_size - 1, ranges[-1][1][2])) elif isinstance(edge_subset[-1], sympy.core.add.Add): for arg in edge_subset[-1].args: if isinstance( arg, symbol) and str(arg) == map.params[-1]: if not is_int(ranges[-1][1][1]): warnings.warn( "Using the MemoryBuffering transformation is potential unsafe since {sym} is not an integer. There should be no issue if {sym} % {vec} == 0" .format(sym=ranges[-1][1][1].args[1], vec=vector_size)) ranges[-1] = (ranges[-1][0], ( ranges[-1][1][0], (ranges[-1][1][1] + 1) / vector_size - 1, ranges[-1][1][2])) maps.append( state.add_map(f'__s{opname}_{mapname}', ranges, map.schedule)) tasklet = state.add_tasklet( f'{opname}_{mapname}', {m[1] for m in rmemlets}, {m[1] for m in wmemlets}, code, ) for node, cname, memlet in rmemlets: state.add_memlet_path(node, *(me for me, _ in maps), tasklet, dst_conn=cname, memlet=memlet) for node, cname, memlet in wmemlets: state.add_memlet_path(tasklet, *(mx for _, mx in reversed(maps)), node, src_conn=cname, memlet=memlet) return ionodes
def apply(self, sdfg: SDFG): subgraph = self.subgraph_view(sdfg) entry_states_in, entry_states_out = self.get_entry_states( sdfg, subgraph) _, exit_states_out = self.get_exit_states(sdfg, subgraph) entry_state_in = entry_states_in.pop() entry_state_out = entry_states_out.pop() \ if len(entry_states_out) > 0 else None exit_state_out = exit_states_out.pop() \ if len(exit_states_out) > 0 else None launch_state = None entry_guard_state = None exit_guard_state = None # generate entry guard state if needed if self.include_in_assignment and entry_state_out is not None: entry_edge = sdfg.edges_between(entry_state_out, entry_state_in)[0] if len(entry_edge.data.assignments) > 0: entry_guard_state = sdfg.add_state( label='{}kernel_entry_guard'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) sdfg.add_edge(entry_state_out, entry_guard_state, InterstateEdge(entry_edge.data.condition)) sdfg.add_edge( entry_guard_state, entry_state_in, InterstateEdge(None, entry_edge.data.assignments)) sdfg.remove_edge(entry_edge) # Update SubgraphView new_node_list = subgraph.nodes() new_node_list.append(entry_guard_state) subgraph = SubgraphView(sdfg, new_node_list) launch_state = sdfg.add_state_before( entry_guard_state, label='{}kernel_launch'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) # generate exit guard state if exit_state_out is not None: exit_guard_state = sdfg.add_state_before( exit_state_out, label='{}kernel_exit_guard'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) # Update SubgraphView new_node_list = subgraph.nodes() new_node_list.append(exit_guard_state) subgraph = SubgraphView(sdfg, new_node_list) if launch_state is None: launch_state = sdfg.add_state_before( exit_state_out, label='{}kernel_launch'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) # If the launch state doesn't exist at this point then there is no other # states outside of the kernel, so create a stand alone launch state if launch_state is None: assert (entry_state_in is None and exit_state_out is None) launch_state = sdfg.add_state(label='{}kernel_launch'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) # create sdfg for kernel and fill it with states and edges from # ssubgraph dfg will be nested at the end kernel_sdfg = SDFG( '{}kernel'.format(self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) edges = subgraph.edges() for edge in edges: kernel_sdfg.add_edge(edge.src, edge.dst, edge.data) # Setting entry node in nested SDFG if no entry guard was created if entry_guard_state is None: kernel_sdfg.start_state = kernel_sdfg.node_id(entry_state_in) for state in subgraph: state.parent = kernel_sdfg # remove the now nested nodes from the outer sdfg and make sure the # launch state is properly connected to remaining states sdfg.remove_nodes_from(subgraph.nodes()) if entry_state_out is not None \ and len(sdfg.edges_between(entry_state_out, launch_state)) == 0: sdfg.add_edge(entry_state_out, launch_state, InterstateEdge()) if exit_state_out is not None \ and len(sdfg.edges_between(launch_state, exit_state_out)) == 0: sdfg.add_edge(launch_state, exit_state_out, InterstateEdge()) # Handle data for kernel kernel_data = set(node.data for state in kernel_sdfg for node in state.nodes() if isinstance(node, nodes.AccessNode)) # move Streams and Register data into the nested SDFG # normal data will be added as kernel argument kernel_args = [] for data in kernel_data: if (isinstance(sdfg.arrays[data], dace.data.Stream) or (isinstance(sdfg.arrays[data], dace.data.Array) and sdfg.arrays[data].storage == StorageType.Register)): kernel_sdfg.add_datadesc(data, sdfg.arrays[data]) del sdfg.arrays[data] else: copy_desc = copy.deepcopy(sdfg.arrays[data]) copy_desc.transient = False copy_desc.storage = StorageType.Default kernel_sdfg.add_datadesc(data, copy_desc) kernel_args.append(data) # read only data will be passed as input, writeable data will be passed # as 'output' otherwise kernel cannot write to data kernel_args_read = set() kernel_args_write = set() for data in kernel_args: data_accesses_read_only = [ node.access == dtypes.AccessType.ReadOnly for state in kernel_sdfg for node in state if isinstance(node, nodes.AccessNode) and node.data == data ] if all(data_accesses_read_only): kernel_args_read.add(data) else: kernel_args_write.add(data) # Kernel SDFG is complete at this point if self.validate: kernel_sdfg.validate() # Filling launch state with nested SDFG, map and access nodes map_entry, map_exit = launch_state.add_map( '{}kernel_launch_map'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else ''), dict(ignore='0'), schedule=ScheduleType.GPU_Persistent, ) nested_sdfg = launch_state.add_nested_sdfg( kernel_sdfg, sdfg, kernel_args_read, kernel_args_write, ) # Create and connect read only data access nodes for arg in kernel_args_read: read_node = launch_state.add_read(arg) launch_state.add_memlet_path(read_node, map_entry, nested_sdfg, dst_conn=arg, memlet=Memlet.from_array( arg, sdfg.arrays[arg])) # Create and connect writable data access nodes for arg in kernel_args_write: write_node = launch_state.add_write(arg) launch_state.add_memlet_path(nested_sdfg, map_exit, write_node, src_conn=arg, memlet=Memlet.from_array( arg, sdfg.arrays[arg])) # Transformation is done if self.validate: sdfg.validate()
def apply(self, sdfg: SDFG): state: SDFGState = sdfg.nodes()[self.state_id] nsdfg_node = state.nodes()[self.subgraph[InlineSDFG._nested_sdfg]] nsdfg: SDFG = nsdfg_node.sdfg nstate: SDFGState = nsdfg.nodes()[0] if nsdfg_node.schedule is not dtypes.ScheduleType.Default: infer_types.set_default_schedule_and_storage_types( nsdfg, nsdfg_node.schedule) nsdfg_scope_entry = state.entry_node(nsdfg_node) nsdfg_scope_exit = (state.exit_node(nsdfg_scope_entry) if nsdfg_scope_entry is not None else None) ####################################################### # Collect and update top-level SDFG metadata # Global/init/exit code for loc, code in nsdfg.global_code.items(): sdfg.append_global_code(code.code, loc) for loc, code in nsdfg.init_code.items(): sdfg.append_init_code(code.code, loc) for loc, code in nsdfg.exit_code.items(): sdfg.append_exit_code(code.code, loc) # Constants for cstname, cstval in nsdfg.constants.items(): if cstname in sdfg.constants: if cstval != sdfg.constants[cstname]: warnings.warn('Constant value mismatch for "%s" while ' 'inlining SDFG. Inner = %s != %s = outer' % (cstname, cstval, sdfg.constants[cstname])) else: sdfg.add_constant(cstname, cstval) # Find original source/destination edges (there is only one edge per # connector, according to match) inputs: Dict[str, MultiConnectorEdge] = {} outputs: Dict[str, MultiConnectorEdge] = {} input_set: Dict[str, str] = {} output_set: Dict[str, str] = {} for e in state.in_edges(nsdfg_node): inputs[e.dst_conn] = e input_set[e.data.data] = e.dst_conn for e in state.out_edges(nsdfg_node): outputs[e.src_conn] = e output_set[e.data.data] = e.src_conn # Access nodes that need to be reshaped reshapes: Set(str) = set() for aname, array in nsdfg.arrays.items(): if array.transient: continue edge = None if aname in inputs: edge = inputs[aname] if len(array.shape) > len(edge.data.subset): reshapes.add(aname) continue if aname in outputs: edge = outputs[aname] if len(array.shape) > len(edge.data.subset): reshapes.add(aname) continue if edge is not None and not InlineSDFG._check_strides( array.strides, sdfg.arrays[edge.data.data].strides, edge.data, nsdfg_node): reshapes.add(aname) # Replace symbols using invocation symbol mapping # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes for symname, symvalue in nsdfg_node.symbol_mapping.items(): if str(symname) != str(symvalue): nsdfg.replace(symname, '__dacesym_' + symname) for symname, symvalue in nsdfg_node.symbol_mapping.items(): if str(symname) != str(symvalue): nsdfg.replace('__dacesym_' + symname, symvalue) # All transients become transients of the parent (if data already # exists, find new name) # Mapping from nested transient name to top-level name transients: Dict[str, str] = {} for node in nstate.nodes(): if isinstance(node, nodes.AccessNode): datadesc = nsdfg.arrays[node.data] if node.data not in transients and datadesc.transient: name = sdfg.add_datadesc('%s_%s' % (nsdfg.label, node.data), datadesc, find_new_name=True) transients[node.data] = name # All transients of edges between code nodes are also added to parent for edge in nstate.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): if edge.data.data is not None: datadesc = nsdfg.arrays[edge.data.data] if edge.data.data not in transients and datadesc.transient: name = sdfg.add_datadesc('%s_%s' % (nsdfg.label, edge.data.data), datadesc, find_new_name=True) transients[edge.data.data] = name # Collect nodes to add to top-level graph new_incoming_edges: Dict[nodes.Node, MultiConnectorEdge] = {} new_outgoing_edges: Dict[nodes.Node, MultiConnectorEdge] = {} source_accesses = set() sink_accesses = set() for node in nstate.source_nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in transients and node.data not in reshapes): new_incoming_edges[node] = inputs[node.data] source_accesses.add(node) for node in nstate.sink_nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in transients and node.data not in reshapes): new_outgoing_edges[node] = outputs[node.data] sink_accesses.add(node) ####################################################### # Replace data on inlined SDFG nodes/edges # Replace data names with their top-level counterparts repldict = {} repldict.update(transients) repldict.update({ k: v.data.data for k, v in itertools.chain(inputs.items(), outputs.items()) }) # Add views whenever reshapes are necessary for dname in reshapes: desc = nsdfg.arrays[dname] # To avoid potential confusion, rename protected __return keyword if dname.startswith('__return'): newname = f'{nsdfg.name}_ret{dname[8:]}' else: newname = dname newname, _ = sdfg.add_view(newname, desc.shape, desc.dtype, storage=desc.storage, strides=desc.strides, offset=desc.offset, debuginfo=desc.debuginfo, allow_conflicts=desc.allow_conflicts, total_size=desc.total_size, alignment=desc.alignment, may_alias=desc.may_alias, find_new_name=True) repldict[dname] = newname for node in nstate.nodes(): if isinstance(node, nodes.AccessNode) and node.data in repldict: node.data = repldict[node.data] for edge in nstate.edges(): if edge.data.data in repldict: edge.data.data = repldict[edge.data.data] # Add extra access nodes for out/in view nodes for node in nstate.nodes(): if isinstance(node, nodes.AccessNode) and node.data in reshapes: if nstate.in_degree(node) > 0 and nstate.out_degree(node) > 0: # Such a node has to be in the output set edge = outputs[node.data] # Redirect outgoing edges through access node out_edges = list(nstate.out_edges(node)) anode = nstate.add_access(edge.data.data) vnode = nstate.add_access(node.data) nstate.add_nedge(node, anode, edge.data) nstate.add_nedge(anode, vnode, edge.data) for e in out_edges: nstate.remove_edge(e) nstate.add_edge(vnode, e.src_conn, e.dst, e.dst_conn, e.data) ####################################################### # Add nested SDFG into top-level SDFG # Add nested nodes into original state subgraph = SubgraphView(nstate, [ n for n in nstate.nodes() if n not in (source_accesses | sink_accesses) ]) state.add_nodes_from(subgraph.nodes()) for edge in subgraph.edges(): state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn, edge.data) ####################################################### # Reconnect inlined SDFG # If a source/sink node is one of the inputs/outputs, reconnect it, # replacing memlets in outgoing/incoming paths modified_edges = set() modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate, state, True) modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate, state, False) # Reshape: add connections to viewed data self._modify_reshape_data(reshapes, repldict, inputs, nstate, state, True) self._modify_reshape_data(reshapes, repldict, outputs, nstate, state, False) # Modify all other internal edges pertaining to input/output nodes for node in subgraph.nodes(): if isinstance(node, nodes.AccessNode): if node.data in input_set or node.data in output_set: if node.data in input_set: outer_edge = inputs[input_set[node.data]] else: outer_edge = outputs[output_set[node.data]] for edge in state.all_edges(node): if (edge not in modified_edges and edge.data.data == node.data): for e in state.memlet_tree(edge): if e.data.data == node.data: e._data = helpers.unsqueeze_memlet( e.data, outer_edge.data) # If source/sink node is not connected to a source/destination access # node, and the nested SDFG is in a scope, connect to scope with empty # memlets if nsdfg_scope_entry is not None: for node in subgraph.nodes(): if state.in_degree(node) == 0: state.add_edge(nsdfg_scope_entry, None, node, None, Memlet()) if state.out_degree(node) == 0: state.add_edge(node, None, nsdfg_scope_exit, None, Memlet()) # Replace nested SDFG parents with new SDFG for node in nstate.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent = state node.sdfg.parent_sdfg = sdfg node.sdfg.parent_nsdfg_node = node # Remove all unused external inputs/output memlet paths, as well as # resulting isolated nodes removed_in_edges = self._remove_edge_path(state, inputs, set(inputs.keys()) - source_accesses, reverse=True) removed_out_edges = self._remove_edge_path(state, outputs, set(outputs.keys()) - sink_accesses, reverse=False) # Re-add in/out edges to first/last nodes in subgraph order = [ x for x in nx.topological_sort(nstate._nx) if isinstance(x, nodes.AccessNode) ] for edge in removed_in_edges: # Find first access node that refers to this edge node = next(n for n in order if n.data == edge.data.data) state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn, edge.data) for edge in removed_out_edges: # Find last access node that refers to this edge node = next(n for n in reversed(order) if n.data == edge.data.data) state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn, edge.data) ####################################################### # Remove nested SDFG node state.remove_node(nsdfg_node)
def apply_pass( self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Dict[SDFGState, Set[str]]]: """ Removes unreachable dataflow throughout SDFG states. :param sdfg: The SDFG to modify. :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass results as ``{Pass subclass name: returned object from pass}``. If not run in a pipeline, an empty dictionary is expected. :return: A dictionary mapping states to removed data descriptor names, or None if nothing changed. """ # Depends on the following analysis passes: # * State reachability # * Read/write access sets per state reachable: Dict[SDFGState, Set[SDFGState]] = pipeline_results['StateReachability'] access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]] = pipeline_results['AccessSets'] result: Dict[SDFGState, Set[str]] = defaultdict(set) # Traverse SDFG backwards for state in reversed(list(cfg.stateorder_topological_sort(sdfg))): ############################################# # Analysis ############################################# # Compute states where memory will no longer be read writes = access_sets[state][1] descendants = reachable[state] descendant_reads = set().union(*(access_sets[succ][0] for succ in descendants)) no_longer_used: Set[str] = set(data for data in writes if data not in descendant_reads) # Compute dead nodes dead_nodes: List[nodes.Node] = [] # Propagate deadness backwards within a state for node in sdutil.dfs_topological_sort(state, reverse=True): if self._is_node_dead(node, sdfg, state, dead_nodes, no_longer_used): dead_nodes.append(node) # Scope exit nodes are only dead if their corresponding entry nodes are live_nodes = set() for node in dead_nodes: if isinstance(node, nodes.ExitNode) and state.entry_node( node) not in dead_nodes: live_nodes.add(node) dead_nodes = dtypes.deduplicate( [n for n in dead_nodes if n not in live_nodes]) if not dead_nodes: continue # Remove nodes while preserving scopes scopes_to_reconnect: Set[nodes.Node] = set() for node in state.nodes(): # Look for scope exits that will be disconnected if isinstance(node, nodes.ExitNode) and node not in dead_nodes: if any(n in dead_nodes for n in state.predecessors(node)): scopes_to_reconnect.add(node) # Two types of scope disconnections may occur: # 1. Two scope exits will no longer be connected # 2. A predecessor of dead nodes is in a scope and not connected to its exit # Case (1) is taken care of by ``remove_memlet_path`` # Case (2) is handled below # Reconnect scopes if scopes_to_reconnect: schildren = state.scope_children() for exit_node in scopes_to_reconnect: entry_node = state.entry_node(exit_node) for node in schildren[entry_node]: if node is exit_node: continue if isinstance(node, nodes.EntryNode): node = state.exit_node(node) # If node will be disconnected from exit node, add an empty memlet if all(succ in dead_nodes for succ in state.successors(node)): state.add_nedge(node, exit_node, Memlet()) ############################################# # Removal ############################################# predecessor_nsdfgs: Dict[nodes.NestedSDFG, Set[str]] = defaultdict(set) for node in dead_nodes: # Remove memlet paths and connectors pertaining to dead nodes for e in state.in_edges(node): mtree = state.memlet_tree(e) for leaf in mtree.leaves(): # Keep track of predecessors of removed nodes for connector pruning if isinstance(leaf.src, nodes.NestedSDFG): predecessor_nsdfgs[leaf.src].add(leaf.src_conn) state.remove_memlet_path(leaf) # Remove the node itself as necessary state.remove_node(node) result[state].update(dead_nodes) # Remove isolated access nodes after elimination access_nodes = set(state.data_nodes()) for node in access_nodes: if state.degree(node) == 0: state.remove_node(node) result[state].add(node) # Prune now-dead connectors for node, dead_conns in predecessor_nsdfgs.items(): for conn in dead_conns: # If removed connector belonged to a nested SDFG, and no other input connector shares name, # make nested data transient (dead dataflow elimination would remove internally as necessary) if conn not in node.in_connectors: node.sdfg.arrays[conn].transient = True # Update read sets for the predecessor states to reuse access_nodes -= result[state] access_node_names = set(n.data for n in access_nodes if state.out_degree(n) > 0) access_sets[state] = (access_node_names, access_sets[state][1]) return result or None
def insert_sdfg_element(sdfg_str, type, parent_uuid, edge_a_uuid): sdfg_answer = load_sdfg_from_json(sdfg_str) sdfg = sdfg_answer['sdfg'] uuid = 'error' ret = find_graph_element_by_uuid(sdfg, parent_uuid) parent = ret['element'] libname = None if type is not None and isinstance(type, str): split_type = type.split('|') if len(split_type) == 2: type = split_type[0] libname = split_type[1] if type == 'SDFGState': if parent is None: parent = sdfg elif isinstance(parent, nodes.NestedSDFG): parent = parent.sdfg state = parent.add_state() uuid = [get_uuid(state)] elif type == 'AccessNode': arrays = list(parent.parent.arrays.keys()) if len(arrays) == 0: parent.parent.add_array('tmp', [1], dtype=dtypes.float64) arrays = list(parent.parent.arrays.keys()) node = parent.add_access(arrays[0]) uuid = [get_uuid(node, parent)] elif type == 'Map': map_entry, map_exit = parent.add_map('map', dict(i='0:1')) uuid = [get_uuid(map_entry, parent), get_uuid(map_exit, parent)] elif type == 'Consume': consume_entry, consume_exit = parent.add_consume('consume', ('i', '1')) uuid = [get_uuid(consume_entry, parent), get_uuid(consume_exit, parent)] elif type == 'Tasklet': tasklet = parent.add_tasklet( name='placeholder', inputs={'in'}, outputs={'out'}, code='') uuid = [get_uuid(tasklet, parent)] elif type == 'NestedSDFG': sub_sdfg = SDFG('nested_sdfg') sub_sdfg.add_array('in', [1], dtypes.float32) sub_sdfg.add_array('out', [1], dtypes.float32) nsdfg = parent.add_nested_sdfg(sub_sdfg, sdfg, {'in'}, {'out'}) uuid = [get_uuid(nsdfg, parent)] elif type == 'LibraryNode': if libname is None: return { 'error': { 'message': 'Failed to add library node', 'details': 'Must provide a valid library node type', }, } libnode_class = pydoc.locate(libname) libnode = libnode_class() parent.add_node(libnode) uuid = [get_uuid(libnode, parent)] elif type == 'Edge': edge_start_ret = find_graph_element_by_uuid(sdfg, edge_a_uuid) edge_start = edge_start_ret['element'] edge_parent = edge_start_ret['parent'] if edge_start is not None: if edge_parent is None: edge_parent = sdfg if isinstance(edge_parent, SDFGState): if not (isinstance(edge_start, nodes.Node) and isinstance(parent, nodes.Node)): return { 'error': { 'message': 'Failed to add edge', 'details': 'Must connect two nodes or two states', }, } memlet = Memlet() edge_parent.add_edge(edge_start, None, parent, None, memlet) elif isinstance(edge_parent, SDFG): if not (isinstance(edge_start, SDFGState) and isinstance(parent, SDFGState)): return { 'error': { 'message': 'Failed to add edge', 'details': 'Must connect two nodes or two states', }, } isedge = InterstateEdge() edge_parent.add_edge(edge_start, parent, isedge) uuid = ['NONE'] else: raise ValueError('No edge starting point provided') old_meta = disable_save_metadata() new_sdfg_str = sdfg.to_json() restore_save_metadata(old_meta) return { 'sdfg': new_sdfg_str, 'uuid': uuid, }
def make_sdfg(specialize): if specialize: sdfg = SDFG("histogram_fpga_parallel_{}_{}x{}".format( P.get(), H.get(), W.get())) else: sdfg = SDFG("histogram_fpga_parallel_{}".format(P.get())) copy_to_fpga_state = make_copy_to_fpga_state(sdfg) state = sdfg.add_state("compute") # Compute module nested_sdfg = make_compute_nested_sdfg(state) tasklet = state.add_nested_sdfg(nested_sdfg, sdfg, {"A_pipe_in"}, {"hist_pipe_out"}) A_pipes_out = state.add_stream("A_pipes", dtype, shape=(P, ), transient=True, storage=StorageType.FPGA_Local) A_pipes_in = state.add_stream("A_pipes", dtype, shape=(P, ), transient=True, storage=StorageType.FPGA_Local) hist_pipes_out = state.add_stream("hist_pipes", itype, shape=(P, ), transient=True, storage=StorageType.FPGA_Local) unroll_entry, unroll_exit = state.add_map( "unroll_compute", {"p": "0:P"}, schedule=dace.ScheduleType.FPGA_Device, unroll=True) state.add_memlet_path(unroll_entry, A_pipes_in, memlet=EmptyMemlet()) state.add_memlet_path(hist_pipes_out, unroll_exit, memlet=EmptyMemlet()) state.add_memlet_path(A_pipes_in, tasklet, dst_conn="A_pipe_in", memlet=Memlet.simple(A_pipes_in, "p", num_accesses="W*H")) state.add_memlet_path(tasklet, hist_pipes_out, src_conn="hist_pipe_out", memlet=Memlet.simple(hist_pipes_out, "p", num_accesses="num_bins")) # Read module a_device = state.add_array("A_device", (H, W), dtype, transient=True, storage=dace.dtypes.StorageType.FPGA_Global) read_entry, read_exit = state.add_map("read_map", { "h": "0:H", "w": "0:W:P" }, schedule=ScheduleType.FPGA_Device) a_val = state.add_array("A_val", (P, ), dtype, transient=True, storage=StorageType.FPGA_Local) read_unroll_entry, read_unroll_exit = state.add_map( "read_unroll", {"p": "0:P"}, schedule=ScheduleType.FPGA_Device, unroll=True) read_tasklet = state.add_tasklet("read", {"A_in"}, {"A_pipe"}, "A_pipe = A_in[p]") state.add_memlet_path(a_device, read_entry, a_val, memlet=Memlet(a_val, num_accesses=1, subset=Indices(["0"]), vector_length=P.get(), other_subset=Indices(["h", "w"]))) state.add_memlet_path(a_val, read_unroll_entry, read_tasklet, dst_conn="A_in", memlet=Memlet.simple(a_val, "0", veclen=P.get(), num_accesses=1)) state.add_memlet_path(read_tasklet, read_unroll_exit, read_exit, A_pipes_out, src_conn="A_pipe", memlet=Memlet.simple(A_pipes_out, "p")) # Write module hist_pipes_in = state.add_stream("hist_pipes", itype, shape=(P, ), transient=True, storage=StorageType.FPGA_Local) hist_device_out = state.add_array( "hist_device", (num_bins, ), itype, transient=True, storage=dace.dtypes.StorageType.FPGA_Global) merge_entry, merge_exit = state.add_map("merge", {"nb": "0:num_bins"}, schedule=ScheduleType.FPGA_Device) merge_reduce = state.add_reduce("lambda a, b: a + b", (0, ), "0", schedule=ScheduleType.FPGA_Device) state.add_memlet_path(hist_pipes_in, merge_entry, merge_reduce, memlet=Memlet.simple(hist_pipes_in, "0:P", num_accesses=P)) state.add_memlet_path(merge_reduce, merge_exit, hist_device_out, memlet=dace.memlet.Memlet.simple( hist_device_out, "nb")) copy_to_host_state = make_copy_to_host_state(sdfg) sdfg.add_edge(copy_to_fpga_state, state, dace.graph.edges.InterstateEdge()) sdfg.add_edge(state, copy_to_host_state, dace.graph.edges.InterstateEdge()) return sdfg
def apply(self, graph: SDFGState, sdfg: SDFG) -> nodes.MapEntry: me = self.mapentry # Add new map within map mx = graph.exit_node(me) new_me, new_mx = graph.add_map('warp_tile', dict(__tid=f'0:{self.warp_size}'), dtypes.ScheduleType.GPU_ThreadBlock) __tid = symbolic.pystr_to_symbolic('__tid') for e in graph.out_edges(me): xfh.reconnect_edge_through_map(graph, e, new_me, True) for e in graph.in_edges(mx): xfh.reconnect_edge_through_map(graph, e, new_mx, False) # Stride and offset all internal maps maps_to_stride = xfh.get_internal_scopes(graph, new_me, immediate=True) for nstate, nmap in maps_to_stride: nsdfg = nstate.parent nsdfg_node = nsdfg.parent_nsdfg_node # Map cannot be partitioned across a warp if (nmap.range.size()[-1] < self.warp_size) == True: continue if nsdfg is not sdfg and nsdfg_node is not None: nsdfg_node.symbol_mapping['__tid'] = __tid if '__tid' not in nsdfg.symbols: nsdfg.add_symbol('__tid', dtypes.int32) nmap.range[-1] = (nmap.range[-1][0], nmap.range[-1][1] - __tid, nmap.range[-1][2] * self.warp_size) subgraph = nstate.scope_subgraph(nmap) subgraph.replace(nmap.params[-1], f'{nmap.params[-1]} + __tid') inner_map_exit = nstate.exit_node(nmap) # If requested, replicate maps with multiple dependent maps if self.replicate_maps: destinations = [ nstate.memlet_path(edge)[-1].dst for edge in nstate.out_edges(inner_map_exit) ] for dst in destinations: # Transformation will not replicate map with more than one # output if len(destinations) != 1: break if not isinstance(dst, nodes.AccessNode): continue # Not leading to access node if not xfh.contained_in(nstate, dst, new_me): continue # Memlet path goes out of map if not nsdfg.arrays[dst.data].transient: continue # Cannot modify non-transients for edge in nstate.out_edges(dst)[1:]: rep_subgraph = xfh.replicate_scope( nsdfg, nstate, subgraph) rep_edge = nstate.out_edges( rep_subgraph.sink_nodes()[0])[0] # Add copy of data newdesc = copy.deepcopy(sdfg.arrays[dst.data]) newname = nsdfg.add_datadesc(dst.data, newdesc, find_new_name=True) newaccess = nstate.add_access(newname) # Redirect edges xfh.redirect_edge(nstate, rep_edge, new_dst=newaccess, new_data=newname) xfh.redirect_edge(nstate, edge, new_src=newaccess, new_data=newname) # If has WCR, add warp-collaborative reduction on outputs for out_edge in nstate.out_edges(inner_map_exit): dst = nstate.memlet_path(out_edge)[-1].dst if not xfh.contained_in(nstate, dst, new_me): # Skip edges going out of map continue if dst.desc(nsdfg).storage == dtypes.StorageType.GPU_Global: # Skip shared memory continue if out_edge.data.wcr is not None: ctype = nsdfg.arrays[out_edge.data.data].dtype.ctype redtype = detect_reduction_type(out_edge.data.wcr) if redtype == dtypes.ReductionType.Custom: raise NotImplementedError credtype = ('dace::ReductionType::' + str(redtype)[str(redtype).find('.') + 1:]) # One element: tasklet if out_edge.data.subset.num_elements() == 1: # Add local access between thread-local and warp reduction name = nsdfg._find_new_name(out_edge.data.data) nsdfg.add_scalar( name, nsdfg.arrays[out_edge.data.data].dtype, transient=True) # Initialize thread-local to global value read = nstate.add_read(out_edge.data.data) write = nstate.add_write(name) edge = nstate.add_nedge(read, write, copy.deepcopy(out_edge.data)) edge.data.wcr = None xfh.state_fission(nsdfg, SubgraphView(nstate, [read, write])) newnode = nstate.add_access(name) nstate.remove_edge(out_edge) edge = nstate.add_edge(out_edge.src, out_edge.src_conn, newnode, None, copy.deepcopy(out_edge.data)) for e in nstate.memlet_path(edge): e.data.data = name e.data.subset = subsets.Range([(0, 0, 1)]) wrt = nstate.add_tasklet( 'warpreduce', {'__a'}, {'__out'}, f'__out = dace::warpReduce<{credtype}, {ctype}>::reduce(__a);', dtypes.Language.CPP) nstate.add_edge(newnode, None, wrt, '__a', Memlet(name)) out_edge.data.wcr = None nstate.add_edge(wrt, '__out', out_edge.dst, None, out_edge.data) else: # More than one element: mapped tasklet # Could be a parallel summation # TODO(later): Check if reduction continue # End of WCR to warp reduction # Make nested SDFG out of new scope xfh.nest_state_subgraph(sdfg, graph, graph.scope_subgraph(new_me, False, False)) return new_me
def _make_sdfg_getrs(node, parent_state, parent_sdfg, implementation): arr_desc = node.validate(parent_sdfg, parent_state) (ain_shape, ain_dtype, ain_strides, bin_shape, bin_dtype, bin_strides, out_shape, out_dtype, out_strides, n, rhs) = arr_desc dtype = ain_dtype sdfg = dace.SDFG("{l}_sdfg".format(l=node.label)) ain_arr = sdfg.add_array('_ain', ain_shape, dtype=ain_dtype, strides=ain_strides) ainout_arr = sdfg.add_array('_ainout', [n, n], dtype=ain_dtype, transient=True) bin_arr = sdfg.add_array('_bin', bin_shape, dtype=bin_dtype, strides=bin_strides) binout_shape = [n, rhs] if implementation == 'cuSolverDn': binout_shape = [rhs, n] binout_arr = sdfg.add_array('_binout', binout_shape, dtype=out_dtype, transient=True) bout_arr = sdfg.add_array('_bout', out_shape, dtype=out_dtype, strides=out_strides) ipiv_arr = sdfg.add_array('_pivots', [n], dtype=dace.int32, transient=True) info_arr = sdfg.add_array('_info', [1], dtype=dace.int32, transient=True) state = sdfg.add_state("{l}_state".format(l=node.label)) getrf_node = Getrf('getrf') getrf_node.implementation = implementation getrs_node = Getrs('getrs') getrs_node.implementation = implementation ain = state.add_read('_ain') ainout1 = state.add_read('_ainout') ainout2 = state.add_access('_ainout') bin = state.add_read('_bin') binout1 = state.add_read('_binout') binout2 = state.add_read('_binout') bout = state.add_access('_bout') if implementation == 'cuSolverDn': transpose_ain = Transpose('AT', dtype=ain_dtype) transpose_ain.implementation = 'cuBLAS' state.add_edge(ain, None, transpose_ain, '_inp', Memlet.from_array(*ain_arr)) state.add_edge(transpose_ain, '_out', ainout1, None, Memlet.from_array(*ainout_arr)) transpose_bin = Transpose('bT', dtype=bin_dtype) transpose_bin.implementation = 'cuBLAS' state.add_edge(bin, None, transpose_bin, '_inp', Memlet.from_array(*bin_arr)) state.add_edge(transpose_bin, '_out', binout1, None, Memlet.from_array(*binout_arr)) transpose_out = Transpose('XT', dtype=bin_dtype) transpose_out.implementation = 'cuBLAS' state.add_edge(binout2, None, transpose_out, '_inp', Memlet.from_array(*binout_arr)) state.add_edge(transpose_out, '_out', bout, None, Memlet.from_array(*bout_arr)) else: state.add_nedge(ain, ainout1, Memlet.from_array(*ain_arr)) state.add_nedge(bin, binout1, Memlet.from_array(*bin_arr)) state.add_nedge(binout2, bout, Memlet.from_array(*bout_arr)) ipiv = state.add_access('_pivots') info1 = state.add_write('_info') info2 = state.add_write('_info') state.add_memlet_path(ainout1, getrf_node, dst_conn="_xin", memlet=Memlet.from_array(*ainout_arr)) state.add_memlet_path(getrf_node, info1, src_conn="_res", memlet=Memlet.from_array(*info_arr)) state.add_memlet_path(getrf_node, ipiv, src_conn="_ipiv", memlet=Memlet.from_array(*ipiv_arr)) state.add_memlet_path(getrf_node, ainout2, src_conn="_xout", memlet=Memlet.from_array(*ainout_arr)) state.add_memlet_path(ainout2, getrs_node, dst_conn="_a", memlet=Memlet.from_array(*ainout_arr)) state.add_memlet_path(binout1, getrs_node, dst_conn="_rhs_in", memlet=Memlet.from_array(*binout_arr)) state.add_memlet_path(ipiv, getrs_node, dst_conn="_ipiv", memlet=Memlet.from_array(*ipiv_arr)) state.add_memlet_path(getrs_node, info2, src_conn="_res", memlet=Memlet.from_array(*info_arr)) state.add_memlet_path(getrs_node, binout2, src_conn="_rhs_out", memlet=Memlet.from_array(*binout_arr)) return sdfg
def apply(self, state: SDFGState, sdfg: SDFG): adesc = self.a.desc(sdfg) bdesc = self.b.desc(sdfg) edge = state.edges_between(self.a, self.b)[0] if len(adesc.shape) >= len(bdesc.shape): copy_shape = edge.data.get_src_subset(edge, state).size() copy_a = True else: copy_shape = edge.data.get_dst_subset(edge, state).size() copy_a = False maprange = {f'__i{i}': (0, s - 1, 1) for i, s in enumerate(copy_shape)} av = self.a.data bv = self.b.data avnode = self.a bvnode = self.b # Linearize and delinearize to get index expression for other side if copy_a: a_index = [ symbolic.pystr_to_symbolic(f'__i{i}') for i in range(len(copy_shape)) ] b_index = self.delinearize_linearize( bdesc, copy_shape, edge.data.get_dst_subset(edge, state)) else: a_index = self.delinearize_linearize( adesc, copy_shape, edge.data.get_src_subset(edge, state)) b_index = [ symbolic.pystr_to_symbolic(f'__i{i}') for i in range(len(copy_shape)) ] a_subset = subsets.Range([(ind, ind, 1) for ind in a_index]) b_subset = subsets.Range([(ind, ind, 1) for ind in b_index]) # Set schedule based on GPU arrays schedule = dtypes.ScheduleType.Default if adesc.storage == dtypes.StorageType.GPU_Global or bdesc.storage == dtypes.StorageType.GPU_Global: # If already inside GPU kernel if is_devicelevel_gpu(sdfg, state, self.a): schedule = dtypes.ScheduleType.Sequential else: schedule = dtypes.ScheduleType.GPU_Device # Add copy map t, _, _ = state.add_mapped_tasklet( 'copy', maprange, dict(__inp=Memlet(data=av, subset=a_subset)), '__out = __inp', dict(__out=Memlet(data=bv, subset=b_subset)), schedule, external_edges=True, input_nodes={av: avnode}, output_nodes={bv: bvnode}) # Set connector types (due to this transformation appearing in codegen, after connector # types have been resolved) t.in_connectors['__inp'] = adesc.dtype t.out_connectors['__out'] = bdesc.dtype # Remove old edge state.remove_edge(edge)