def apply(self, graph: SDFGState, sdfg: SDFG): map_exit = self.map_exit outer_map_exit = self.outer_map_exit # Choose array array = self.array if array is None or len(array) == 0: array = next(e.data.data for e in graph.edges_between(map_exit, outer_map_exit) if e.data.wcr is not None) # Avoid import loop from dace.transformation.dataflow.local_storage import OutLocalStorage data_node: nodes.AccessNode = OutLocalStorage.apply_to( sdfg, dict(array=array), verify=False, save=False, node_a=map_exit, node_b=outer_map_exit) if self.identity is None: warnings.warn('AccumulateTransient did not properly initialize ' 'newly-created transient!') return sdfg_state: SDFGState = sdfg.node(self.state_id) map_entry = sdfg_state.entry_node(map_exit) nested_sdfg: NestedSDFG = nest_state_subgraph( sdfg=sdfg, state=sdfg_state, subgraph=SubgraphView( sdfg_state, {map_entry, map_exit} | sdfg_state.all_nodes_between(map_entry, map_exit))) nested_sdfg_state: SDFGState = nested_sdfg.sdfg.nodes()[0] init_state = nested_sdfg.sdfg.add_state_before(nested_sdfg_state) temp_array: Array = sdfg.arrays[data_node.data] init_state.add_mapped_tasklet( name='acctrans_init', map_ranges={ '_o%d' % i: '0:%s' % symstr(d) for i, d in enumerate(temp_array.shape) }, inputs={}, code='out = %s' % self.identity, outputs={ 'out': dace.Memlet.simple(data=data_node.data, subset_str=','.join([ '_o%d' % i for i, _ in enumerate(temp_array.shape) ])) }, external_edges=True)
def apply(self, graph: SDFGState, sdfg: SDFG): node_a = self.node_a node_b = self.node_b prefix = self.prefix # Determine direction of new memlet scope_dict = graph.scope_dict() propagate_forward = sd.scope_contains_scope(scope_dict, node_a, node_b) array = self.array if array is None or len(array) == 0: array = next(e.data.data for e in graph.edges_between(node_a, node_b) if e.data.data is not None and e.data.wcr is None) original_edge = None invariant_memlet = None for edge in graph.edges_between(node_a, node_b): if array == edge.data.data: original_edge = edge invariant_memlet = edge.data break if invariant_memlet is None: for edge in graph.edges_between(node_a, node_b): original_edge = edge invariant_memlet = edge.data warnings.warn('Array %s not found! Using array %s instead.' % (array, invariant_memlet.data)) array = invariant_memlet.data break if invariant_memlet is None: raise NameError('Array %s not found!' % array) if self.create_array: # Add transient array new_data, _ = sdfg.add_transient( name=prefix + invariant_memlet.data, shape=[ symbolic.overapproximate(r).simplify() for r in invariant_memlet.bounding_box_size() ], dtype=sdfg.arrays[invariant_memlet.data].dtype, find_new_name=True) else: new_data = prefix + invariant_memlet.data data_node = nodes.AccessNode(new_data) # Store as fields so that other transformations can use them self._local_name = new_data self._data_node = data_node to_data_mm = copy.deepcopy(invariant_memlet) from_data_mm = copy.deepcopy(invariant_memlet) offset = subsets.Indices([r[0] for r in invariant_memlet.subset]) # Reconnect, assuming one edge to the access node graph.remove_edge(original_edge) if propagate_forward: graph.add_edge(node_a, original_edge.src_conn, data_node, None, to_data_mm) new_edge = graph.add_edge(data_node, None, node_b, original_edge.dst_conn, from_data_mm) else: new_edge = graph.add_edge(node_a, original_edge.src_conn, data_node, None, to_data_mm) graph.add_edge(data_node, None, node_b, original_edge.dst_conn, from_data_mm) # Offset all edges in the memlet tree (including the new edge) for edge in graph.memlet_tree(new_edge): edge.data.subset.offset(offset, True) edge.data.data = new_data return data_node