class Vectorization(pattern_matching.Transformation): """ Implements the vectorization transformation. Vectorization matches when all the input and output memlets of a tasklet inside a map access the inner-most loop variable in their last dimension. The transformation changes the step of the inner-most loop to be equal to the length of the vector and vectorizes the memlets. """ vector_len = Property(desc="Vector length", dtype=int, default=4) propagate_parent = Property(desc="Propagate vector length through " "parent SDFGs", dtype=bool, default=False) strided_map = Property(desc="Use strided map range (jump by vector length)" " instead of modifying memlets", dtype=bool, default=False) _map_entry = nodes.MapEntry(nodes.Map("", [], [])) _tasklet = nodes.Tasklet('_') _map_exit = nodes.MapExit(nodes.Map("", [], [])) @staticmethod def expressions(): return [ sdutil.node_path_graph(Vectorization._map_entry, Vectorization._tasklet, Vectorization._map_exit) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): map_entry = graph.nodes()[candidate[Vectorization._map_entry]] tasklet = graph.nodes()[candidate[Vectorization._tasklet]] param = symbolic.pystr_to_symbolic(map_entry.map.params[-1]) found = False # Check if all edges, adjacent to the tasklet, # use the parameter in their last dimension. for _src, _, _dest, _, memlet in graph.all_edges(tasklet): # Cases that do not matter for vectorization if memlet.data is None: # Empty memlets continue if isinstance(sdfg.arrays[memlet.data], data.Stream): # Streams continue # Vectorization can not be applied in WCR if memlet.wcr is not None: return False try: subset = memlet.subset veclen = memlet.veclen except AttributeError: return False if subset is None: return False try: if veclen > symbolic.pystr_to_symbolic('1'): return False for idx, expr in enumerate(subset): if isinstance(expr, tuple): for ex in expr: ex = symbolic.pystr_to_symbolic(ex) symbols = ex.free_symbols if param in symbols: if idx == subset.dims() - 1: found = True else: return False else: expr = symbolic.pystr_to_symbolic(expr) symbols = expr.free_symbols if param in symbols: if idx == subset.dims() - 1: found = True else: return False except TypeError: # cannot determine truth value of Relational return False return found @staticmethod def match_to_str(graph, candidate): map_entry = candidate[Vectorization._map_entry] tasklet = candidate[Vectorization._tasklet] map_exit = candidate[Vectorization._map_exit] return ' -> '.join(str(node) for node in [map_entry, tasklet, map_exit]) def apply(self, sdfg): graph = sdfg.nodes()[self.state_id] map_entry = graph.nodes()[self.subgraph[Vectorization._map_entry]] tasklet = graph.nodes()[self.subgraph[Vectorization._tasklet]] map_exit = graph.nodes()[self.subgraph[Vectorization._map_exit]] param = symbolic.pystr_to_symbolic(map_entry.map.params[-1]) # Create new vector size. vector_size = self.vector_len # Change the step of the inner-most dimension. dim_from, dim_to, dim_step = map_entry.map.range[-1] if self.strided_map: map_entry.map.range[-1] = (dim_from, dim_to, vector_size) else: map_entry.map.range[-1] = (dim_from, (dim_to + 1) / vector_size - 1, dim_step) # TODO: Postamble and/or preamble non-vectorized map # Vectorize memlets adjacent to the tasklet. processed_edges = set() for edge in graph.all_edges(tasklet): _src, _, _dest, _, memlet = edge if memlet.data is None: # Empty memlets continue lastindex = memlet.subset[-1] if isinstance(lastindex, tuple): symbols = set() for indd in lastindex: symbols.update( symbolic.pystr_to_symbolic(indd).free_symbols) else: symbols = symbolic.pystr_to_symbolic( memlet.subset[-1]).free_symbols if param not in symbols: continue try: # propagate vector length inside this SDFG for e in graph.memlet_tree(edge): e.data.veclen = vector_size if not self.strided_map and e not in processed_edges: e.data.subset.replace({param: vector_size * param}) processed_edges.add(e) # propagate to the parent (TODO: handle multiple level of nestings) if self.propagate_parent and sdfg.parent is not None: source_edge = graph.memlet_path(edge)[0] sink_edge = graph.memlet_path(edge)[-1] # Find parent Nested SDFG node parent_node = next(n for n in sdfg.parent.nodes() if isinstance(n, nodes.NestedSDFG) and n.sdfg.name == sdfg.name) # continue in propagating the vector length following the # path that arrives to source_edge or starts from sink_edge for pe in sdfg.parent.all_edges(parent_node): if str(pe.dst_conn) == str(source_edge.src) or str( pe.src_conn) == str(sink_edge.dst): for ppe in sdfg.parent.memlet_tree(pe): ppe.data.veclen = vector_size if (not self.strided_map and ppe not in processed_edges): ppe.data.subset.replace( {param: vector_size * param}) processed_edges.add(ppe) except AttributeError: raise return
class Vectorization(transformation.Transformation): """ Implements the vectorization transformation. Vectorization matches when all the input and output memlets of a tasklet inside a map access the inner-most loop variable in their last dimension. The transformation changes the step of the inner-most loop to be equal to the length of the vector and vectorizes the memlets. """ vector_len = Property(desc="Vector length", dtype=int, default=4) propagate_parent = Property(desc="Propagate vector length through " "parent SDFGs", dtype=bool, default=False) strided_map = Property(desc="Use strided map range (jump by vector length)" " instead of modifying memlets", dtype=bool, default=True) preamble = Property( dtype=bool, default=None, allow_none=True, desc='Force creation or skipping a preamble map without vectors') postamble = Property( dtype=bool, default=None, allow_none=True, desc='Force creation or skipping a postamble map without vectors') _map_entry = nodes.MapEntry(nodes.Map("", [], [])) _tasklet = nodes.Tasklet('_') _map_exit = nodes.MapExit(nodes.Map("", [], [])) @staticmethod def expressions(): return [ sdutil.node_path_graph(Vectorization._map_entry, Vectorization._tasklet, Vectorization._map_exit) ] def can_be_applied(self, graph, candidate, expr_index, sdfg, strict=False): map_entry = graph.nodes()[candidate[Vectorization._map_entry]] tasklet = graph.nodes()[candidate[Vectorization._tasklet]] param = symbolic.pystr_to_symbolic(map_entry.map.params[-1]) found = False # Strided maps cannot be vectorized if map_entry.map.range[-1][2] != 1 and self.strided_map: return False # Check if all edges, adjacent to the tasklet, # use the parameter in their contiguous dimension. for e, conntype in graph.all_edges_and_connectors(tasklet): # Cases that do not matter for vectorization if e.data.data is None: # Empty memlets continue if isinstance(sdfg.arrays[e.data.data], data.Stream): # Streams continue # Vectorization can not be applied in WCR # if e.data.wcr is not None: # return False subset = e.data.subset array = sdfg.arrays[e.data.data] # If already vectorized or a pointer, do not apply if isinstance(conntype, (dtypes.vector, dtypes.pointer)): return False try: for idx, expr in enumerate(subset): if isinstance(expr, tuple): for ex in expr: ex = symbolic.pystr_to_symbolic(ex) symbols = ex.free_symbols if param in symbols: if array.strides[idx] == 1: found = True else: return False else: expr = symbolic.pystr_to_symbolic(expr) symbols = expr.free_symbols if param in symbols: if array.strides[idx] == 1: found = True else: return False except TypeError: # cannot determine truth value of Relational return False return found @staticmethod def match_to_str(graph, candidate): map_entry = candidate[Vectorization._map_entry] tasklet = candidate[Vectorization._tasklet] map_exit = candidate[Vectorization._map_exit] return ' -> '.join(str(node) for node in [map_entry, tasklet, map_exit]) def apply(self, sdfg: SDFG): graph = sdfg.nodes()[self.state_id] map_entry = graph.nodes()[self.subgraph[Vectorization._map_entry]] tasklet = graph.nodes()[self.subgraph[Vectorization._tasklet]] param = symbolic.pystr_to_symbolic(map_entry.map.params[-1]) # Create new vector size. vector_size = self.vector_len dim_from, dim_to, dim_skip = map_entry.map.range[-1] # Determine whether to create preamble or postamble maps if self.preamble is not None: create_preamble = self.preamble else: create_preamble = not ((dim_from % vector_size == 0) == True or dim_from == 0) if self.postamble is not None: create_postamble = self.postamble else: if isinstance(dim_to, symbolic.SymExpr): create_postamble = (((dim_to.approx + 1) % vector_size == 0) == False) else: create_postamble = (((dim_to + 1) % vector_size == 0) == False) # Determine new range for vectorized map if self.strided_map: new_range = [dim_from, dim_to - vector_size + 1, vector_size] else: new_range = [ dim_from // vector_size, ((dim_to + 1) // vector_size) - 1, dim_skip ] # Create preamble non-vectorized map (replacing the original map) if create_preamble: old_scope = graph.scope_subgraph(map_entry, True, True) new_scope: ScopeSubgraphView = replicate_scope( sdfg, graph, old_scope) new_begin = dim_from + (vector_size - (dim_from % vector_size)) map_entry.map.range[-1] = (dim_from, new_begin - 1, dim_skip) # Replace map_entry with the replicated scope (so that the preamble # will usually come first in topological sort) map_entry = new_scope.entry tasklet = new_scope.nodes()[old_scope.nodes().index(tasklet)] new_range[0] = new_begin # Create postamble non-vectorized map if create_postamble: new_scope: ScopeSubgraphView = replicate_scope( sdfg, graph, graph.scope_subgraph(map_entry, True, True)) dim_to_ex = dim_to + 1 new_scope.entry.map.range[-1] = (dim_to_ex - (dim_to_ex % vector_size), dim_to, dim_skip) # Change the step of the inner-most dimension. map_entry.map.range[-1] = tuple(new_range) # Vectorize connectors adjacent to the tasklet. for edge in graph.all_edges(tasklet): connectors = (tasklet.in_connectors if edge.dst == tasklet else tasklet.out_connectors) conn = edge.dst_conn if edge.dst == tasklet else edge.src_conn if edge.data.data is None: # Empty memlets continue desc = sdfg.arrays[edge.data.data] contigidx = desc.strides.index(1) newlist = [] lastindex = edge.data.subset[contigidx] if isinstance(lastindex, tuple): newlist = [(rb, re, rs) for rb, re, rs in edge.data.subset] symbols = set() for indd in lastindex: symbols.update( symbolic.pystr_to_symbolic(indd).free_symbols) else: newlist = [(rb, rb, 1) for rb in edge.data.subset] symbols = symbolic.pystr_to_symbolic(lastindex).free_symbols oldtype = connectors[conn] if oldtype is None or oldtype.type is None: oldtype = desc.dtype # Vector to scalar WCR edge: change connector and continue if (edge.data.subset.num_elements() == 1 and edge.data.wcr is not None): connectors[conn] = dtypes.vector(oldtype, vector_size) continue if str(param) not in map(str, symbols): continue # Vectorize connector, if not already vectorized if isinstance(oldtype, dtypes.vector): continue connectors[conn] = dtypes.vector(oldtype, vector_size) # Modify memlet subset to match vector length if self.strided_map: rb = newlist[contigidx][0] if self.propagate_parent: newlist[contigidx] = (rb / self.vector_len, rb / self.vector_len, 1) else: newlist[contigidx] = (rb, rb + self.vector_len - 1, 1) else: rb = newlist[contigidx][0] if self.propagate_parent: newlist[contigidx] = (rb, rb, 1) else: newlist[contigidx] = (self.vector_len * rb, self.vector_len * rb + self.vector_len - 1, 1) edge.data.subset = subsets.Range(newlist) edge.data.volume = vector_size # Vector length propagation using data descriptors, recursive traversal # outwards if self.propagate_parent: for edge in graph.all_edges(tasklet): cursdfg = sdfg curedge = edge while cursdfg is not None: arrname = curedge.data.data dtype = cursdfg.arrays[arrname].dtype # Change type and shape to vector if not isinstance(dtype, dtypes.vector): cursdfg.arrays[arrname].dtype = dtypes.vector( dtype, vector_size) new_shape = list(cursdfg.arrays[arrname].shape) contigidx = cursdfg.arrays[arrname].strides.index(1) new_shape[contigidx] /= vector_size try: new_shape[contigidx] = int(new_shape[contigidx]) except TypeError: pass cursdfg.arrays[arrname].shape = new_shape propagation.propagate_memlets_sdfg(cursdfg) # Find matching edge in parent nsdfg = cursdfg.parent_nsdfg_node if nsdfg is None: break tstate = cursdfg.parent curedge = ([ e for e in tstate.in_edges(nsdfg) if e.dst_conn == arrname ] + [ e for e in tstate.out_edges(nsdfg) if e.src_conn == arrname ])[0] cursdfg = cursdfg.parent_sdfg
def __init__(self, *args, **kwargs): self._entry = nodes.EntryNode() self._tasklet = nodes.Tasklet('_') self._exit = nodes.ExitNode() super().__init__(*args, **kwargs)
class StreamTransient(transformation.Transformation): """ Implements the StreamTransient transformation, which adds a transient and stream nodes between nested maps that lead to a stream. The transient then acts as a local buffer. """ with_buffer = Property(dtype=bool, default=True, desc="Use an intermediate buffer for accumulation") _tasklet = nodes.Tasklet('_') _map_exit = nodes.MapExit(nodes.Map("", [], [])) _outer_map_exit = nodes.MapExit(nodes.Map("", [], [])) @staticmethod def expressions(): return [ sdutil.node_path_graph(StreamTransient._tasklet, StreamTransient._map_exit, StreamTransient._outer_map_exit) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): map_exit = graph.nodes()[candidate[StreamTransient._map_exit]] outer_map_exit = graph.nodes()[candidate[ StreamTransient._outer_map_exit]] # Check if there is a streaming output for _src, _, dest, _, memlet in graph.out_edges(map_exit): if isinstance(sdfg.arrays[memlet.data], data.Stream) and dest == outer_map_exit: return True return False @staticmethod def match_to_str(graph, candidate): tasklet = candidate[StreamTransient._tasklet] map_exit = candidate[StreamTransient._map_exit] outer_map_exit = candidate[StreamTransient._outer_map_exit] return ' -> '.join( str(node) for node in [tasklet, map_exit, outer_map_exit]) def apply(self, sdfg: SDFG): graph = sdfg.nodes()[self.state_id] tasklet = graph.nodes()[self.subgraph[StreamTransient._tasklet]] map_exit = graph.nodes()[self.subgraph[StreamTransient._map_exit]] outer_map_exit = graph.nodes()[self.subgraph[ StreamTransient._outer_map_exit]] memlet = None edge = None for e in graph.out_edges(map_exit): memlet = e.data # TODO: What if there's more than one? if e.dst == outer_map_exit and isinstance(sdfg.arrays[memlet.data], data.Stream): edge = e break tasklet_memlet = None for e in graph.out_edges(tasklet): tasklet_memlet = e.data if tasklet_memlet.data == memlet.data: break bbox = map_exit.map.range.bounding_box_size() bbox_approx = [symbolic.overapproximate(dim) for dim in bbox] dataname = memlet.data # Create the new node: Temporary stream and an access node newname, _ = sdfg.add_stream('trans_' + dataname, sdfg.arrays[memlet.data].dtype, 1, bbox_approx[0], [1], transient=True, find_new_name=True) snode = graph.add_access(newname) to_stream_mm = copy.deepcopy(memlet) to_stream_mm.data = snode.data tasklet_memlet.data = snode.data if self.with_buffer: newname_arr, _ = sdfg.add_transient('strans_' + dataname, [bbox_approx[0]], sdfg.arrays[memlet.data].dtype, find_new_name=True) anode = graph.add_access(newname_arr) to_array_mm = copy.deepcopy(memlet) to_array_mm.data = anode.data graph.add_edge(snode, None, anode, None, to_array_mm) else: anode = snode # Reconnect, assuming one edge to the stream graph.remove_edge(edge) graph.add_edge(map_exit, edge.src_conn, snode, None, to_stream_mm) graph.add_edge(anode, None, outer_map_exit, edge.dst_conn, memlet) return
def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG): map_entry = self.map_entry map_exit = graph.exit_node(map_entry) sz = dace.symbol('commsize', dtype=dace.int32, integer=True, positive=True) Px = dace.symbol('Px', dtype=dace.int32, integer=True, positive=True) Py = dace.symbol('Py', dtype=dace.int32, integer=True, positive=True) from dace.data import _prod # NOTE: Maps with step in their ranges are currently not supported if len(map_entry.map.params) == 2: params = map_entry.map.params ranges = [None] * 2 b, e, _ = map_entry.map.range[0] ranges[0] = (0, (e - b + 1) / Px - 1, 1) b, e, _ = map_entry.map.range[1] ranges[1] = (0, (e - b + 1) / Py - 1, 1) strides = [1] else: params = ['__iflat'] sizes = map_entry.map.range.size_exact() total_size = _prod(sizes) ranges = [(0, (total_size) / sz - 1, 1)] strides = [_prod(sizes[i + 1:]) for i in range(len(sizes))] root_name = sdfg.temp_data_name() sdfg.add_scalar(root_name, dace.int32, transient=True) root_node = graph.add_access(root_name) root_tasklet = graph.add_tasklet('_set_root_', {}, {'__out'}, '__out = 0') graph.add_edge(root_tasklet, '__out', root_node, None, dace.Memlet.simple(root_name, '0')) from dace.libraries.mpi import Bcast from dace.libraries.pblas import BlockCyclicScatter, BlockCyclicGather inputs = set() for src, _, _, _, m in graph.in_edges(map_entry): if not isinstance(src, nodes.AccessNode): raise NotImplementedError desc = src.desc(sdfg) if not isinstance(desc, (data.Scalar, data.Array)): raise NotImplementedError if list(desc.shape) != m.src_subset.size_exact(): # Second attempt # TODO: We need a solution for symbols not matching if str(list(desc.shape)) != str(m.src_subset.size_exact()): raise NotImplementedError inputs.add(src) for inp in inputs: desc = inp.desc(sdfg) if isinstance(desc, data.Scalar): local_access = graph.add_access(inp.data) bcast_node = Bcast('_Bcast_') graph.add_edge(inp, None, bcast_node, '_inbuffer', dace.Memlet.from_array(inp.data, desc)) graph.add_edge(root_node, None, bcast_node, '_root', dace.Memlet.simple(root_name, '0')) graph.add_edge(bcast_node, '_outbuffer', local_access, None, dace.Memlet.from_array(inp.data, desc)) for e in graph.edges_between(inp, map_entry): graph.add_edge(local_access, None, map_entry, e.dst_conn, dace.Memlet.from_array(inp.data, desc)) graph.remove_edge(e) elif isinstance(desc, data.Array): local_name, local_arr = sdfg.add_temp_transient( [(desc.shape[0]) // Px, (desc.shape[1]) // Py], dtype=desc.dtype, storage=desc.storage) local_access = graph.add_access(local_name) bsizes_name, bsizes_arr = sdfg.add_temp_transient( (2, ), dtype=dace.int32) bsizes_access = graph.add_access(bsizes_name) bsizes_tasklet = nodes.Tasklet( '_set_bsizes_', {}, {'__out'}, "__out[0] = {x}; __out[1] = {y}".format( x=(desc.shape[0]) // Px, y=(desc.shape[1]) // Py)) graph.add_edge(bsizes_tasklet, '__out', bsizes_access, None, dace.Memlet.from_array(bsizes_name, bsizes_arr)) gdesc_name, gdesc_arr = sdfg.add_temp_transient( (9, ), dtype=dace.int32) gdesc_access = graph.add_access(gdesc_name) ldesc_name, ldesc_arr = sdfg.add_temp_transient( (9, ), dtype=dace.int32) ldesc_access = graph.add_access(ldesc_name) scatter_node = BlockCyclicScatter('_Scatter_') graph.add_edge(inp, None, scatter_node, '_inbuffer', dace.Memlet.from_array(inp.data, desc)) graph.add_edge(bsizes_access, None, scatter_node, '_block_sizes', dace.Memlet.from_array(bsizes_name, bsizes_arr)) graph.add_edge(scatter_node, '_outbuffer', local_access, None, dace.Memlet.from_array(local_name, local_arr)) graph.add_edge(scatter_node, '_gdescriptor', gdesc_access, None, dace.Memlet.from_array(gdesc_name, gdesc_arr)) graph.add_edge(scatter_node, '_ldescriptor', ldesc_access, None, dace.Memlet.from_array(ldesc_name, ldesc_arr)) for e in graph.edges_between(inp, map_entry): graph.add_edge( local_access, None, map_entry, e.dst_conn, dace.Memlet.from_array(local_name, local_arr)) graph.remove_edge(e) for e in graph.out_edges(map_entry): if e.data.data == inp.data: e.data.data = local_name else: raise NotImplementedError outputs = set() for _, _, dst, _, m in graph.out_edges(map_exit): if not isinstance(dst, nodes.AccessNode): raise NotImplementedError desc = dst.desc(sdfg) if not isinstance(desc, data.Array): raise NotImplementedError try: if list(desc.shape) != m.dst_subset.size_exact(): # Second attempt # TODO: We need a solution for symbols not matching if str(list(desc.shape)) != str(m.dst_subset.size_exact()): raise NotImplementedError except AttributeError: if list(desc.shape) != m.subset.size_exact(): # Second attempt # TODO: We need a solution for symbols not matching if str(list(desc.shape)) != str(m.subset.size_exact()): raise NotImplementedError outputs.add(dst) for out in outputs: desc = out.desc(sdfg) if isinstance(desc, data.Scalar): raise NotImplementedError elif isinstance(desc, data.Array): local_name, local_arr = sdfg.add_temp_transient( [(desc.shape[0]) // Px, (desc.shape[1]) // Py], dtype=desc.dtype, storage=desc.storage) local_access = graph.add_access(local_name) bsizes_name, bsizes_arr = sdfg.add_temp_transient( (2, ), dtype=dace.int32) bsizes_access = graph.add_access(bsizes_name) bsizes_tasklet = nodes.Tasklet( '_set_bsizes_', {}, {'__out'}, "__out[0] = {x}; __out[1] = {y}".format( x=(desc.shape[0]) // Px, y=(desc.shape[1]) // Py)) graph.add_edge(bsizes_tasklet, '__out', bsizes_access, None, dace.Memlet.from_array(bsizes_name, bsizes_arr)) scatter_node = BlockCyclicGather('_Gather_') graph.add_edge(local_access, None, scatter_node, '_inbuffer', dace.Memlet.from_array(local_name, local_arr)) graph.add_edge(bsizes_access, None, scatter_node, '_block_sizes', dace.Memlet.from_array(bsizes_name, bsizes_arr)) graph.add_edge(scatter_node, '_outbuffer', out, None, dace.Memlet.from_array(out.data, desc)) for e in graph.edges_between(map_exit, out): graph.add_edge( map_exit, e.src_conn, local_access, None, dace.Memlet.from_array(local_name, local_arr)) graph.remove_edge(e) for e in graph.in_edges(map_exit): if e.data.data == out.data: e.data.data = local_name else: raise NotImplementedError map_entry.map.params = params map_entry.map.range = subsets.Range(ranges)
class OnTheFlyMapFusion(Transformation): _first_map_entry = nodes.MapEntry(nodes.Map('', [], [])) _first_tasklet = nodes.Tasklet('') _first_map_exit = nodes.MapExit(nodes.Map('', [], [])) _array_access = nodes.AccessNode('') _second_map_entry = nodes.MapEntry(nodes.Map('', [], [])) _second_tasklet = nodes.Tasklet('') @staticmethod def expressions(): return [ sdutils.node_path_graph(OnTheFlyMapFusion._first_map_entry, OnTheFlyMapFusion._first_tasklet, OnTheFlyMapFusion._first_map_exit, OnTheFlyMapFusion._array_access, OnTheFlyMapFusion._second_map_entry, OnTheFlyMapFusion._second_tasklet) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): first_map_entry = graph.node( candidate[OnTheFlyMapFusion._first_map_entry]) first_tasklet = graph.node(candidate[OnTheFlyMapFusion._first_tasklet]) first_map_exit = graph.node( candidate[OnTheFlyMapFusion._first_map_exit]) array_access = graph.node(candidate[OnTheFlyMapFusion._array_access]) if len(first_map_exit.in_connectors) != 1: return False if (graph.in_degree(array_access) != 1 or graph.out_degree(array_access) != 1): return False return True @staticmethod def _memlet_offsets(base_memlet, offset_memlet): """ Compute subset offset of `offset_memlet` relative to `base_memlet`. """ def offset(base_range, offset_range): b0, e0, s0 = base_range b1, e1, s1 = offset_range assert e1 - e0 == b1 - b0 and s0 == s1 return int(e1 - e0) return tuple( offset(b, o) for b, o in zip(base_memlet.subset.ranges, offset_memlet.subset.ranges)) @staticmethod def _update_map_connectors(state, array_access, first_map_entry, second_map_entry): """ Remove unused connector (of the to-be-replaced array) from second map entry, add new connectors to second map entry for the inputs used in the first map’s tasklets. """ # Remove edges and connectors from arrays access to second map entry for edge in state.edges_between(array_access, second_map_entry): state.remove_edge_and_connectors(edge) state.remove_node(array_access) # Add new connectors to second map # TODO: implement for the general case with random naming for edge in state.in_edges(first_map_entry): if second_map_entry.add_in_connector(edge.dst_conn): state.add_edge(edge.src, edge.src_conn, second_map_entry, edge.dst_conn, edge.data) @staticmethod def _read_offsets(state, array_name, first_map_exit, second_map_entry): """ Compute offsets of read accesses in second map. """ # Get output memlet of first tasklet output_edges = state.in_edges(first_map_exit) assert len(output_edges) == 1 write_memlet = output_edges[0].data # Find read offsets by looping over second map entry connectors offsets = defaultdict(list) for edge in state.out_edges(second_map_entry): if edge.data.data == array_name: second_map_entry.remove_out_connector(edge.src_conn) state.remove_edge(edge) offset = OnTheFlyMapFusion._memlet_offsets( write_memlet, edge.data) offsets[offset].append(edge) return offsets @staticmethod def _copy_first_map_contents(state, first_map_entry, first_map_exit): nodes = list( state.all_nodes_between(first_map_entry, first_map_exit) - {first_map_entry}) new_nodes = [copy.deepcopy(node) for node in nodes] for node in new_nodes: state.add_node(node) id_map = { state.node_id(old): state.node_id(new) for old, new in zip(nodes, new_nodes) } def map(node): return state.node(id_map[state.node_id(node)]) for edge in state.edges(): if edge.src in nodes or edge.dst in nodes: src = map(edge.src) if edge.src in nodes else edge.src dst = map(edge.dst) if edge.dst in nodes else edge.dst state.add_edge(src, edge.src_conn, dst, edge.dst_conn, copy.deepcopy(edge.data)) return new_nodes def _replicate_first_map(self, sdfg, array_access, first_map_entry, first_map_exit, second_map_entry): """ Replicate tasklet of first map for reach read access in second map. """ state = sdfg.node(self.state_id) array_name = array_access.data array = sdfg.arrays[array_name] read_offsets = self._read_offsets(state, array_name, first_map_exit, second_map_entry) # Replicate first map tasklets once for each read offset access and # connect them to other tasklets accordingly for offset, edges in read_offsets.items(): nodes = self._copy_first_map_contents(state, first_map_entry, first_map_exit) tmp_name = sdfg.temp_data_name() sdfg.add_scalar(tmp_name, array.dtype, transient=True) tmp_access = state.add_access(tmp_name) for node in nodes: for edge in state.edges_between(node, first_map_exit): state.add_edge(edge.src, edge.src_conn, tmp_access, None, dace.Memlet(tmp_name)) state.remove_edge(edge) for edge in state.edges_between(first_map_entry, node): memlet = copy.deepcopy(edge.data) memlet.subset.offset(list(offset), negative=False) second_map_entry.add_out_connector(edge.src_conn) state.add_edge(second_map_entry, edge.src_conn, node, edge.dst_conn, memlet) state.remove_edge(edge) for edge in edges: state.add_edge(tmp_access, None, edge.dst, edge.dst_conn, dace.Memlet(tmp_name)) def apply(self, sdfg: dace.SDFG): state = sdfg.node(self.state_id) first_map_entry = state.node(self.subgraph[self._first_map_entry]) first_tasklet = state.node(self.subgraph[self._first_tasklet]) first_map_exit = state.node(self.subgraph[self._first_map_exit]) array_access = state.node(self.subgraph[self._array_access]) second_map_entry = state.node(self.subgraph[self._second_map_entry]) self._update_map_connectors(state, array_access, first_map_entry, second_map_entry) self._replicate_first_map(sdfg, array_access, first_map_entry, first_map_exit, second_map_entry) state.remove_nodes_from( state.all_nodes_between(first_map_entry, first_map_exit) | {first_map_exit})
def expansion(node, state: SDFGState, sdfg: SDFG): # Extract input and output array views (as generated by memlets) inputs, outputs = _get_inputs_and_outputs(sdfg, state, node) unique_id = "{}_{}_{}_{}".format(clean_onnx_name(node.name), sdfg.sdfg_id, sdfg.node_id(state), state.node_id(node)) _add_ort_init_code(sdfg) sdfg.append_global_code( "OrtExecutableKernel *__ort_kernel_{};\n".format(unique_id)) sdfg.append_global_code( "OrtExecutableKernelContext *__ort_context_{};\n".format( unique_id)) sdfg.append_init_code(""" {{ // Setup for {name} __ort_check_status(__ort_api->CreateExecutableKernelContext("{name}", "{op_type}", &__ort_context_{name})); """.format(name=unique_id, op_type=node.schema.name)) # check if ORT supports CUDA for this node ########################################## # Default: all parameters are on CPU if we execute using cpu outputs_on_host = [True for _ in range(len(outputs))] inputs_on_host = [True for _ in range(len(inputs))] actual_node_schedule = node.schedule if node.schedule == ScheduleType.CPU_Multicore or node.schedule == ScheduleType.Default: provider_index = 0 elif node.schedule == ScheduleType.GPU_Device: provider_index = 1 try: # the ith position indicates whether the ith output is in host memory inputs_on_host, outputs_on_host = check_op(sdfg, state, node, cuda=True) except ONNXOpValidationError as e: # fallback to CPU print("Falling back to CPU for node {}. Reason:\n{}".format( node.name, str(e))) provider_index = 0 actual_node_schedule = ScheduleType.Default else: raise NotImplementedError( "ORT expansion for schedule '{}' is not implemented".format( node.schedule)) # check if we need to insert device copies ########################################## # maps the connectors for which a copy will be required to the storage type required to be connected to the tasklet input_copy_required = defaultdict(dict) output_copy_required = defaultdict(dict) assert len( node.iter_outputs_in_onnx_order(state)) == len(outputs_on_host) assert len( node.iter_inputs_in_onnx_order(state)) == len(inputs_on_host) # check outputs for edge, output_on_host in zip(node.iter_outputs_in_onnx_order(state), outputs_on_host): # get the memlet for this output array = sdfg.arrays[edge.data.data] if output_on_host: is_device_mismatch = not can_access(ScheduleType.Default, array.storage) else: is_device_mismatch = not can_access(ScheduleType.GPU_Device, array.storage) if isinstance( array, dt.Scalar ) and actual_node_schedule == ScheduleType.GPU_Device: # ORT kernels expect scalars to be cudaMalloced. We will copy during expansion to enforce this is_device_mismatch = True output_copy_required[edge.src_conn]['copy_to_array'] = True if is_device_mismatch: # we need to insert a copy output_copy_required[edge.src_conn][ 'storage'] = StorageType.Default if output_on_host else StorageType.GPU_Global # check inputs (same thing again) for edge, input_on_host in zip(node.iter_inputs_in_onnx_order(state), inputs_on_host): array = sdfg.arrays[edge.data.data] if input_on_host: is_device_mismatch = not can_access(ScheduleType.Default, array.storage) else: is_device_mismatch = not can_access(ScheduleType.GPU_Device, array.storage) if isinstance( array, dt.Scalar ) and actual_node_schedule == ScheduleType.GPU_Device: # ORT kernels expect scalars to be cudaMalloced. We will copy during expansion to enforce this is_device_mismatch = True input_copy_required[edge.dst_conn]['copy_to_array'] = True if is_device_mismatch: # we need to insert a copy input_copy_required[edge.dst_conn][ 'storage'] = StorageType.Default if input_on_host else StorageType.GPU_Global # begin codegen ########################################## tasklet_setup_code = "" tasklet_code = "" tasklet_cleanup_code = "" reversed_onnx_dtype_map = { v: k for k, v in ONNX_DTYPES_TO_DACE_TYPE_CLASS.items() } # emit code for inputs and outputs ########################################## in_connectors = {} out_connectors = {} for edge, is_input in node.iter_edges(state): parameter_name = edge.dst_conn if is_input else edge.src_conn if len(output_copy_required) != 0 or len(input_copy_required) != 0: edge_connector_name = "_conn_" + parameter_name else: edge_connector_name = parameter_name input_output_string = "input" if is_input else "output" connector_dict = in_connectors if is_input else out_connectors memlet = edge.data desc = sdfg.arrays[memlet.data] sdfg.append_init_code(""" // Add parameter {parameter_name} __ort_check_status(__ort_api->ExecutableKernelContext_Add{input_output_string}(__ort_context_{id}, ONNX_TENSOR_ELEMENT_DATA_TYPE_{type_string})); """.format(id=unique_id, type_string=reversed_onnx_dtype_map[desc.dtype].upper(), parameter_name=parameter_name, input_output_string=input_output_string.capitalize())) ort_value_name = "ort_value_{input_output_string}_{parameter_name}".format( input_output_string=input_output_string, parameter_name=parameter_name) copy_to_array = ( (parameter_name in output_copy_required and 'copy_to_array' in output_copy_required[parameter_name]) or (parameter_name in input_copy_required and 'copy_to_array' in input_copy_required[parameter_name])) if desc.storage == StorageType.Default: mem_info = "__ort_cpu_mem_info" elif desc.storage == StorageType.GPU_Global: mem_info = "__ort_cuda_mem_info" elif desc.storage == StorageType.CPU_Pinned: mem_info = "__ort_cuda_pinned_mem_info" else: raise ValueError( "Unsupported storage type {} for input to ONNX node". format(desc.storage)) if (isinstance(desc, dt.Scalar) and # when copying to array, the ort value is not a scalar but an array not copy_to_array): tasklet_setup_code += """ OrtValue* {ort_value_name}; __ort_check_status(__ort_api->CreateTensorWithDataAsOrtValue( {mem_info}, &{edge_connector_name}, {data_size} * sizeof({ctype}), nullptr, 0, ONNX_TENSOR_ELEMENT_DATA_TYPE_{type_str}, &{ort_value_name} )); """.format( input_output_string=input_output_string, mem_info=mem_info, edge_connector_name=edge_connector_name, data_size=reduce(lambda x, y: x * y, desc.shape), ctype=desc.dtype.ctype, type_str=reversed_onnx_dtype_map[desc.dtype].upper(), ort_value_name=ort_value_name) connector_dict[parameter_name] = None elif isinstance(desc, dt.Array) or copy_to_array: # when we copy a scalar to an array, that scalar ofc has shape [] dims = [] if copy_to_array else desc.shape # setup dims array tasklet_setup_code += """ int64_t {input_output_string}_{parameter_name}_dims[{dims_size}] = {{{dims}}}; """.format(input_output_string=input_output_string, parameter_name=parameter_name, dims_size=len(dims), dims=", ".join(str(s) for s in dims)) connector_dict[parameter_name] = dace.pointer(desc.dtype) data = "const_cast < void * > (reinterpret_cast < const void * > ({}))".format( edge_connector_name) tasklet_setup_code += """ OrtValue* {ort_value_name}; __ort_check_status(__ort_api->CreateTensorWithDataAsOrtValue( {mem_info}, {data}, {data_size} * sizeof({ctype}), {input_output_string}_{parameter_name}_dims, {dims_size}, ONNX_TENSOR_ELEMENT_DATA_TYPE_{type_str}, &{ort_value_name} )); """.format( input_output_string=input_output_string, data=data, mem_info=mem_info, parameter_name=parameter_name, data_size=reduce(lambda x, y: x * y, desc.shape), ctype=desc.dtype.ctype, dims_size=len(dims), type_str=reversed_onnx_dtype_map[desc.dtype].upper(), ort_value_name=ort_value_name) else: raise NotImplementedError( "Data-descriptor type {} not supported for ONNX nodes". format(type(desc))) tasklet_code += "__ort_check_status(__ort_api->ExecutableKernel_Set{input_output_string_capital}(" \ "__ort_kernel_{unique_id}, {position}, {ort_value_name}));\n".format( input_output_string_capital=input_output_string. capitalize(), ort_value_name=ort_value_name, unique_id=unique_id, position=get_position(node.schema, is_input, parameter_name)) tasklet_cleanup_code += "__ort_api->ReleaseValue(ort_value_{input_output_string}_{parameter_name});\n".format( input_output_string=input_output_string, parameter_name=parameter_name) sdfg.append_init_code("// Setup attributes\n") for name, attr in node.schema.attributes.items(): if hasattr(node, name): sdfg.append_init_code( _gen_attr_init_code("__ort_context_{}".format(unique_id), node.schema.attributes[name], getattr(node, name))) sdfg.prepend_exit_code( "__ort_api->ReleaseExecutableKernelContext(__ort_context_{});\n". format(unique_id)) sdfg.prepend_exit_code( "__ort_api->ReleaseExecutableKernel(__ort_kernel_{});\n".format( unique_id)) tasklet_code += 'fprintf(stderr, "Launching {}\\n");\n'.format( unique_id) tasklet_code += "__ort_check_status(__ort_api->ExecutableKernel_Compute(__ort_kernel_{}));\n".format( unique_id) sdfg.append_init_code( "__ort_check_status(__ort_api->CreateExecutableKernel(" "__ort_session, __ort_context_{id}, /*provider_index=*/{provider_index}, &__ort_kernel_{id}));\n" .format(provider_index=provider_index, id=unique_id)) sdfg.append_init_code( "}} // end setup for context_{}".format(unique_id)) tasklet_code = tasklet_setup_code + tasklet_code + tasklet_cleanup_code tasklet = nd.Tasklet('onnx_code', in_connectors, out_connectors, tasklet_code, language=dace.dtypes.Language.CPP) tasklet.environments = {"ONNXRuntime"} if len(output_copy_required) != 0 or len(input_copy_required) != 0: nsdfg = dace.SDFG("nested_{}".format(unique_id)) nstate = nsdfg.add_state() ntasklet = deepcopy(tasklet) # add a prefix to connectors to prevent shadowing of array names ntasklet.in_connectors = { "_conn_" + k: v for k, v in tasklet.in_connectors.items() } ntasklet.out_connectors = { "_conn_" + k: v for k, v in tasklet.out_connectors.items() } nstate.add_node(ntasklet) for edge, is_input in node.iter_edges(state): parameter_name = edge.dst_conn if is_input else edge.src_conn memlet = edge.data desc = sdfg.arrays[memlet.data] # add the original array original_desc = deepcopy(desc) original_desc.transient = False nsdfg.add_datadesc(parameter_name, original_desc) if not (isinstance(desc, dt.Array) or isinstance(desc, dt.Scalar)): raise ValueError( "Unsupported data type {} connected to an ONNX tasklet" .format(type(desc))) if parameter_name not in (input_copy_required if is_input else output_copy_required): if is_input: access = nstate.add_read(parameter_name) nstate.add_edge(access, None, ntasklet, "_conn_" + parameter_name, nsdfg.get_array_memlet(parameter_name)) else: access = nstate.add_write(parameter_name) nstate.add_edge(ntasklet, "_conn_" + parameter_name, access, None, nsdfg.get_array_memlet(parameter_name)) continue copy_options = input_copy_required[ parameter_name] if is_input else output_copy_required[ parameter_name] # add the copy of the descriptor if 'copy_to_array' in copy_options: copy_desc = dt.Array(shape=[1], dtype=desc.dtype) else: copy_desc = deepcopy(desc) copy_desc.transient = True copy_desc.storage = copy_options['storage'] nsdfg.add_datadesc("copy_" + memlet.data, copy_desc) nmemlet = deepcopy(memlet) nmemlet.data = "copy_" + nmemlet.data if is_input: access = nstate.add_read(parameter_name) access_copy = nstate.add_access("copy_" + memlet.data) nstate.add_edge( access, None, access_copy, None, nsdfg.get_array_memlet("copy_" + memlet.data)) nstate.add_edge(access_copy, None, ntasklet, "_conn_" + parameter_name, nmemlet) else: access = nstate.add_write(parameter_name) access_copy = nstate.add_access("copy_" + memlet.data) nstate.add_edge(ntasklet, "_conn_" + parameter_name, access_copy, None, nmemlet) nstate.add_edge( access_copy, None, access, None, nsdfg.get_array_memlet("copy_" + memlet.data)) return nsdfg else: return tasklet
class MapReduceFusion(pm.Transformation): """ Implements the map-reduce-fusion transformation. Fuses a map with an immediately following reduction, where the array between the map and the reduction is not used anywhere else. """ no_init = Property(dtype=bool, default=False, desc='If enabled, does not create initialization states ' 'for reduce nodes with identity') _tasklet = nodes.Tasklet('_') _tmap_exit = nodes.MapExit(nodes.Map("", [], [])) _in_array = nodes.AccessNode('_') import dace.libraries.standard as stdlib # Avoid import loop _reduce = stdlib.Reduce() _out_array = nodes.AccessNode('_') @staticmethod def expressions(): return [ sdutil.node_path_graph(MapReduceFusion._tasklet, MapReduceFusion._tmap_exit, MapReduceFusion._in_array, MapReduceFusion._reduce, MapReduceFusion._out_array) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, permissive=False): tmap_exit = graph.nodes()[candidate[MapReduceFusion._tmap_exit]] in_array = graph.nodes()[candidate[MapReduceFusion._in_array]] reduce_node = graph.nodes()[candidate[MapReduceFusion._reduce]] tasklet = graph.nodes()[candidate[MapReduceFusion._tasklet]] # Make sure that the array is only accessed by the map and the reduce if any([ src != tmap_exit for src, _, _, _, memlet in graph.in_edges(in_array) ]): return False if any([ dest != reduce_node for _, _, dest, _, memlet in graph.out_edges(in_array) ]): return False tmem = next(e for e in graph.edges_between(tasklet, tmap_exit) if e.data.data == in_array.data).data # Make sure that the transient is not accessed anywhere else # in this state or other states if not permissive and (len([ n for n in graph.nodes() if isinstance(n, nodes.AccessNode) and n.data == in_array.data ]) > 1 or in_array.data in sdfg.shared_transients()): return False # If memlet already has WCR and it is different from reduce node, # do not match if tmem.wcr is not None and tmem.wcr != reduce_node.wcr: return False # Verify that reduction ranges match tasklet map tout_memlet = graph.in_edges(in_array)[0].data rin_memlet = graph.out_edges(in_array)[0].data if tout_memlet.subset != rin_memlet.subset: return False return True @staticmethod def match_to_str(graph, candidate): tasklet = candidate[MapReduceFusion._tasklet] map_exit = candidate[MapReduceFusion._tmap_exit] reduce = candidate[MapReduceFusion._reduce] return ' -> '.join(str(node) for node in [tasklet, map_exit, reduce]) def apply(self, sdfg: SDFG): graph = sdfg.nodes()[self.state_id] tmap_exit = graph.nodes()[self.subgraph[MapReduceFusion._tmap_exit]] in_array = graph.nodes()[self.subgraph[MapReduceFusion._in_array]] reduce_node = graph.nodes()[self.subgraph[MapReduceFusion._reduce]] out_array = graph.nodes()[self.subgraph[MapReduceFusion._out_array]] # Set nodes to remove according to the expression index nodes_to_remove = [in_array] nodes_to_remove.append(reduce_node) memlet_edge = None for edge in graph.in_edges(tmap_exit): if edge.data.data == in_array.data: memlet_edge = edge break if memlet_edge is None: raise RuntimeError('Reduction memlet cannot be None') # Find which indices should be removed from new memlet input_edge = graph.in_edges(reduce_node)[0] axes = reduce_node.axes or list(range(len(input_edge.data.subset))) array_edge = graph.out_edges(reduce_node)[0] # Delete relevant edges and nodes graph.remove_nodes_from(nodes_to_remove) # Delete relevant data descriptors for node in set(nodes_to_remove): if isinstance(node, nodes.AccessNode): # try to delete it try: sdfg.remove_data(node.data) # will raise ValueError if the datadesc is used somewhere else except ValueError: pass # Filter out reduced dimensions from subset filtered_subset = [ dim for i, dim in enumerate(memlet_edge.data.subset) if i not in axes ] if len(filtered_subset) == 0: # Output is a scalar filtered_subset = [(0, 0, 1)] # Modify edge from tasklet to map exit memlet_edge.data.data = out_array.data memlet_edge.data.wcr = reduce_node.wcr memlet_edge.data.subset = type(memlet_edge.data.subset)(filtered_subset) # Add edge from map exit to output array graph.add_edge( memlet_edge.dst, 'OUT_' + memlet_edge.dst_conn[3:], array_edge.dst, array_edge.dst_conn, Memlet.simple(array_edge.data.data, array_edge.data.subset, num_accesses=array_edge.data.num_accesses, wcr_str=reduce_node.wcr)) # Add initialization state as necessary if not self.no_init and reduce_node.identity is not None: init_state = sdfg.add_state_before(graph) init_state.add_mapped_tasklet( 'freduce_init', [('o%d' % i, '%s:%s:%s' % (r[0], r[1] + 1, r[2])) for i, r in enumerate(array_edge.data.subset)], {}, 'out = %s' % reduce_node.identity, { 'out': Memlet.simple( array_edge.data.data, ','.join([ 'o%d' % i for i in range(len(array_edge.data.subset)) ])) }, external_edges=True)
class MapWCRFusion(pm.Transformation): """ Implements the map expanded-reduce fusion transformation. Fuses a map with an immediately following reduction, where the array between the map and the reduction is not used anywhere else, and the reduction is divided to two maps with a WCR, denoting partial reduction. """ _tasklet = nodes.Tasklet('_') _tmap_exit = nodes.MapExit(nodes.Map("", [], [])) _in_array = nodes.AccessNode('_') _rmap_in_entry = nodes.MapEntry(nodes.Map("", [], [])) _rmap_in_tasklet = nodes.Tasklet('_') _rmap_in_cr = nodes.MapExit(nodes.Map("", [], [])) _rmap_out_entry = nodes.MapEntry(nodes.Map("", [], [])) _rmap_out_exit = nodes.MapExit(nodes.Map("", [], [])) _out_array = nodes.AccessNode('_') @staticmethod def expressions(): return [ # Map, then partial reduction of axes sdutil.node_path_graph( MapWCRFusion._tasklet, MapWCRFusion._tmap_exit, MapWCRFusion._in_array, MapWCRFusion._rmap_out_entry, MapWCRFusion._rmap_in_entry, MapWCRFusion._rmap_in_tasklet, MapWCRFusion._rmap_in_cr, MapWCRFusion._rmap_out_exit, MapWCRFusion._out_array) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, permissive=False): tmap_exit = graph.nodes()[candidate[MapWCRFusion._tmap_exit]] in_array = graph.nodes()[candidate[MapWCRFusion._in_array]] rmap_entry = graph.nodes()[candidate[MapWCRFusion._rmap_out_entry]] # Make sure that the array is only accessed by the map and the reduce if any([ src != tmap_exit for src, _, _, _, memlet in graph.in_edges(in_array) ]): return False if any([ dest != rmap_entry for _, _, dest, _, memlet in graph.out_edges(in_array) ]): return False # Make sure that there is a reduction in the second map rmap_cr = graph.nodes()[candidate[MapWCRFusion._rmap_in_cr]] reduce_edge = graph.in_edges(rmap_cr)[0] if reduce_edge.data.wcr is None: return False # Make sure that the transient is not accessed anywhere else # in this state or other states if not permissive and (len([ n for n in graph.nodes() if isinstance(n, nodes.AccessNode) and n.data == in_array.data ]) > 1 or in_array.data in sdfg.shared_transients()): return False # Verify that reduction ranges match tasklet map tout_memlet = graph.in_edges(in_array)[0].data rin_memlet = graph.out_edges(in_array)[0].data if tout_memlet.subset != rin_memlet.subset: return False return True @staticmethod def match_to_str(graph, candidate): tasklet = candidate[MapWCRFusion._tasklet] map_exit = candidate[MapWCRFusion._tmap_exit] reduce = candidate[MapWCRFusion._rmap_in_cr] return ' -> '.join(str(node) for node in [tasklet, map_exit, reduce]) def apply(self, sdfg): graph = sdfg.node(self.state_id) # To apply, collapse the second map and then fuse the two resulting maps map_collapse = MapCollapse( self.sdfg_id, self.state_id, { MapCollapse._outer_map_entry: self.subgraph[MapWCRFusion._rmap_out_entry], MapCollapse._inner_map_entry: self.subgraph[MapWCRFusion._rmap_in_entry] }, 0) map_entry, _ = map_collapse.apply(sdfg) map_fusion = MapFusion( self.sdfg_id, self.state_id, { MapFusion.first_map_exit: self.subgraph[MapWCRFusion._tmap_exit], MapFusion.second_map_entry: graph.node_id(map_entry) }, 0) map_fusion.apply(sdfg)
def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG, **kwargs): node.validate(sdfg, state) for edge in state.in_edges(node): if edge.dst_conn == '_inbuffer': input_edge = edge for edge in state.out_edges(node): if edge.src_conn == '_outbuffer': output_edge = edge input_dims = input_edge.data.subset.size_exact() output_dims = output_edge.data.subset.size_exact() input_data = sdfg.arrays[input_edge.data.data] output_data = sdfg.arrays[output_edge.data.data] # Verify that data is on the GPU if input_data.storage is not dtypes.StorageType.GPU_Global: raise ValueError('Input of NCCL Send must reside ' ' in global GPU memory.') if output_data.storage is not dtypes.StorageType.GPU_Global: raise ValueError('Output of NCCL Recv must reside ' ' in global GPU memory.') root = node.root rootstr = str(root) for fs in root.free_symbols: if fs.name in sdfg.arrays: sdfg.arrays[fs.name].lifetime = dtypes.AllocationLifetime.SDFG if fs.name in sdfg.parent_sdfg.arrays: sdfg.parent_sdfg.arrays[ fs.name].lifetime = dtypes.AllocationLifetime.SDFG redtype = node.reduction_type redtype = nutil.NCCL_SUPPORTED_OPERATIONS[redtype] wcr_str = str(redtype) wcr_str = wcr_str[wcr_str.find('.') + 1:] # Skip "NcclReductionType." nccl_dtype_str = nutil.Nccl_dtypes(input_data.dtype.base_type) count_str = "*".join(str(e) for e in input_dims) if input_data.dtype.veclen > 1: raise (NotImplementedError) code = f"""ncclReduce(_inbuffer, _outbuffer, {count_str}, {nccl_dtype_str}, {wcr_str}, {rootstr}, __state->ncclCommunicators->at(__dace_cuda_device), __dace_current_stream)""" if Config.get('compiler', 'build_type') == 'Debug': code = '''DACE_NCCL_CHECK(''' + code + ''');\n''' else: code = code + ''';\n''' if Config.get_bool('debugprint'): code = ( f'''printf("{str(node)}: begin; dev,peer: %d, %d\\n", __dace_cuda_device, {rootstr});\n''' + code + f'''printf("{str(node)}: end; dev,peer: %d, %d\\n\\n", __dace_cuda_device, {rootstr});\n''' ) code += """\ncudaStreamSynchronize(__dace_current_stream);""" tasklet = nodes.Tasklet(node.name + "_" + wcr_str, node.in_connectors, node.out_connectors, code, location=node.location, language=dtypes.Language.CPP, library_expansion_symbols=set( map(str, root.free_symbols))) return tasklet