def apply(self, sdfg: SDFG): graph = sdfg.nodes()[self.state_id] tasklet = graph.nodes()[self.subgraph[StreamTransient.tasklet]] map_exit = graph.nodes()[self.subgraph[StreamTransient.map_exit]] outer_map_exit = graph.nodes()[self.subgraph[ StreamTransient.outer_map_exit]] memlet = None edge = None for e in graph.out_edges(map_exit): memlet = e.data # TODO: What if there's more than one? if e.dst == outer_map_exit and isinstance(sdfg.arrays[memlet.data], data.Stream): edge = e break tasklet_memlet = None for e in graph.out_edges(tasklet): tasklet_memlet = e.data if tasklet_memlet.data == memlet.data: break bbox = map_exit.map.range.bounding_box_size() bbox_approx = [symbolic.overapproximate(dim) for dim in bbox] dataname = memlet.data # Create the new node: Temporary stream and an access node newname, _ = sdfg.add_stream('trans_' + dataname, sdfg.arrays[memlet.data].dtype, bbox_approx[0], storage=sdfg.arrays[memlet.data].storage, transient=True, find_new_name=True) snode = graph.add_access(newname) to_stream_mm = copy.deepcopy(memlet) to_stream_mm.data = snode.data tasklet_memlet.data = snode.data if self.with_buffer: newname_arr, _ = sdfg.add_transient('strans_' + dataname, [bbox_approx[0]], sdfg.arrays[memlet.data].dtype, find_new_name=True) anode = graph.add_access(newname_arr) to_array_mm = copy.deepcopy(memlet) to_array_mm.data = anode.data graph.add_edge(snode, None, anode, None, to_array_mm) else: anode = snode # Reconnect, assuming one edge to the stream graph.remove_edge(edge) graph.add_edge(map_exit, edge.src_conn, snode, None, to_stream_mm) graph.add_edge(anode, None, outer_map_exit, edge.dst_conn, memlet) return
def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG): node.validate(sdfg, state) inedge: graph.MultiConnectorEdge = state.in_edges(node)[0] outedge: graph.MultiConnectorEdge = state.out_edges(node)[0] insubset = dcpy(inedge.data.subset) isqdim = insubset.squeeze() outsubset = dcpy(outedge.data.subset) osqdim = outsubset.squeeze() input_dims = len(insubset) output_dims = len(outsubset) input_data = sdfg.arrays[inedge.data.data] output_data = sdfg.arrays[outedge.data.data] if len(osqdim) == 0: # Fix for scalars osqdim = [0] # Standardize and squeeze axes axes = node.axes if node.axes else [ i for i in range(len(inedge.data.subset)) ] axes = [axis for axis in axes if axis in isqdim] assert node.identity is not None # Create nested SDFG nsdfg = SDFG('reduce') nsdfg.add_array('_in', insubset.size(), input_data.dtype, strides=[ s for i, s in enumerate(input_data.strides) if i in isqdim ], storage=input_data.storage) nsdfg.add_array('_out', outsubset.size(), output_data.dtype, strides=[ s for i, s in enumerate(output_data.strides) if i in osqdim ], storage=output_data.storage) nsdfg.add_transient('acc', [1], nsdfg.arrays['_in'].dtype, dtypes.StorageType.Register) nstate = nsdfg.add_state() # Interleave input and output axes to match input memlet ictr, octr = 0, 0 input_subset = [] for i in isqdim: if i in axes: input_subset.append('_i%d' % ictr) ictr += 1 else: input_subset.append('_o%d' % octr) octr += 1 ome, omx = nstate.add_map( 'reduce_output', { '_o%d' % i: '0:%s' % symstr(sz) for i, sz in enumerate(outsubset.size()) }) outm = dace.Memlet.simple( '_out', ','.join(['_o%d' % i for i in range(output_dims)])) #wcr_str=node.wcr) inmm = dace.Memlet.simple('_in', ','.join(input_subset)) idt = nstate.add_tasklet('reset', {}, {'o'}, f'o = {node.identity}') nstate.add_edge(ome, None, idt, None, dace.Memlet()) accread = nstate.add_access('acc') accwrite = nstate.add_access('acc') nstate.add_edge(idt, 'o', accread, None, dace.Memlet('acc')) # Add inner map, which corresponds to the range to reduce, containing # an identity tasklet ime, imx = nstate.add_map('reduce_values', { '_i%d' % i: '0:%s' % symstr(insubset.size()[isqdim.index(axis)]) for i, axis in enumerate(sorted(axes)) }, schedule=dtypes.ScheduleType.Sequential) # Add identity tasklet for reduction t = nstate.add_tasklet('identity', {'a', 'b'}, {'o'}, 'o = b') # Connect everything r = nstate.add_read('_in') w = nstate.add_write('_out') nstate.add_memlet_path(r, ome, ime, t, dst_conn='b', memlet=inmm) nstate.add_memlet_path(accread, ime, t, dst_conn='a', memlet=dace.Memlet('acc[0]')) nstate.add_memlet_path(t, imx, accwrite, src_conn='o', memlet=dace.Memlet('acc[0]', wcr=node.wcr)) nstate.add_memlet_path(accwrite, omx, w, memlet=outm) # Rename outer connectors and add to node inedge._dst_conn = '_in' outedge._src_conn = '_out' node.add_in_connector('_in') node.add_out_connector('_out') from dace.transformation import dataflow nsdfg.apply_transformations_repeated(dataflow.MapCollapse) return nsdfg