class AccumulateTransient(pattern_matching.Transformation): """ Implements the AccumulateTransient transformation, which adds transient stream and data nodes between nested maps that lead to a stream. The transient data nodes then act as a local accumulator. """ _tasklet = nodes.Tasklet('_') _map_exit = nodes.MapExit(nodes.Map("", [], [])) _outer_map_exit = nodes.MapExit(nodes.Map("", [], [])) @staticmethod def expressions(): return [ nxutil.node_path_graph(StreamTransient._tasklet, StreamTransient._map_exit, StreamTransient._outer_map_exit) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): tasklet = graph.nodes()[candidate[StreamTransient._tasklet]] map_exit = graph.nodes()[candidate[StreamTransient._map_exit]] # Check if there is a streaming output for _src, _, dest, _, memlet in graph.out_edges(tasklet): if memlet.wcr is not None and dest == map_exit: return True return False @staticmethod def match_to_str(graph, candidate): tasklet = candidate[StreamTransient._tasklet] map_exit = candidate[StreamTransient._map_exit] outer_map_exit = candidate[StreamTransient._outer_map_exit] return ' -> '.join( str(node) for node in [tasklet, map_exit, outer_map_exit]) def apply(self, sdfg): graph = sdfg.nodes()[self.state_id] tasklet = graph.nodes()[self.subgraph[StreamTransient._tasklet]] map_exit = graph.nodes()[self.subgraph[StreamTransient._map_exit]] outer_map_exit = graph.nodes()[self.subgraph[ StreamTransient._outer_map_exit]] memlet = None edge = None for e in graph.out_edges(tasklet): memlet = e.data # TODO: What if there's more than one? if e.dst == map_exit and e.data.wcr is not None: break out_memlet = None for e in graph.out_edges(map_exit): out_memlet = e.data if out_memlet.data == memlet.data: edge = e break dataname = memlet.data # Create a new node with the same size as the output newdata = sdfg.add_array('trans_' + dataname, sdfg.arrays[memlet.data].shape, sdfg.arrays[memlet.data].dtype, transient=True) dnode = nodes.AccessNode('trans_' + dataname) to_data_mm = copy.deepcopy(memlet) to_data_mm.data = dnode.data to_data_mm.num_accesses = memlet.num_elements() to_exit_mm = copy.deepcopy(out_memlet) to_exit_mm.num_accesses = out_memlet.num_elements() memlet.data = dnode.data # Reconnect, assuming one edge to the stream graph.remove_edge(edge) graph.add_edge(map_exit, edge.src_conn, dnode, None, to_data_mm) graph.add_edge(dnode, None, outer_map_exit, edge.dst_conn, to_exit_mm) return def modifies_graph(self): return True
class StreamTransient(pattern_matching.Transformation): """ Implements the StreamTransient transformation, which adds a transient stream node between nested maps that lead to a stream. The transient then acts as a local buffer. """ _tasklet = nodes.Tasklet('_') _map_exit = nodes.MapExit(nodes.Map("", [], [])) _outer_map_exit = nodes.MapExit(nodes.Map("", [], [])) @staticmethod def expressions(): return [ nxutil.node_path_graph(StreamTransient._tasklet, StreamTransient._map_exit, StreamTransient._outer_map_exit) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): map_exit = graph.nodes()[candidate[StreamTransient._map_exit]] outer_map_exit = graph.nodes()[candidate[ StreamTransient._outer_map_exit]] # Check if there is a streaming output for _src, _, dest, _, memlet in graph.out_edges(map_exit): if isinstance(sdfg.arrays[memlet.data], data.Stream) and dest == outer_map_exit: return True return False @staticmethod def match_to_str(graph, candidate): tasklet = candidate[StreamTransient._tasklet] map_exit = candidate[StreamTransient._map_exit] outer_map_exit = candidate[StreamTransient._outer_map_exit] return ' -> '.join( str(node) for node in [tasklet, map_exit, outer_map_exit]) def apply(self, sdfg): graph = sdfg.nodes()[self.state_id] tasklet = graph.nodes()[self.subgraph[StreamTransient._tasklet]] map_exit = graph.nodes()[self.subgraph[StreamTransient._map_exit]] outer_map_exit = graph.nodes()[self.subgraph[ StreamTransient._outer_map_exit]] memlet = None edge = None for e in graph.out_edges(map_exit): memlet = e.data # TODO: What if there's more than one? if e.dst == outer_map_exit and isinstance(sdfg.arrays[memlet.data], data.Stream): edge = e break tasklet_memlet = None for e in graph.out_edges(tasklet): tasklet_memlet = e.data if tasklet_memlet.data == memlet.data: break bbox = map_exit.map.range.bounding_box_size() bbox_approx = [symbolic.overapproximate(dim) for dim in bbox] dataname = memlet.data # Create the new node: Temporary stream and an access node newstream = sdfg.add_stream( 'tile_' + dataname, sdfg.arrays[memlet.data].dtype, 1, bbox_approx[0], [1], transient=True, ) snode = nodes.AccessNode('tile_' + dataname) to_stream_mm = copy.deepcopy(memlet) to_stream_mm.data = snode.data tasklet_memlet.data = snode.data # Reconnect, assuming one edge to the stream graph.remove_edge(edge) graph.add_edge(map_exit, None, snode, None, to_stream_mm) graph.add_edge(snode, None, outer_map_exit, None, memlet) return def modifies_graph(self): return True
class Vectorization(pattern_matching.Transformation): """ Implements the vectorization transformation. Vectorization matches when all the input and output memlets of a tasklet inside a map access the inner-most loop variable in their last dimension. The transformation changes the step of the inner-most loop to be equal to the length of the vector and vectorizes the memlets. """ vector_len = Property(desc="Vector length", dtype=int, default=4) _map_entry = nodes.MapEntry(nodes.Map("", [], [])) _tasklet = nodes.Tasklet('_') _map_exit = nodes.MapExit(nodes.Map("", [], [])) @staticmethod def expressions(): return [ nxutil.node_path_graph(Vectorization._map_entry, Vectorization._tasklet, Vectorization._map_exit) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): map_entry = graph.nodes()[candidate[Vectorization._map_entry]] tasklet = graph.nodes()[candidate[Vectorization._tasklet]] param = symbolic.pystr_to_symbolic(map_entry.map.params[-1]) found = False dtype = None # Check if all edges, adjacent to the tasklet, # use the parameter in their last dimension. for _src, _, _dest, _, memlet in graph.all_edges(tasklet): # Cases that do not matter for vectorization if isinstance(sdfg.arrays[memlet.data], data.Stream): continue if memlet.wcr is not None: continue try: subset = memlet.subset veclen = memlet.veclen except AttributeError: return False if subset is None: return False try: if veclen > symbolic.pystr_to_symbolic('1'): return False for idx, expr in enumerate(subset): if isinstance(expr, tuple): for ex in expr: ex = symbolic.pystr_to_symbolic(ex) symbols = ex.free_symbols if param in symbols: if idx == subset.dims() - 1: found = True else: return False else: expr = symbolic.pystr_to_symbolic(expr) symbols = expr.free_symbols if param in symbols: if idx == subset.dims() - 1: found = True else: return False except TypeError: # cannot determine truth value of Relational return False return found @staticmethod def match_to_str(graph, candidate): map_entry = candidate[Vectorization._map_entry] tasklet = candidate[Vectorization._tasklet] map_exit = candidate[Vectorization._map_exit] return ' -> '.join( str(node) for node in [map_entry, tasklet, map_exit]) def apply(self, sdfg): graph = sdfg.nodes()[self.state_id] map_entry = graph.nodes()[self.subgraph[Vectorization._map_entry]] tasklet = graph.nodes()[self.subgraph[Vectorization._tasklet]] map_exit = graph.nodes()[self.subgraph[Vectorization._map_exit]] param = symbolic.pystr_to_symbolic(map_entry.map.params[-1]) # Create new vector size. vector_size = self.vector_len # Change the step of the inner-most dimension. dim_from, dim_to, _dim_step = map_entry.map.range[-1] map_entry.map.range[-1] = (dim_from, dim_to, vector_size) # Vectorize memlets adjacent to the tasklet. for _src, _, _dest, _, memlet in graph.all_edges(tasklet): subset = memlet.subset lastindex = memlet.subset[-1] if isinstance(lastindex, tuple): symbols = set() for indd in lastindex: symbols.update( symbolic.pystr_to_symbolic(indd).free_symbols) else: symbols = symbolic.pystr_to_symbolic( memlet.subset[-1]).free_symbols if param in symbols: try: memlet.veclen = vector_size except AttributeError: return # TODO: Create new map for non-vectorizable part. return def modifies_graph(self): return True
def __init__(self, *args, **kwargs): self.entry = nodes.EntryNode() self.tasklet = nodes.Tasklet('_') self.exit = nodes.ExitNode() self.pairs = None super().__init__(*args, **kwargs)
class MapWCRFusion(pm.Transformation): """ Implements the map expanded-reduce fusion transformation. Fuses a map with an immediately following reduction, where the array between the map and the reduction is not used anywhere else, and the reduction is divided to two maps with a WCR, denoting partial reduction. """ _tasklet = nodes.Tasklet('_') _tmap_exit = nodes.MapExit(nodes.Map("", [], [])) _in_array = nodes.AccessNode('_') _rmap_in_entry = nodes.MapEntry(nodes.Map("", [], [])) _rmap_in_tasklet = nodes.Tasklet('_') _rmap_in_cr = nodes.MapExit(nodes.Map("", [], [])) _rmap_out_entry = nodes.MapEntry(nodes.Map("", [], [])) _rmap_out_exit = nodes.MapExit(nodes.Map("", [], [])) _out_array = nodes.AccessNode('_') @staticmethod def expressions(): return [ # Map, then partial reduction of axes nxutil.node_path_graph( MapWCRFusion._tasklet, MapWCRFusion._tmap_exit, MapWCRFusion._in_array, MapWCRFusion._rmap_out_entry, MapWCRFusion._rmap_in_entry, MapWCRFusion._rmap_in_tasklet, MapWCRFusion._rmap_in_cr, MapWCRFusion._rmap_out_exit, MapWCRFusion._out_array) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): tmap_exit = graph.nodes()[candidate[MapWCRFusion._tmap_exit]] in_array = graph.nodes()[candidate[MapWCRFusion._in_array]] rmap_entry = graph.nodes()[candidate[MapWCRFusion._rmap_out_entry]] # Make sure that the array is only accessed by the map and the reduce if any([ src != tmap_exit for src, _, _, _, memlet in graph.in_edges(in_array) ]): return False if any([ dest != rmap_entry for _, _, dest, _, memlet in graph.out_edges(in_array) ]): return False # Make sure that there is a reduction in the second map rmap_cr = graph.nodes()[candidate[MapWCRFusion._rmap_in_cr]] reduce_edge = graph.in_edges(rmap_cr)[0] if reduce_edge.data.wcr is None: return False # (strict) Make sure that the transient is not accessed anywhere else # in this state or other states if strict and (len([ n for n in graph.nodes() if isinstance(n, nodes.AccessNode) and n.data == in_array.data ]) > 1 or in_array.data in sdfg.shared_transients()): return False # Verify that reduction ranges match tasklet map tout_memlet = graph.in_edges(in_array)[0].data rin_memlet = graph.out_edges(in_array)[0].data if tout_memlet.subset != rin_memlet.subset: return False return True @staticmethod def match_to_str(graph, candidate): tasklet = candidate[MapWCRFusion._tasklet] map_exit = candidate[MapWCRFusion._tmap_exit] reduce = candidate[MapWCRFusion._rmap_in_cr] return ' -> '.join(str(node) for node in [tasklet, map_exit, reduce]) def apply(self, sdfg): graph = sdfg.node(self.state_id) # To apply, collapse the second map and then fuse the two resulting maps map_collapse = MapCollapse( self.sdfg_id, self.state_id, { MapCollapse._outer_map_entry: self.subgraph[MapWCRFusion._rmap_out_entry], MapCollapse._inner_map_entry: self.subgraph[MapWCRFusion._rmap_in_entry] }, 0) map_entry, _ = map_collapse.apply(sdfg) map_fusion = MapFusion( self.sdfg_id, self.state_id, { MapFusion._first_map_exit: self.subgraph[MapWCRFusion._tmap_exit], MapFusion._second_map_entry: graph.node_id(map_entry) }, 0) map_fusion.apply(sdfg)
def __init__(self, *args, **kwargs): self._entry = nodes.EntryNode() self._tasklet = nodes.Tasklet('_') self._exit = nodes.ExitNode() super().__init__(*args, **kwargs)
def add_indirection_subgraph(sdfg, graph, src, dst, memlet): """ Replaces the specified edge in the specified graph with a subgraph that implements indirection without nested AST memlet objects. """ if not isinstance(memlet, astnodes._Memlet): raise TypeError("Expected memlet to be astnodes._Memlet") indirect_inputs = set() indirect_outputs = set() # Scheme for multi-array indirection: # 1. look for all arrays and accesses, create set of arrays+indices # from which the index memlets will be constructed from # 2. each separate array creates a memlet, of which num_accesses = len(set) # 3. one indirection tasklet receives them all + original array and # produces the right output index/range memlet ######################### # Step 1 accesses = OrderedDict() newsubset = dcpy(memlet.subset) for dimidx, dim in enumerate(memlet.subset): # Range/Index disambiguation direct_assignment = False if not isinstance(dim, tuple): dim = [dim] direct_assignment = True for i, r in enumerate(dim): for expr in sympy.preorder_traversal(r): if symbolic.is_sympy_userfunction(expr): fname = expr.func.__name__ if fname not in accesses: accesses[fname] = [] # Replace function with symbol (memlet local name to-be) if expr.args in accesses[fname]: aindex = accesses[fname].index(expr.args) toreplace = 'index_' + fname + '_' + str(aindex) else: accesses[fname].append(expr.args) toreplace = 'index_' + fname + '_' + str( len(accesses[fname]) - 1) if direct_assignment: newsubset[dimidx] = r.subs(expr, toreplace) else: newsubset[dimidx][i] = r.subs(expr, toreplace) ######################### # Step 2 ind_inputs = {'__ind_' + memlet.local_name} ind_outputs = {'lookup'} # Add accesses to inputs for arrname, arr_accesses in accesses.items(): for i in range(len(arr_accesses)): ind_inputs.add('index_%s_%d' % (arrname, i)) tasklet = nd.Tasklet("Indirection", ind_inputs, ind_outputs) input_index_memlets = [] for arrname, arr_accesses in accesses.items(): arr = memlet.otherdeps[arrname] for i, access in enumerate(arr_accesses): # Memlet to load the indirection index indexMemlet = Memlet(arrname, 1, sbs.Indices(list(access)), 1) input_index_memlets.append(indexMemlet) graph.add_edge(src, None, tasklet, "index_%s_%d" % (arrname, i), indexMemlet) ######################### # Step 3 # Create new tasklet that will perform the indirection indirection_ast = ast.parse("lookup = {arr}[{index}]".format( arr='__ind_' + memlet.local_name, index=', '.join([symbolic.symstr(s) for s in newsubset]))) # Conserve line number of original indirection code tasklet.code = ast.copy_location(indirection_ast.body[0], memlet.ast) # Create transient variable to trigger the indirected load if memlet.num_accesses == 1: storage = sdfg.add_scalar('__' + memlet.local_name + '_value', memlet.data.dtype, transient=True) else: storage = sdfg.add_array('__' + memlet.local_name + '_value', memlet.data.dtype, storage=types.StorageType.Default, transient=True, shape=memlet.bounding_box_size()) indirectRange = sbs.Range([(0, s - 1, 1) for s in storage.shape]) dataNode = nd.AccessNode('__' + memlet.local_name + '_value') # Create memlet that depends on the full array that we look up in fullRange = sbs.Range([(0, s - 1, 1) for s in memlet.data.shape]) fullMemlet = Memlet(memlet.dataname, memlet.num_accesses, fullRange, memlet.veclen) graph.add_edge(src, None, tasklet, '__ind_' + memlet.local_name, fullMemlet) # Memlet to store the final value into the transient, and to load it into # the tasklet that needs it indirectMemlet = Memlet('__' + memlet.local_name + '_value', memlet.num_accesses, indirectRange, memlet.veclen) graph.add_edge(tasklet, 'lookup', dataNode, None, indirectMemlet) valueMemlet = Memlet('__' + memlet.local_name + '_value', memlet.num_accesses, indirectRange, memlet.veclen) graph.add_edge(dataNode, None, dst, memlet.local_name, valueMemlet)
class MapReduceFusion(pm.Transformation): """ Implements the map-reduce-fusion transformation. Fuses a map with an immediately following reduction, where the array between the map and the reduction is not used anywhere else. """ _tasklet = nodes.Tasklet('_') _tmap_exit = nodes.MapExit(nodes.Map("", [], [])) _in_array = nodes.AccessNode('_') _reduce = nodes.Reduce('lambda: None', None) _out_array = nodes.AccessNode('_') @staticmethod def expressions(): return [ nxutil.node_path_graph(MapReduceFusion._tasklet, MapReduceFusion._tmap_exit, MapReduceFusion._in_array, MapReduceFusion._reduce, MapReduceFusion._out_array) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): tmap_exit = graph.nodes()[candidate[MapReduceFusion._tmap_exit]] in_array = graph.nodes()[candidate[MapReduceFusion._in_array]] reduce_node = graph.nodes()[candidate[MapReduceFusion._reduce]] tasklet = graph.nodes()[candidate[MapReduceFusion._tasklet]] # Make sure that the array is only accessed by the map and the reduce if any([ src != tmap_exit for src, _, _, _, memlet in graph.in_edges(in_array) ]): return False if any([ dest != reduce_node for _, _, dest, _, memlet in graph.out_edges(in_array) ]): return False tmem = next(e for e in graph.edges_between(tasklet, tmap_exit) if e.data.data == in_array.data).data # (strict) Make sure that the transient is not accessed anywhere else # in this state or other states if strict and (len([ n for n in graph.nodes() if isinstance(n, nodes.AccessNode) and n.data == in_array.data ]) > 1 or in_array.data in sdfg.shared_transients()): return False # If memlet already has WCR and it is different from reduce node, # do not match if tmem.wcr is not None and tmem.wcr != reduce_node.wcr: return False # Verify that reduction ranges match tasklet map tout_memlet = graph.in_edges(in_array)[0].data rin_memlet = graph.out_edges(in_array)[0].data if tout_memlet.subset != rin_memlet.subset: return False return True @staticmethod def match_to_str(graph, candidate): tasklet = candidate[MapReduceFusion._tasklet] map_exit = candidate[MapReduceFusion._tmap_exit] reduce = candidate[MapReduceFusion._reduce] return ' -> '.join(str(node) for node in [tasklet, map_exit, reduce]) def apply(self, sdfg): graph = sdfg.nodes()[self.state_id] tmap_exit = graph.nodes()[self.subgraph[MapReduceFusion._tmap_exit]] in_array = graph.nodes()[self.subgraph[MapReduceFusion._in_array]] reduce_node = graph.nodes()[self.subgraph[MapReduceFusion._reduce]] out_array = graph.nodes()[self.subgraph[MapReduceFusion._out_array]] # Set nodes to remove according to the expression index nodes_to_remove = [in_array] nodes_to_remove.append(reduce_node) memlet_edge = None for edge in graph.in_edges(tmap_exit): if edge.data.data == in_array.data: memlet_edge = edge break if memlet_edge is None: raise RuntimeError('Reduction memlet cannot be None') # Find which indices should be removed from new memlet input_edge = graph.in_edges(reduce_node)[0] axes = reduce_node.axes or list(range(input_edge.data.subset)) array_edge = graph.out_edges(reduce_node)[0] # Delete relevant edges and nodes graph.remove_nodes_from(nodes_to_remove) # Filter out reduced dimensions from subset filtered_subset = [ dim for i, dim in enumerate(memlet_edge.data.subset) if i not in axes ] if len(filtered_subset) == 0: # Output is a scalar filtered_subset = [0] # Modify edge from tasklet to map exit memlet_edge.data.data = out_array.data memlet_edge.data.wcr = reduce_node.wcr memlet_edge.data.wcr_identity = reduce_node.identity memlet_edge.data.subset = type( memlet_edge.data.subset)(filtered_subset) # Add edge from map exit to output array graph.add_edge( memlet_edge.dst, 'OUT_' + memlet_edge.dst_conn[3:], array_edge.dst, array_edge.dst_conn, Memlet(array_edge.data.data, array_edge.data.num_accesses, array_edge.data.subset, array_edge.data.veclen, reduce_node.wcr, reduce_node.identity))
def _build_dataflow_graph_recurse(sdfg, state, primitives, modules, superEntry, super_exit): # Array of pairs (exit node, memlet) exit_nodes = [] if len(primitives) == 0: # Inject empty tasklets into empty states primitives = [astnodes._EmptyTaskletNode("Empty Tasklet", None)] for prim in primitives: label = prim.name # Expand node to get entry and exit points if isinstance(prim, astnodes._MapNode): if len(prim.children) == 0: raise ValueError("Map node expected to have children") mapNode = nd.Map(label, prim.params, prim.range, is_async=prim.is_async) # Add connectors for inputs that exist as array nodes entry = nd.MapEntry( mapNode, _get_input_symbols(prim.inputs, prim.range.free_symbols)) exit = nd.MapExit(mapNode) elif isinstance(prim, astnodes._ConsumeNode): if len(prim.children) == 0: raise ValueError("Consume node expected to have children") consumeNode = nd.Consume(label, (prim.params[1], prim.num_pes), prim.condition) entry = nd.ConsumeEntry(consumeNode) exit = nd.ConsumeExit(consumeNode) elif isinstance(prim, astnodes._ReduceNode): rednode = nd.Reduce(prim.ast, prim.axes, prim.identity) state.add_node(rednode) entry = rednode exit = rednode elif isinstance(prim, astnodes._TaskletNode): if isinstance(prim, astnodes._EmptyTaskletNode): tasklet = nd.EmptyTasklet(prim.name) else: # Remove memlets from tasklet AST if prim.language == types.Language.Python: clean_code = MemletRemover().visit(prim.ast) clean_code = ModuleInliner(modules).visit(clean_code) else: # Use external code from tasklet definition if prim.extcode is None: raise SyntaxError("Cannot define an intrinsic " "tasklet without an implementation") clean_code = prim.extcode tasklet = nd.Tasklet( prim.name, set(prim.inputs.keys()), set(prim.outputs.keys()), code=clean_code, language=prim.language, code_global=prim.gcode) # TODO: location=prim.location # Need to add the tasklet in case we're in an empty state, where no # edge will be drawn to it state.add_node(tasklet) entry = tasklet exit = tasklet elif isinstance(prim, astnodes._NestedSDFGNode): prim.sdfg.parent = state prim.sdfg._parent_sdfg = sdfg prim.sdfg.update_sdfg_list([]) nsdfg = nd.NestedSDFG(prim.name, prim.sdfg, set(prim.inputs.keys()), set(prim.outputs.keys())) state.add_node(nsdfg) entry = nsdfg exit = nsdfg elif isinstance(prim, astnodes._ProgramNode): return elif isinstance(prim, astnodes._ControlFlowNode): continue else: raise TypeError("Node type not implemented: " + str(prim.__class__)) # Add incoming edges for varname, memlet in prim.inputs.items(): arr = memlet.dataname if (prim.parent is not None and memlet.dataname in prim.parent.transients.keys()): node = input_node_for_array(state, memlet.dataname) # Add incoming edge into transient as well # FIXME: A bit hacked? if arr in prim.parent.inputs: astmem = prim.parent.inputs[arr] _add_astmemlet_edge(sdfg, state, superEntry, None, node, None, astmem) # Remove local name from incoming edge to parent prim.parent.inputs[arr].local_name = None elif superEntry: node = superEntry else: node = input_node_for_array(state, memlet.dataname) # Destination connector inference # Connected to a tasklet or a nested SDFG dst_conn = (memlet.local_name if isinstance(entry, nd.CodeNode) else None) # Connected to a scope as part of its range if str(varname).startswith('__DACEIN_'): dst_conn = str(varname)[9:] # Handle special case of consume input stream if (isinstance(entry, nd.ConsumeEntry) and memlet.data == prim.stream): dst_conn = 'IN_stream' # If a memlet that covers this input already exists, skip # generating this one; otherwise replace memlet with ours skip_incoming_edge = False remove_edge = None for e in state.edges_between(node, entry): if e.data.data != memlet.dataname or dst_conn != e.dst_conn: continue if e.data.subset.covers(memlet.subset): skip_incoming_edge = True break elif memlet.subset.covers(e.data.subset): remove_edge = e break else: print('WARNING: Performing bounding-box union on', memlet.subset, 'and', e.data.subset, '(in)') e.data.subset = sbs.bounding_box_union( e.data.subset, memlet.subset) e.data.num_accesses += memlet.num_accesses skip_incoming_edge = True break if remove_edge is not None: state.remove_edge(remove_edge) if skip_incoming_edge == False: _add_astmemlet_edge(sdfg, state, node, None, entry, dst_conn, memlet) # If there are no inputs, generate a dummy edge if superEntry and len(prim.inputs) == 0: state.add_edge(superEntry, None, entry, None, EmptyMemlet()) if len(prim.children) > 0: # Recurse inner_outputs = _build_dataflow_graph_recurse( sdfg, state, prim.children, modules, entry, exit) # Infer output node for each memlet for i, (out_src, mem) in enumerate(inner_outputs): # If there is no such array in this primitive's outputs, # it's an external array (e.g., a map in a map). In this case, # connect to the exit node if mem.dataname in prim.outputs: inner_outputs[i] = (out_src, prim.outputs[mem.dataname]) else: inner_outputs[i] = (out_src, mem) else: inner_outputs = [(exit, mem) for mem in prim.outputs.values()] # Add outgoing edges for out_src, astmem in inner_outputs: data = astmem.data dataname = astmem.dataname # If WCR is not none, it needs to be handled in the code. Check for # this after, as we only expect it for one distinct case wcr_was_handled = astmem.wcr is None # TODO: This is convoluted. We should find a more readable # way of connecting the outgoing edges. if super_exit is None: # Assert that we're in a top-level node if ((not isinstance(prim.parent, astnodes._ProgramNode)) and (not isinstance(prim.parent, astnodes._ControlFlowNode))): raise RuntimeError("Expected to be at the top node") # Looks hacky src_conn = (astmem.local_name if isinstance( out_src, (nd.Tasklet, nd.NestedSDFG)) else None) # Here we just need to connect memlets directly to their # respective data nodes out_tgt = output_node_for_array(state, astmem.dataname) # If a memlet that covers this outuput already exists, skip # generating this one; otherwise replace memlet with ours skip_outgoing_edge = False remove_edge = None for e in state.edges_between(out_src, out_tgt): if e.data.data != astmem.dataname or src_conn != e.src_conn: continue if e.data.subset.covers(astmem.subset): skip_outgoing_edge = True break elif astmem.subset.covers(e.data.subset): remove_edge = e break else: print('WARNING: Performing bounding-box union on', astmem.subset, 'and', e.data.subset, '(out)') e.data.subset = sbs.bounding_box_union( e.data.subset, astmem.subset) e.data.num_accesses += astmem.num_accesses skip_outgoing_edge = True break if skip_outgoing_edge == True: continue if remove_edge is not None: state.remove_edge(remove_edge) _add_astmemlet_edge(sdfg, state, out_src, src_conn, out_tgt, None, astmem, wcr=astmem.wcr, wcr_identity=astmem.wcr_identity) wcr_was_handled = (True if astmem.wcr is not None else wcr_was_handled) # If the program defines another output, connect it too. # This refers to the case where we have streams, which # must define an input and output, and sometimes this output # is defined in pdp.outputs if (isinstance(out_tgt, nd.AccessNode) and isinstance(out_tgt.desc(sdfg), dt.Stream)): try: stream_memlet = next( v for k, v in prim.parent.outputs.items() if k == out_tgt.data) stream_output = output_node_for_array( state, stream_memlet.dataname) _add_astmemlet_edge(sdfg, state, out_tgt, None, stream_output, None, stream_memlet) except StopIteration: # Stream output not found, skip pass else: # We're in a nest if isinstance(prim, astnodes._ScopeNode): # We're a map or a consume node, that needs to connect our # exit to either an array or to the super_exit if data.transient and dataname in prim.parent.transients: # Connect the exit directly out_tgt = output_node_for_array(state, data.dataname) _add_astmemlet_edge(sdfg, state, out_src, None, out_tgt, None, astmem) else: # This is either a transient defined in an outer scope, # or an I/O array, so redirect thruogh the exit node _add_astmemlet_edge(sdfg, state, out_src, None, super_exit, None, astmem) # Instruct outer recursion layer to continue the route exit_nodes.append((super_exit, astmem)) elif isinstance( prim, (astnodes._TaskletNode, astnodes._NestedSDFGNode)): # We're a tasklet, and need to connect either to the exit # if the array is I/O or is defined in a scope further out, # or directly to the transient if it's defined locally if dataname in prim.parent.transients: # This is a local transient variable, so connect to it # directly out_tgt = output_node_for_array(state, data.dataname) _add_astmemlet_edge(sdfg, state, out_src, astmem.local_name, out_tgt, None, astmem) else: # This is an I/O array, or an outer level transient, so # redirect through the exit node _add_astmemlet_edge(sdfg, state, out_src, astmem.local_name, super_exit, None, astmem, wcr=astmem.wcr, wcr_identity=astmem.wcr_identity) exit_nodes.append((super_exit, astmem)) if astmem.wcr is not None: wcr_was_handled = True # Sanity check else: raise TypeError("Unexpected node type: {}".format( type(out_src).__name__)) if not wcr_was_handled and not isinstance(prim, astnodes._ScopeNode): raise RuntimeError("Detected unhandled WCR for primitive '{}' " "of type {}. WCR is only expected for " "tasklets in a map/consume scope.".format( prim.name, type(prim).__name__)) return exit_nodes
class MapReduceFusion(pm.Transformation): """ Implements the map-reduce-fusion transformation. Fuses a map with an immediately following reduction, where the array between the map and the reduction is not used anywhere else. """ _tasklet = nodes.Tasklet('_') _tmap_exit = nodes.MapExit(nodes.Map("", [], [])) _in_array = nodes.AccessNode('_') _rmap_in_entry = nodes.MapEntry(nodes.Map("", [], [])) _rmap_in_tasklet = nodes.Tasklet('_') _rmap_in_cr = nodes.MapExit(nodes.Map("", [], [])) _rmap_out_entry = nodes.MapEntry(nodes.Map("", [], [])) _rmap_out_exit = nodes.MapExit(nodes.Map("", [], [])) _out_array = nodes.AccessNode('_') _reduce = nodes.Reduce('lambda: None', None) @staticmethod def expressions(): return [ # Map, then reduce of all axes nxutil.node_path_graph( MapReduceFusion._tasklet, MapReduceFusion._tmap_exit, MapReduceFusion._in_array, MapReduceFusion._rmap_in_entry, MapReduceFusion._rmap_in_tasklet, MapReduceFusion._rmap_in_cr, MapReduceFusion._out_array), # Map, then partial reduction of axes nxutil.node_path_graph( MapReduceFusion._tasklet, MapReduceFusion._tmap_exit, MapReduceFusion._in_array, MapReduceFusion._rmap_out_entry, MapReduceFusion._rmap_in_entry, MapReduceFusion._rmap_in_tasklet, MapReduceFusion._rmap_in_cr, MapReduceFusion._rmap_out_exit, MapReduceFusion._out_array), # Map, then reduce node nxutil.node_path_graph( MapReduceFusion._tasklet, MapReduceFusion._tmap_exit, MapReduceFusion._in_array, MapReduceFusion._reduce, MapReduceFusion._out_array) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): tmap_exit = graph.nodes()[candidate[MapReduceFusion._tmap_exit]] in_array = graph.nodes()[candidate[MapReduceFusion._in_array]] if expr_index == 0: # Reduce without outer map rmap_entry = graph.nodes()[candidate[ MapReduceFusion._rmap_in_entry]] # rmap_in_entry = rmap_entry elif expr_index == 1: # Reduce with outer map rmap_entry = graph.nodes()[candidate[ MapReduceFusion._rmap_out_entry]] # rmap_in_entry = graph.nodes()[candidate[ # MapReduceFusion._rmap_in_entry]] else: # Reduce node rmap_entry = graph.nodes()[candidate[MapReduceFusion._reduce]] # Make sure that the array is only accessed by the map and the reduce if any([ src != tmap_exit for src, _, _, _, memlet in graph.in_edges(in_array) ]): return False if any([ dest != rmap_entry for _, _, dest, _, memlet in graph.out_edges(in_array) ]): return False # Make sure that there is a reduction in the second map if expr_index < 2: rmap_cr = graph.nodes()[candidate[MapReduceFusion._rmap_in_cr]] reduce_edge = graph.in_edges(rmap_cr)[0] if reduce_edge.data.wcr is None: return False # Make sure that the transient is not accessed by other states # if garr.get_unique_name() in cgen_state.sdfg.shared_transients(): # return False # reduce_inarr = reduce.in_array # reduce_outarr = reduce.out_array # reduce_inslice = reduce.inslice # reduce_outslice = reduce.outslice # insize = cgen_state.var_sizes[reduce_inarr] # outsize = cgen_state.var_sizes[reduce_outarr] # Currently only supports full-range arrays # TODO(later): Support fusion of partial reductions and refactor slice/subarray handling #if not nxutil.fullrange(reduce_inslice, insize) or \ # not nxutil.fullrange(reduce_outslice, outsize): # return False # Verify acceses from tasklet through MapExit #already_found = False #for _src, _, _dest, _, memlet in graph.in_edges(map_exit): # if isinstance(memlet.subset, subsets.Indices): # # Make sure that only one value is reduced at a time # if memlet.data == in_array.desc: # if already_found: # return False # already_found = True ## Find axes after reduction #indims = len(reduce.inslice) #axis_after_reduce = [None] * indims #ctr = 0 #for i in range(indims): # if reduce.axes is not None and i in reduce.axes: # axis_after_reduce[i] = None # else: # axis_after_reduce[i] = ctr # ctr += 1 ## Match map ranges with reduce ranges #curaxis = 0 #for dim, var in enumerate(memlet.subset): # # Make sure that indices are direct symbols # #if not isinstance(symbolic.pystr_to_symbolic(var), sympy.Symbol): # # return False # perm = None # for i, mapvar in enumerate(map_exit.map.params): # if symbolic.pystr_to_symbolic(mapvar) == var: # perm = i # break # if perm is None: # If symbol is not found in map range # return False # # Make sure that map ranges match output slice after reduction # map_range = map_exit.map.range[perm] # if map_range[0] != 0: # return False # Disallow start from middle # if map_range[2] is not None and map_range[2] != 1: # return False # Disallow skip # if reduce.axes is not None and dim not in reduce.axes: # if map_range[1] != symbolic.pystr_to_symbolic( # reduce.outslice[axis_after_reduce[dim]][1]): # return False # Range check (output axis) # else: # if map_range[1] != symbolic.pystr_to_symbolic(reduce.inslice[dim][1]): # return False # Range check (reduction axis) # Verify that reduction ranges match tasklet map tout_memlet = graph.in_edges(in_array)[0].data rin_memlet = graph.out_edges(in_array)[0].data if tout_memlet.subset != rin_memlet.subset: return False return True @staticmethod def match_to_str(graph, candidate): tasklet = candidate[MapReduceFusion._tasklet] map_exit = candidate[MapReduceFusion._tmap_exit] if len(candidate) == 5: # Expression 2 reduce = candidate[MapReduceFusion._reduce] else: reduce = candidate[MapReduceFusion._rmap_in_cr] return ' -> '.join(str(node) for node in [tasklet, map_exit, reduce]) @staticmethod def find_memlet_map_permutation(memlet: Memlet, map: nodes.Map): perm = [None] * len(memlet.subset) indices = set() for i, dim in enumerate(memlet.subset): for j, mapdim in enumerate(map.params): if symbolic.pystr_to_symbolic( mapdim) == dim and j not in indices: perm[i] = j indices.add(j) break return perm @staticmethod def find_permutation(tasklet_map: nodes.Map, red_outer_map: nodes.Map, red_inner_map: nodes.Map, tmem: Memlet): """ Find permutation between tasklet-exit memlet and tasklet map. """ result = [], [] assert len(tasklet_map.range) == len(red_inner_map.range) + len( red_outer_map.range) # Match map ranges with reduce ranges unavailable_ranges_out = set() unavailable_ranges_in = set() for i, tmap_rng in enumerate(tasklet_map.range): found = False for j, rng in enumerate(red_outer_map.range): if tmap_rng == rng and j not in unavailable_ranges_out: result[0].append(i) unavailable_ranges_out.add(j) found = True break if found: continue for j, rng in enumerate(red_inner_map.range): if tmap_rng == rng and j not in unavailable_ranges_in: result[1].append(i) unavailable_ranges_in.add(j) found = True break if not found: break # Ensure all map variables matched with reduce variables assert len(result[0]) + len(result[1]) == len(tasklet_map.range) # Returns ([outer map indices], [inner (CR) map indices]) return result @staticmethod def find_permutation_reduce(tasklet_map: nodes.Map, reduce_node: nodes.Reduce, graph: SDFGState, tmem: Memlet): in_memlet = graph.in_edges(reduce_node)[0].data out_memlet = graph.out_edges(reduce_node)[0].data assert len(tasklet_map.range) == in_memlet.subset.dims() # Find permutation between tasklet-exit memlet and tasklet map tmem_perm = MapReduceFusion.find_memlet_map_permutation( tmem, tasklet_map) mapred_perm = [] # Match map ranges with reduce ranges unavailable_ranges = set() for i, tmap_rng in enumerate(tasklet_map.range): found = False for j, in_rng in enumerate(in_memlet.subset): if tmap_rng == in_rng and j not in unavailable_ranges: mapred_perm.append(i) unavailable_ranges.add(j) found = True break if not found: break # Ensure all map variables matched with reduce variables assert len(tmem_perm) == len(tmem.subset) assert len(mapred_perm) == len(in_memlet.subset) # Prepare result from the two permutations and the reduction axes result = [] for i in range(len(mapred_perm)): if reduce_node.axes is None or i in reduce_node.axes: continue result.append(mapred_perm[tmem_perm[i]]) return result def apply(self, sdfg): def gnode(nname): return graph.nodes()[self.subgraph[nname]] expr_index = self.expr_index graph = sdfg.nodes()[self.state_id] tasklet = gnode(MapReduceFusion._tasklet) tmap_exit = graph.nodes()[self.subgraph[MapReduceFusion._tmap_exit]] in_array = graph.nodes()[self.subgraph[MapReduceFusion._in_array]] if expr_index == 0: # Reduce without outer map rmap_entry = graph.nodes()[self.subgraph[ MapReduceFusion._rmap_in_entry]] elif expr_index == 1: # Reduce with outer map rmap_out_entry = graph.nodes()[self.subgraph[ MapReduceFusion._rmap_out_entry]] rmap_out_exit = graph.nodes()[self.subgraph[ MapReduceFusion._rmap_out_exit]] rmap_in_entry = graph.nodes()[self.subgraph[ MapReduceFusion._rmap_in_entry]] rmap_tasklet = graph.nodes()[self.subgraph[ MapReduceFusion._rmap_in_tasklet]] if expr_index == 2: rmap_cr = graph.nodes()[self.subgraph[MapReduceFusion._reduce]] else: rmap_cr = graph.nodes()[self.subgraph[MapReduceFusion._rmap_in_cr]] out_array = gnode(MapReduceFusion._out_array) # Set nodes to remove according to the expression index nodes_to_remove = [in_array] if expr_index == 0: nodes_to_remove.append(gnode(MapReduceFusion._rmap_in_entry)) elif expr_index == 1: nodes_to_remove.append(gnode(MapReduceFusion._rmap_out_entry)) nodes_to_remove.append(gnode(MapReduceFusion._rmap_in_entry)) nodes_to_remove.append(gnode(MapReduceFusion._rmap_out_exit)) else: nodes_to_remove.append(gnode(MapReduceFusion._reduce)) # If no other edges lead to mapexit, remove it. Otherwise, keep # it and remove reduction incoming/outgoing edges if expr_index != 2 and len(graph.in_edges(tmap_exit)) == 1: nodes_to_remove.append(tmap_exit) memlet_edge = None for edge in graph.in_edges(tmap_exit): if edge.data.data == in_array.data: memlet_edge = edge break if memlet_edge is None: raise RuntimeError('Reduction memlet cannot be None') if expr_index == 0: # Reduce without outer map # Index order does not matter, merge as-is pass elif expr_index == 1: # Reduce with outer map tmap = tmap_exit.map perm_outer, perm_inner = MapReduceFusion.find_permutation( tmap, rmap_out_entry.map, rmap_in_entry.map, memlet_edge.data) # Split tasklet map into tmap_out -> tmap_in (according to # reduction) omap = nodes.Map( tmap.label + '_nonreduce', [p for i, p in enumerate(tmap.params) if i in perm_outer], [r for i, r in enumerate(tmap.range) if i in perm_outer], tmap.schedule, tmap.unroll, tmap.is_async) tmap.params = [ p for i, p in enumerate(tmap.params) if i in perm_inner ] tmap.range = [ r for i, r in enumerate(tmap.range) if i in perm_inner ] omap_entry = nodes.MapEntry(omap) omap_exit = rmap_out_exit rmap_out_exit.map = omap # Reconnect graph to new map tmap_entry = graph.entry_node(tmap_exit) tmap_in_edges = list(graph.in_edges(tmap_entry)) for e in tmap_in_edges: nxutil.change_edge_dest(graph, tmap_entry, omap_entry) for e in tmap_in_edges: graph.add_edge(omap_entry, e.src_conn, tmap_entry, e.dst_conn, copy.copy(e.data)) elif expr_index == 2: # Reduce node # Find correspondence between map indices and array outputs tmap = tmap_exit.map perm = MapReduceFusion.find_permutation_reduce( tmap, rmap_cr, graph, memlet_edge.data) output_subset = [tmap.params[d] for d in perm] if len(output_subset) == 0: # Output is a scalar output_subset = [0] array_edge = graph.out_edges(rmap_cr)[0] # Delete relevant edges and nodes graph.remove_edge(memlet_edge) graph.remove_nodes_from(nodes_to_remove) # Add new edges and nodes # From tasklet to map exit graph.add_edge( memlet_edge.src, memlet_edge.src_conn, memlet_edge.dst, memlet_edge.dst_conn, Memlet(out_array.data, memlet_edge.data.num_accesses, subsets.Indices(output_subset), memlet_edge.data.veclen, rmap_cr.wcr, rmap_cr.identity)) # From map exit to output array graph.add_edge( memlet_edge.dst, 'OUT_' + memlet_edge.dst_conn[3:], array_edge.dst, array_edge.dst_conn, Memlet(array_edge.data.data, array_edge.data.num_accesses, array_edge.data.subset, array_edge.data.veclen, rmap_cr.wcr, rmap_cr.identity)) return # Remove tmp array node prior to the others, so that a new one # can be created in its stead (see below) graph.remove_node(nodes_to_remove[0]) nodes_to_remove = nodes_to_remove[1:] # Create tasklet -> tmp -> tasklet connection tmp = graph.add_array( 'tmp', memlet_edge.data.subset.bounding_box_size(), sdfg.arrays[memlet_edge.data.data].dtype, transient=True) tasklet_tmp_memlet = copy.deepcopy(memlet_edge.data) tasklet_tmp_memlet.data = tmp.data tasklet_tmp_memlet.subset = ShapeProperty.to_string(tmp.shape) # Modify memlet to point to output array memlet_edge.data.data = out_array.data # Recover reduction axes from CR reduce subset reduce_cr_subset = graph.in_edges(rmap_tasklet)[0].data.subset reduce_axes = [] for ind, crvar in enumerate(reduce_cr_subset.indices): if '__i' in str(crvar): reduce_axes.append(ind) # Modify memlet access index by filtering out reduction axes if True: # expr_index == 0: newindices = [] for ind, ovar in enumerate(memlet_edge.data.subset.indices): if ind not in reduce_axes: newindices.append(ovar) if len(newindices) == 0: newindices = [0] memlet_edge.data.subset = subsets.Indices(newindices) graph.remove_edge(memlet_edge) graph.add_edge(memlet_edge.src, memlet_edge.src_conn, tmp, memlet_edge.dst_conn, tasklet_tmp_memlet) red_edges = list(graph.in_edges(rmap_tasklet)) if len(red_edges) != 1: raise RuntimeError('CR edge must be unique') tmp_tasklet_memlet = copy.deepcopy(tasklet_tmp_memlet) graph.add_edge(tmp, None, rmap_tasklet, red_edges[0].dst_conn, tmp_tasklet_memlet) for e in graph.edges_between(rmap_tasklet, rmap_cr): e.data.subset = memlet_edge.data.subset # Move output edges to point directly to CR node if expr_index == 1: # Set output memlet between CR node and outer reduction map to # contain the same subset as the one pointing to the CR node for e in graph.out_edges(rmap_cr): e.data.subset = memlet_edge.data.subset rmap_out = gnode(MapReduceFusion._rmap_out_exit) nxutil.change_edge_src(graph, rmap_out, omap_exit) # Remove nodes graph.remove_nodes_from(nodes_to_remove) # For unrelated outputs, connect original output to rmap_out if expr_index == 1 and tmap_exit not in nodes_to_remove: other_out_edges = list(graph.out_edges(tmap_exit)) for e in other_out_edges: graph.remove_edge(e) graph.add_edge(e.src, e.src_conn, omap_exit, None, e.data) graph.add_edge(omap_exit, None, e.dst, e.dst_conn, copy.copy(e.data)) def modifies_graph(self): return True
class Vectorization(pattern_matching.Transformation): """ Implements the vectorization transformation. Vectorization matches when all the input and output memlets of a tasklet inside a map access the inner-most loop variable in their last dimension. The transformation changes the step of the inner-most loop to be equal to the length of the vector and vectorizes the memlets. """ vector_len = Property(desc="Vector length", dtype=int, default=4) propagate_parent = Property(desc="Propagate vector length through " "parent SDFGs", dtype=bool, default=False) _map_entry = nodes.MapEntry(nodes.Map("", [], [])) _tasklet = nodes.Tasklet('_') _map_exit = nodes.MapExit(nodes.Map("", [], [])) @staticmethod def expressions(): return [ nxutil.node_path_graph(Vectorization._map_entry, Vectorization._tasklet, Vectorization._map_exit) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): map_entry = graph.nodes()[candidate[Vectorization._map_entry]] tasklet = graph.nodes()[candidate[Vectorization._tasklet]] param = symbolic.pystr_to_symbolic(map_entry.map.params[-1]) found = False # Check if all edges, adjacent to the tasklet, # use the parameter in their last dimension. for _src, _, _dest, _, memlet in graph.all_edges(tasklet): # Cases that do not matter for vectorization if memlet.data is None: # Empty memlets continue if isinstance(sdfg.arrays[memlet.data], data.Stream): # Streams continue # Vectorization can not be applied in WCR if memlet.wcr is not None: return False try: subset = memlet.subset veclen = memlet.veclen except AttributeError: return False if subset is None: return False try: if veclen > symbolic.pystr_to_symbolic('1'): return False for idx, expr in enumerate(subset): if isinstance(expr, tuple): for ex in expr: ex = symbolic.pystr_to_symbolic(ex) symbols = ex.free_symbols if param in symbols: if idx == subset.dims() - 1: found = True else: return False else: expr = symbolic.pystr_to_symbolic(expr) symbols = expr.free_symbols if param in symbols: if idx == subset.dims() - 1: found = True else: return False except TypeError: # cannot determine truth value of Relational return False return found @staticmethod def match_to_str(graph, candidate): map_entry = candidate[Vectorization._map_entry] tasklet = candidate[Vectorization._tasklet] map_exit = candidate[Vectorization._map_exit] return ' -> '.join( str(node) for node in [map_entry, tasklet, map_exit]) def apply(self, sdfg): graph = sdfg.nodes()[self.state_id] map_entry = graph.nodes()[self.subgraph[Vectorization._map_entry]] tasklet = graph.nodes()[self.subgraph[Vectorization._tasklet]] map_exit = graph.nodes()[self.subgraph[Vectorization._map_exit]] param = symbolic.pystr_to_symbolic(map_entry.map.params[-1]) # Create new vector size. vector_size = self.vector_len # Change the step of the inner-most dimension. dim_from, dim_to, _dim_step = map_entry.map.range[-1] map_entry.map.range[-1] = (dim_from, dim_to, vector_size) # Vectorize memlets adjacent to the tasklet. for edge in graph.all_edges(tasklet): _src, _, _dest, _, memlet = edge if memlet.data is None: # Empty memlets continue lastindex = memlet.subset[-1] if isinstance(lastindex, tuple): symbols = set() for indd in lastindex: symbols.update( symbolic.pystr_to_symbolic(indd).free_symbols) else: symbols = symbolic.pystr_to_symbolic( memlet.subset[-1]).free_symbols if param in symbols: try: # propagate vector length inside this SDFG for e in graph.memlet_path(edge): e.data.veclen = vector_size source_edge = graph.memlet_path(edge)[0] sink_edge = graph.memlet_path(edge)[-1] # propagate to the parent (TODO: handle multiple level of nestings) if self.propagate_parent and sdfg.parent is not None: # Find parent Nested SDFG node parent_node = next(n for n in sdfg.parent.nodes() if isinstance(n, nodes.NestedSDFG) and n.sdfg.name == sdfg.name) # continue in propagating the vector length following the path that arrives to source_edge or # starts from sink_edge for pe in sdfg.parent.all_edges(parent_node): if str(pe.dst_conn) == str(source_edge.src) or str( pe.src_conn) == str(sink_edge.dst): for ppe in sdfg.parent.memlet_path(pe): ppe.data.veclen = vector_size except AttributeError: raise return
class AccumulateTransient(pattern_matching.Transformation): """ Implements the AccumulateTransient transformation, which adds transient stream and data nodes between nested maps that lead to a stream. The transient data nodes then act as a local accumulator. """ _tasklet = nodes.Tasklet('_') _map_exit = nodes.MapExit(nodes.Map("", [], [])) _outer_map_exit = nodes.MapExit(nodes.Map("", [], [])) array = Property( dtype=str, desc="Array to create local storage for (if empty, first available)", default=None, allow_none=True) @staticmethod def expressions(): return [ nxutil.node_path_graph(AccumulateTransient._tasklet, AccumulateTransient._map_exit, AccumulateTransient._outer_map_exit) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): tasklet = graph.nodes()[candidate[AccumulateTransient._tasklet]] map_exit = graph.nodes()[candidate[AccumulateTransient._map_exit]] # Check if there is an accumulation output for _src, _, dest, _, memlet in graph.out_edges(tasklet): if memlet.wcr is not None and dest == map_exit: return True return False @staticmethod def match_to_str(graph, candidate): tasklet = candidate[AccumulateTransient._tasklet] map_exit = candidate[AccumulateTransient._map_exit] outer_map_exit = candidate[AccumulateTransient._outer_map_exit] return ' -> '.join( str(node) for node in [tasklet, map_exit, outer_map_exit]) def apply(self, sdfg): graph = sdfg.node(self.state_id) # Avoid import loop from dace.transformation.dataflow.local_storage import LocalStorage local_storage_subgraph = { LocalStorage._node_a: self.subgraph[AccumulateTransient._map_exit], LocalStorage._node_b: self.subgraph[AccumulateTransient._outer_map_exit] } sdfg_id = sdfg.sdfg_list.index(sdfg) in_local_storage = LocalStorage( sdfg_id, self.state_id, local_storage_subgraph, self.expr_index) in_local_storage.array = self.array in_local_storage.apply(sdfg) # Initialize transient to zero in case of summation # TODO: Initialize transient in other WCR types memlet = graph.in_edges(in_local_storage._data_node)[0].data if detect_reduction_type(memlet.wcr) == dtypes.ReductionType.Sum: in_local_storage._data_node.setzero = True else: warnings.warn('AccumulateTransient did not properly initialize' 'newly-created transient!')
def my_add_mapped_tasklet( state, inpdict, outdict, name: str, map_ranges: Dict[str, dace.subsets.Subset], inputs: Dict[str, dace.memlet.Memlet], code: str, outputs: Dict[str, dace.memlet.Memlet], schedule=dace.dtypes.ScheduleType.Default, unroll_map=False, code_global="", code_init="", code_exit="", location="-1", language=dace.dtypes.Language.Python, debuginfo=None, external_edges=True, ) -> Tuple[dace.graph.nodes.Node]: """ Convenience function that adds a map entry, tasklet, map exit, and the respective edges to external arrays. :param name: Tasklet (and wrapping map) name :param map_ranges: Mapping between variable names and their subsets :param inputs: Mapping between input local variable names and their memlets :param code: Code (written in `language`) :param outputs: Mapping between output local variable names and their memlets :param schedule: Map schedule :param unroll_map: True if map should be unrolled in code generation :param code_global: (optional) Global code (outside functions) :param language: Programming language in which the code is written :param debuginfo: Debugging information (mostly for DIODE) :param external_edges: Create external access nodes and connect them with memlets automatically :return: tuple of (tasklet, map_entry, map_exit) """ import dace.graph.nodes as nd from dace.sdfg import getdebuginfo from dace.graph.labeling import propagate_memlet map_name = name + "_map" debuginfo = getdebuginfo(debuginfo) tasklet = nd.Tasklet( name, set(inputs.keys()), set(outputs.keys()), code, language=language, code_global=code_global, code_init=code_init, code_exit=code_exit, location=location, debuginfo=debuginfo, ) map = state._map_from_ndrange(map_name, schedule, unroll_map, map_ranges, debuginfo=debuginfo) map_entry = nd.MapEntry(map) map_exit = nd.MapExit(map) state.add_nodes_from([map_entry, tasklet, map_exit]) tomemlet = {} for name, memlet in inputs.items(): memlet.name = name state.add_edge(map_entry, None, tasklet, name, memlet) tomemlet[memlet.data] = memlet if len(inputs) == 0: state.add_edge(map_entry, None, tasklet, None, dace.memlet.EmptyMemlet()) if external_edges: for inp, inpnode in inpdict.items(): outer_memlet = propagate_memlet(state, tomemlet[inp], map_entry, True) state.add_edge(inpnode, None, map_entry, "IN_" + inp, outer_memlet) for e in state.out_edges(map_entry): if e.data.data == inp: e._src_conn = "OUT_" + inp map_entry.add_in_connector("IN_" + inp) map_entry.add_out_connector("OUT_" + inp) tomemlet = {} for name, memlet in outputs.items(): memlet.name = name state.add_edge(tasklet, name, map_exit, None, memlet) tomemlet[memlet.data] = memlet if len(outputs) == 0: state.add_edge(tasklet, None, map_exit, None, mm.EmptyMemlet()) if external_edges: for out, outnode in outdict.items(): outer_memlet = propagate_memlet(state, tomemlet[out], map_exit, True) state.add_edge(map_exit, "OUT_" + out, outnode, None, outer_memlet) for e in state.in_edges(map_exit): if e.data.data == out: e._dst_conn = "IN_" + out map_exit.add_in_connector("IN_" + out) map_exit.add_out_connector("OUT_" + out) return tasklet, map_entry, map_exit