def apply(self, sdfg: SDFG): state: SDFGState = sdfg.nodes()[self.state_id] nsdfg_node = state.nodes()[self.subgraph[InlineSDFG._nested_sdfg]] nsdfg: SDFG = nsdfg_node.sdfg nstate: SDFGState = nsdfg.nodes()[0] if nsdfg_node.schedule is not dtypes.ScheduleType.Default: infer_types.set_default_schedule_and_storage_types( nsdfg, nsdfg_node.schedule) nsdfg_scope_entry = state.entry_node(nsdfg_node) nsdfg_scope_exit = (state.exit_node(nsdfg_scope_entry) if nsdfg_scope_entry is not None else None) ####################################################### # Collect and update top-level SDFG metadata # Global/init/exit code for loc, code in nsdfg.global_code.items(): sdfg.append_global_code(code.code, loc) for loc, code in nsdfg.init_code.items(): sdfg.append_init_code(code.code, loc) for loc, code in nsdfg.exit_code.items(): sdfg.append_exit_code(code.code, loc) # Constants for cstname, cstval in nsdfg.constants.items(): if cstname in sdfg.constants: if cstval != sdfg.constants[cstname]: warnings.warn('Constant value mismatch for "%s" while ' 'inlining SDFG. Inner = %s != %s = outer' % (cstname, cstval, sdfg.constants[cstname])) else: sdfg.add_constant(cstname, cstval) # Find original source/destination edges (there is only one edge per # connector, according to match) inputs: Dict[str, MultiConnectorEdge] = {} outputs: Dict[str, MultiConnectorEdge] = {} input_set: Dict[str, str] = {} output_set: Dict[str, str] = {} for e in state.in_edges(nsdfg_node): inputs[e.dst_conn] = e input_set[e.data.data] = e.dst_conn for e in state.out_edges(nsdfg_node): outputs[e.src_conn] = e output_set[e.data.data] = e.src_conn # Access nodes that need to be reshaped reshapes: Set(str) = set() for aname, array in nsdfg.arrays.items(): if array.transient: continue edge = None if aname in inputs: edge = inputs[aname] if len(array.shape) > len(edge.data.subset): reshapes.add(aname) continue if aname in outputs: edge = outputs[aname] if len(array.shape) > len(edge.data.subset): reshapes.add(aname) continue if edge is not None and not InlineSDFG._check_strides( array.strides, sdfg.arrays[edge.data.data].strides, edge.data, nsdfg_node): reshapes.add(aname) # Replace symbols using invocation symbol mapping # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes for symname, symvalue in nsdfg_node.symbol_mapping.items(): if str(symname) != str(symvalue): nsdfg.replace(symname, '__dacesym_' + symname) for symname, symvalue in nsdfg_node.symbol_mapping.items(): if str(symname) != str(symvalue): nsdfg.replace('__dacesym_' + symname, symvalue) # All transients become transients of the parent (if data already # exists, find new name) # Mapping from nested transient name to top-level name transients: Dict[str, str] = {} for node in nstate.nodes(): if isinstance(node, nodes.AccessNode): datadesc = nsdfg.arrays[node.data] if node.data not in transients and datadesc.transient: name = sdfg.add_datadesc('%s_%s' % (nsdfg.label, node.data), datadesc, find_new_name=True) transients[node.data] = name # All transients of edges between code nodes are also added to parent for edge in nstate.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): if edge.data.data is not None: datadesc = nsdfg.arrays[edge.data.data] if edge.data.data not in transients and datadesc.transient: name = sdfg.add_datadesc('%s_%s' % (nsdfg.label, edge.data.data), datadesc, find_new_name=True) transients[edge.data.data] = name # Collect nodes to add to top-level graph new_incoming_edges: Dict[nodes.Node, MultiConnectorEdge] = {} new_outgoing_edges: Dict[nodes.Node, MultiConnectorEdge] = {} source_accesses = set() sink_accesses = set() for node in nstate.source_nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in transients and node.data not in reshapes): new_incoming_edges[node] = inputs[node.data] source_accesses.add(node) for node in nstate.sink_nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in transients and node.data not in reshapes): new_outgoing_edges[node] = outputs[node.data] sink_accesses.add(node) ####################################################### # Replace data on inlined SDFG nodes/edges # Replace data names with their top-level counterparts repldict = {} repldict.update(transients) repldict.update({ k: v.data.data for k, v in itertools.chain(inputs.items(), outputs.items()) }) # Add views whenever reshapes are necessary for dname in reshapes: desc = nsdfg.arrays[dname] # To avoid potential confusion, rename protected __return keyword if dname.startswith('__return'): newname = f'{nsdfg.name}_ret{dname[8:]}' else: newname = dname newname, _ = sdfg.add_view(newname, desc.shape, desc.dtype, storage=desc.storage, strides=desc.strides, offset=desc.offset, debuginfo=desc.debuginfo, allow_conflicts=desc.allow_conflicts, total_size=desc.total_size, alignment=desc.alignment, may_alias=desc.may_alias, find_new_name=True) repldict[dname] = newname for node in nstate.nodes(): if isinstance(node, nodes.AccessNode) and node.data in repldict: node.data = repldict[node.data] for edge in nstate.edges(): if edge.data.data in repldict: edge.data.data = repldict[edge.data.data] # Add extra access nodes for out/in view nodes for node in nstate.nodes(): if isinstance(node, nodes.AccessNode) and node.data in reshapes: if nstate.in_degree(node) > 0 and nstate.out_degree(node) > 0: # Such a node has to be in the output set edge = outputs[node.data] # Redirect outgoing edges through access node out_edges = list(nstate.out_edges(node)) anode = nstate.add_access(edge.data.data) vnode = nstate.add_access(node.data) nstate.add_nedge(node, anode, edge.data) nstate.add_nedge(anode, vnode, edge.data) for e in out_edges: nstate.remove_edge(e) nstate.add_edge(vnode, e.src_conn, e.dst, e.dst_conn, e.data) ####################################################### # Add nested SDFG into top-level SDFG # Add nested nodes into original state subgraph = SubgraphView(nstate, [ n for n in nstate.nodes() if n not in (source_accesses | sink_accesses) ]) state.add_nodes_from(subgraph.nodes()) for edge in subgraph.edges(): state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn, edge.data) ####################################################### # Reconnect inlined SDFG # If a source/sink node is one of the inputs/outputs, reconnect it, # replacing memlets in outgoing/incoming paths modified_edges = set() modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate, state, True) modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate, state, False) # Reshape: add connections to viewed data self._modify_reshape_data(reshapes, repldict, inputs, nstate, state, True) self._modify_reshape_data(reshapes, repldict, outputs, nstate, state, False) # Modify all other internal edges pertaining to input/output nodes for node in subgraph.nodes(): if isinstance(node, nodes.AccessNode): if node.data in input_set or node.data in output_set: if node.data in input_set: outer_edge = inputs[input_set[node.data]] else: outer_edge = outputs[output_set[node.data]] for edge in state.all_edges(node): if (edge not in modified_edges and edge.data.data == node.data): for e in state.memlet_tree(edge): if e.data.data == node.data: e._data = helpers.unsqueeze_memlet( e.data, outer_edge.data) # If source/sink node is not connected to a source/destination access # node, and the nested SDFG is in a scope, connect to scope with empty # memlets if nsdfg_scope_entry is not None: for node in subgraph.nodes(): if state.in_degree(node) == 0: state.add_edge(nsdfg_scope_entry, None, node, None, Memlet()) if state.out_degree(node) == 0: state.add_edge(node, None, nsdfg_scope_exit, None, Memlet()) # Replace nested SDFG parents with new SDFG for node in nstate.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent = state node.sdfg.parent_sdfg = sdfg node.sdfg.parent_nsdfg_node = node # Remove all unused external inputs/output memlet paths, as well as # resulting isolated nodes removed_in_edges = self._remove_edge_path(state, inputs, set(inputs.keys()) - source_accesses, reverse=True) removed_out_edges = self._remove_edge_path(state, outputs, set(outputs.keys()) - sink_accesses, reverse=False) # Re-add in/out edges to first/last nodes in subgraph order = [ x for x in nx.topological_sort(nstate._nx) if isinstance(x, nodes.AccessNode) ] for edge in removed_in_edges: # Find first access node that refers to this edge node = next(n for n in order if n.data == edge.data.data) state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn, edge.data) for edge in removed_out_edges: # Find last access node that refers to this edge node = next(n for n in reversed(order) if n.data == edge.data.data) state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn, edge.data) ####################################################### # Remove nested SDFG node state.remove_node(nsdfg_node)
def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG): node.validate(sdfg, state) input_edge: graph.MultiConnectorEdge = state.in_edges(node)[0] output_edge: graph.MultiConnectorEdge = state.out_edges(node)[0] input_dims = len(input_edge.data.subset) output_dims = len(output_edge.data.subset) input_data = sdfg.arrays[input_edge.data.data] output_data = sdfg.arrays[output_edge.data.data] # Setup all locations in which code will be written cuda_globalcode = CodeIOStream() cuda_initcode = CodeIOStream() cuda_exitcode = CodeIOStream() host_globalcode = CodeIOStream() host_localcode = CodeIOStream() output_memlet = output_edge.data # Try to autodetect reduction type redtype = detect_reduction_type(node.wcr) node_id = state.node_id(node) state_id = sdfg.node_id(state) idstr = '{sdfg}_{state}_{node}'.format(sdfg=sdfg.name, state=state_id, node=node_id) if node.out_connectors: dtype = next(node.out_connectors.values()) else: dtype = sdfg.arrays[output_memlet.data].dtype output_type = dtype.ctype if node.identity is None: raise ValueError('For device reduce nodes, initial value must be ' 'specified') # Create a functor or use an existing one for reduction if redtype == dtypes.ReductionType.Custom: body, [arg1, arg2] = unparse_cr_split(sdfg, node.wcr) cuda_globalcode.write( """ struct __reduce_{id} {{ template <typename T> DACE_HDFI T operator()(const T &{arg1}, const T &{arg2}) const {{ {contents} }} }};""".format(id=idstr, arg1=arg1, arg2=arg2, contents=body), sdfg, state_id, node_id) reduce_op = ', __reduce_' + idstr + '(), ' + symstr(node.identity) elif redtype in ExpandReduceCUDADevice._SPECIAL_RTYPES: reduce_op = '' else: credtype = 'dace::ReductionType::' + str( redtype)[str(redtype).find('.') + 1:] reduce_op = ((', dace::_wcr_fixed<%s, %s>()' % (credtype, output_type)) + ', ' + symstr(node.identity)) # Obtain some SDFG-related information input_memlet = input_edge.data reduce_shape = input_memlet.subset.bounding_box_size() num_items = ' * '.join(symstr(s) for s in reduce_shape) input = (input_memlet.data + ' + ' + cpp_array_expr(sdfg, input_memlet, with_brackets=False)) output = (output_memlet.data + ' + ' + cpp_array_expr(sdfg, output_memlet, with_brackets=False)) input_dims = input_memlet.subset.dims() output_dims = output_memlet.subset.data_dims() reduce_all_axes = (node.axes is None or len(node.axes) == input_dims) if reduce_all_axes: reduce_last_axes = False else: reduce_last_axes = sorted(node.axes) == list( range(input_dims - len(node.axes), input_dims)) if (not reduce_all_axes) and (not reduce_last_axes): raise NotImplementedError( 'Multiple axis reductions not supported on GPUs. Please use ' 'the pure expansion or make reduce axes the last in the array.' ) # Verify that data is on the GPU if input_data.storage not in [ dtypes.StorageType.GPU_Global, dtypes.StorageType.CPU_Pinned ]: raise ValueError('Input of GPU reduction must either reside ' ' in global GPU memory or pinned CPU memory') if output_data.storage not in [ dtypes.StorageType.GPU_Global, dtypes.StorageType.CPU_Pinned ]: raise ValueError('Output of GPU reduction must either reside ' ' in global GPU memory or pinned CPU memory') # Determine reduction type kname = (ExpandReduceCUDADevice._SPECIAL_RTYPES[redtype] if redtype in ExpandReduceCUDADevice._SPECIAL_RTYPES else 'Reduce') # Create temp memory for this GPU cuda_globalcode.write( """ void *__cub_storage_{sdfg}_{state}_{node} = NULL; size_t __cub_ssize_{sdfg}_{state}_{node} = 0; """.format(sdfg=sdfg.name, state=state_id, node=node_id), sdfg, state_id, node) if reduce_all_axes: reduce_type = 'DeviceReduce' reduce_range = num_items reduce_range_def = 'size_t num_items' reduce_range_use = 'num_items' reduce_range_call = num_items elif reduce_last_axes: num_reduce_axes = len(node.axes) not_reduce_axes = reduce_shape[:-num_reduce_axes] reduce_axes = reduce_shape[-num_reduce_axes:] num_segments = ' * '.join([symstr(s) for s in not_reduce_axes]) segment_size = ' * '.join([symstr(s) for s in reduce_axes]) reduce_type = 'DeviceSegmentedReduce' iterator = 'dace::stridedIterator({size})'.format( size=segment_size) reduce_range = '{num}, {it}, {it} + 1'.format(num=num_segments, it=iterator) reduce_range_def = 'size_t num_segments, size_t segment_size' iterator_use = 'dace::stridedIterator(segment_size)' reduce_range_use = 'num_segments, {it}, {it} + 1'.format( it=iterator_use) reduce_range_call = '%s, %s' % (num_segments, segment_size) # Call CUB to get the storage size, allocate and free it cuda_initcode.write( """ cub::{reduce_type}::{kname}(nullptr, __cub_ssize_{sdfg}_{state}_{node}, ({intype}*)nullptr, ({outtype}*)nullptr, {reduce_range}{redop}); cudaMalloc(&__cub_storage_{sdfg}_{state}_{node}, __cub_ssize_{sdfg}_{state}_{node}); """.format(sdfg=sdfg.name, state=state_id, node=node_id, reduce_type=reduce_type, reduce_range=reduce_range, redop=reduce_op, intype=input_data.dtype.ctype, outtype=output_data.dtype.ctype, kname=kname), sdfg, state_id, node) cuda_exitcode.write( 'cudaFree(__cub_storage_{sdfg}_{state}_{node});'.format( sdfg=sdfg.name, state=state_id, node=node_id), sdfg, state_id, node) # Write reduction function definition cuda_globalcode.write(""" DACE_EXPORTED void __dace_reduce_{id}({intype} *input, {outtype} *output, {reduce_range_def}, cudaStream_t stream); void __dace_reduce_{id}({intype} *input, {outtype} *output, {reduce_range_def}, cudaStream_t stream) {{ cub::{reduce_type}::{kname}(__cub_storage_{id}, __cub_ssize_{id}, input, output, {reduce_range_use}{redop}, stream); }} """.format(id=idstr, intype=input_data.dtype.ctype, outtype=output_data.dtype.ctype, reduce_type=reduce_type, reduce_range_def=reduce_range_def, reduce_range_use=reduce_range_use, kname=kname, redop=reduce_op)) # Write reduction function definition in caller file host_globalcode.write( """ DACE_EXPORTED void __dace_reduce_{id}({intype} *input, {outtype} *output, {reduce_range_def}, cudaStream_t stream); """.format(id=idstr, reduce_range_def=reduce_range_def, intype=input_data.dtype.ctype, outtype=output_data.dtype.ctype), sdfg, state_id, node) # Call reduction function where necessary host_localcode.write( '__dace_reduce_{id}({input}, {output}, {reduce_range_call}, __dace_current_stream);' .format(id=idstr, input=input, output=output, reduce_range_call=reduce_range_call)) # Make tasklet tnode = dace.nodes.Tasklet('reduce', {'_in': dace.pointer(input_data.dtype)}, {'_out': dace.pointer(output_data.dtype)}, host_localcode.getvalue(), language=dace.Language.CPP) # Add the rest of the code sdfg.append_global_code(host_globalcode.getvalue()) sdfg.append_global_code(cuda_globalcode.getvalue(), 'cuda') sdfg.append_init_code(cuda_initcode.getvalue(), 'cuda') sdfg.append_exit_code(cuda_exitcode.getvalue(), 'cuda') # Rename outer connectors and add to node input_edge._dst_conn = '_in' output_edge._src_conn = '_out' node.add_in_connector('_in') node.add_out_connector('_out') return tnode
def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG): node.validate(sdfg, state) input_edge: graph.MultiConnectorEdge = state.in_edges(node)[0] output_edge: graph.MultiConnectorEdge = state.out_edges(node)[0] input_dims = len(input_edge.data.subset) input_data = sdfg.arrays[input_edge.data.data] output_data = sdfg.arrays[output_edge.data.data] # Setup all locations in which code will be written cuda_globalcode = CodeIOStream() localcode = CodeIOStream() # Try to autodetect reduction type redtype = detect_reduction_type(node.wcr) node_id = state.node_id(node) state_id = sdfg.node_id(state) idstr = '{sdfg}_{state}_{node}'.format(sdfg=sdfg.name, state=state_id, node=node_id) # Obtain some SDFG-related information input_memlet = input_edge.data output_memlet = output_edge.data if node.out_connectors: dtype = next(node.out_connectors.values()) else: dtype = sdfg.arrays[output_memlet.data].dtype output_type = dtype.ctype if node.identity is None: raise ValueError('For device reduce nodes, initial value must be ' 'specified') # Create a functor or use an existing one for reduction if redtype == dtypes.ReductionType.Custom: body, [arg1, arg2] = unparse_cr_split(sdfg, node.wcr) cuda_globalcode.write( """ struct __reduce_{id} {{ template <typename T> DACE_HDFI T operator()(const T &{arg1}, const T &{arg2}) const {{ {contents} }} }};""".format(id=idstr, arg1=arg1, arg2=arg2, contents=body), sdfg, state_id, node_id) reduce_op = ', __reduce_' + idstr + '(), ' + symstr(node.identity) elif redtype in ExpandReduceCUDADevice._SPECIAL_RTYPES: reduce_op = '' else: credtype = 'dace::ReductionType::' + str( redtype)[str(redtype).find('.') + 1:] reduce_op = ((', dace::_wcr_fixed<%s, %s>()' % (credtype, output_type)) + ', ' + symstr(node.identity)) # Try to obtain the number of threads in the block, or use the default # configuration block_threads = devicelevel_block_size(sdfg, state, node) if block_threads is not None: block_threads = functools.reduce(lambda a, b: a * b, block_threads, 1) # Checks if block_threads is None: raise ValueError('Block-wide GPU reduction must occur within' ' a GPU kernel') if issymbolic(block_threads, sdfg.constants): raise ValueError('Block size has to be constant for block-wide ' 'reduction (got %s)' % str(block_threads)) if (node.axes is not None and len(node.axes) < input_dims): raise ValueError( 'Only full reduction is supported for block-wide reduce,' ' please use the pure expansion') if (input_data.storage != dtypes.StorageType.Register or output_data.storage != dtypes.StorageType.Register): raise ValueError( 'Block-wise reduction only supports GPU register inputs ' 'and outputs') if redtype in ExpandReduceCUDABlock._SPECIAL_RTYPES: raise ValueError('%s block reduction not supported' % redtype) credtype = 'dace::ReductionType::' + str( redtype)[str(redtype).find('.') + 1:] if redtype == dtypes.ReductionType.Custom: redop = '__reduce_%s()' % idstr else: redop = 'dace::_wcr_fixed<%s, %s>()' % (credtype, output_type) # Allocate shared memory for block reduce localcode.write(""" typedef cub::BlockReduce<{type}, {numthreads}> BlockReduce_{id}; __shared__ typename BlockReduce_{id}::TempStorage temp_storage_{id}; """.format(id=idstr, type=output_data.dtype.ctype, numthreads=block_threads)) input = (input_memlet.data + ' + ' + cpp_array_expr(sdfg, input_memlet, with_brackets=False)) output = cpp_array_expr(sdfg, output_memlet) localcode.write(""" {output} = BlockReduce_{id}(temp_storage_{id}).Reduce({input}, {redop}); """.format(id=idstr, redop=redop, input=input_memlet.data, output=output)) # Make tasklet tnode = dace.nodes.Tasklet('reduce', {'_in': dace.pointer(input_data.dtype)}, {'_out': dace.pointer(output_data.dtype)}, localcode.getvalue(), language=dace.Language.CPP) # Add the rest of the code sdfg.append_global_code(cuda_globalcode.getvalue(), 'cuda') # Rename outer connectors and add to node input_edge._dst_conn = '_in' output_edge._src_conn = '_out' node.add_in_connector('_in') node.add_out_connector('_out') return tnode
def apply(self, outer_state: SDFGState, sdfg: SDFG): nsdfg_node = self.nested_sdfg nsdfg: SDFG = nsdfg_node.sdfg if nsdfg_node.schedule is not dtypes.ScheduleType.Default: infer_types.set_default_schedule_and_storage_types( nsdfg, nsdfg_node.schedule) ####################################################### # Collect and update top-level SDFG metadata # Global/init/exit code for loc, code in nsdfg.global_code.items(): sdfg.append_global_code(code.code, loc) for loc, code in nsdfg.init_code.items(): sdfg.append_init_code(code.code, loc) for loc, code in nsdfg.exit_code.items(): sdfg.append_exit_code(code.code, loc) # Environments for nstate in nsdfg.nodes(): for node in nstate.nodes(): if isinstance(node, nodes.CodeNode): node.environments |= nsdfg_node.environments # Constants for cstname, cstval in nsdfg.constants.items(): if cstname in sdfg.constants: if cstval != sdfg.constants[cstname]: warnings.warn('Constant value mismatch for "%s" while ' 'inlining SDFG. Inner = %s != %s = outer' % (cstname, cstval, sdfg.constants[cstname])) else: sdfg.add_constant(cstname, cstval) # Symbols outer_symbols = {str(k): v for k, v in sdfg.symbols.items()} for ise in sdfg.edges(): outer_symbols.update(ise.data.new_symbols(sdfg, outer_symbols)) # Find original source/destination edges (there is only one edge per # connector, according to match) inputs: Dict[str, MultiConnectorEdge] = {} outputs: Dict[str, MultiConnectorEdge] = {} input_set: Dict[str, str] = {} output_set: Dict[str, str] = {} for e in outer_state.in_edges(nsdfg_node): inputs[e.dst_conn] = e input_set[e.data.data] = e.dst_conn for e in outer_state.out_edges(nsdfg_node): outputs[e.src_conn] = e output_set[e.data.data] = e.src_conn # Replace symbols using invocation symbol mapping # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes symbolic.safe_replace(nsdfg_node.symbol_mapping, nsdfg.replace_dict) # Access nodes that need to be reshaped # reshapes: Set(str) = set() # for aname, array in nsdfg.arrays.items(): # if array.transient: # continue # edge = None # if aname in inputs: # edge = inputs[aname] # if len(array.shape) > len(edge.data.subset): # reshapes.add(aname) # continue # if aname in outputs: # edge = outputs[aname] # if len(array.shape) > len(edge.data.subset): # reshapes.add(aname) # continue # if edge is not None and not InlineMultistateSDFG._check_strides( # array.strides, sdfg.arrays[edge.data.data].strides, # edge.data, nsdfg_node): # reshapes.add(aname) # Mapping from nested transient name to top-level name transients: Dict[str, str] = {} # All transients become transients of the parent (if data already # exists, find new name) for nstate in nsdfg.nodes(): for node in nstate.nodes(): if isinstance(node, nodes.AccessNode): datadesc = nsdfg.arrays[node.data] if node.data not in transients and datadesc.transient: new_name = node.data if (new_name in sdfg.arrays or new_name in outer_symbols or new_name in sdfg.constants): new_name = f'{nsdfg.label}_{node.data}' name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True) transients[node.data] = name # All transients of edges between code nodes are also added to parent for edge in nstate.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): if edge.data.data is not None: datadesc = nsdfg.arrays[edge.data.data] if edge.data.data not in transients and datadesc.transient: new_name = edge.data.data if (new_name in sdfg.arrays or new_name in outer_symbols or new_name in sdfg.constants): new_name = f'{nsdfg.label}_{edge.data.data}' name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True) transients[edge.data.data] = name ####################################################### # Replace data on inlined SDFG nodes/edges # Replace data names with their top-level counterparts repldict = {} repldict.update(transients) repldict.update({ k: v.data.data for k, v in itertools.chain(inputs.items(), outputs.items()) }) symbolic.safe_replace(repldict, lambda m: replace_datadesc_names(nsdfg, m), value_as_string=True) # Add views whenever reshapes are necessary # for dname in reshapes: # desc = nsdfg.arrays[dname] # # To avoid potential confusion, rename protected __return keyword # if dname.startswith('__return'): # newname = f'{nsdfg.name}_ret{dname[8:]}' # else: # newname = dname # newname, _ = sdfg.add_view(newname, # desc.shape, # desc.dtype, # storage=desc.storage, # strides=desc.strides, # offset=desc.offset, # debuginfo=desc.debuginfo, # allow_conflicts=desc.allow_conflicts, # total_size=desc.total_size, # alignment=desc.alignment, # may_alias=desc.may_alias, # find_new_name=True) # repldict[dname] = newname # Add extra access nodes for out/in view nodes # inv_reshapes = {repldict[r]: r for r in reshapes} # for nstate in nsdfg.nodes(): # for node in nstate.nodes(): # if isinstance(node, # nodes.AccessNode) and node.data in inv_reshapes: # if nstate.in_degree(node) > 0 and nstate.out_degree( # node) > 0: # # Such a node has to be in the output set # edge = outputs[inv_reshapes[node.data]] # # Redirect outgoing edges through access node # out_edges = list(nstate.out_edges(node)) # anode = nstate.add_access(edge.data.data) # vnode = nstate.add_access(node.data) # nstate.add_nedge(node, anode, edge.data) # nstate.add_nedge(anode, vnode, edge.data) # for e in out_edges: # nstate.remove_edge(e) # nstate.add_edge(vnode, e.src_conn, e.dst, # e.dst_conn, e.data) # Make unique names for states statenames = set(s.label for s in sdfg.nodes()) for nstate in nsdfg.nodes(): if nstate.label in statenames: newname = data.find_new_name(nstate.label, statenames) statenames.add(newname) nstate.set_label(newname) ####################################################### # Collect and modify interstate edges as necessary outer_assignments = set() for e in sdfg.edges(): outer_assignments |= e.data.assignments.keys() inner_assignments = set() for e in nsdfg.edges(): inner_assignments |= e.data.assignments.keys() assignments_to_replace = inner_assignments & outer_assignments sym_replacements: Dict[str, str] = {} allnames = set(outer_symbols.keys()) | set(sdfg.arrays.keys()) for assign in assignments_to_replace: newname = data.find_new_name(assign, allnames) allnames.add(newname) sym_replacements[assign] = newname nsdfg.replace_dict(sym_replacements) ####################################################### # Add nested SDFG states into top-level SDFG outer_start_state = sdfg.start_state sdfg.add_nodes_from(nsdfg.nodes()) for ise in nsdfg.edges(): sdfg.add_edge(ise.src, ise.dst, ise.data) ####################################################### # Reconnect inlined SDFG source = nsdfg.start_state sinks = nsdfg.sink_nodes() # Reconnect state machine for e in sdfg.in_edges(outer_state): sdfg.add_edge(e.src, source, e.data) for e in sdfg.out_edges(outer_state): for sink in sinks: sdfg.add_edge(sink, e.dst, e.data) # Modify start state as necessary if outer_start_state is outer_state: sdfg.start_state = sdfg.node_id(source) # TODO: Modify memlets by offsetting # If both source and sink nodes are inputs/outputs, reconnect once # edges_to_ignore = self._modify_access_to_access(new_incoming_edges, # nsdfg, nstate, state, # orig_data) # source_to_outer = {n: e.src for n, e in new_incoming_edges.items()} # sink_to_outer = {n: e.dst for n, e in new_outgoing_edges.items()} # # If a source/sink node is one of the inputs/outputs, reconnect it, # # replacing memlets in outgoing/incoming paths # modified_edges = set() # modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate, # state, sink_to_outer, True, # edges_to_ignore) # modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate, # state, source_to_outer, # False, edges_to_ignore) # # Reshape: add connections to viewed data # self._modify_reshape_data(reshapes, repldict, inputs, nstate, state, # True) # self._modify_reshape_data(reshapes, repldict, outputs, nstate, state, # False) # Modify all other internal edges pertaining to input/output nodes # for nstate in nsdfg.nodes(): # for node in nstate.nodes(): # if isinstance(node, nodes.AccessNode): # if node.data in input_set or node.data in output_set: # if node.data in input_set: # outer_edge = inputs[input_set[node.data]] # else: # outer_edge = outputs[output_set[node.data]] # for edge in state.all_edges(node): # if (edge not in modified_edges # and edge.data.data == node.data): # for e in state.memlet_tree(edge): # if e.data.data == node.data: # e._data = helpers.unsqueeze_memlet( # e.data, outer_edge.data) # Replace nested SDFG parents with new SDFG for nstate in nsdfg.nodes(): nstate.parent = sdfg for node in nstate.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent_sdfg = sdfg node.sdfg.parent_nsdfg_node = node ####################################################### # Remove nested SDFG and state sdfg.remove_node(outer_state) return nsdfg.nodes()