class Map(object): """ A Map is a two-node representation of parametric graphs, containing an integer set by which the contents (nodes dominated by an entry node and post-dominated by an exit node) are replicated. Maps contain a `schedule` property, which specifies how the scope should be scheduled (execution order). Code generators can use the schedule property to generate appropriate code, e.g., GPU kernels. """ # List of (editable) properties label = Property(dtype=str, desc="Label of the map") params = ListProperty(element_type=str, desc="Mapped parameters") range = RangeProperty(desc="Ranges of map parameters", default=sbs.Range([])) schedule = EnumProperty(dtype=dtypes.ScheduleType, desc="Map schedule", default=dtypes.ScheduleType.Default) unroll = Property(dtype=bool, desc="Map unrolling") collapse = Property(dtype=int, default=1, desc="How many dimensions to" " collapse into the parallel range") debuginfo = DebugInfoProperty() is_collapsed = Property(dtype=bool, desc="Show this node/scope/state as collapsed", default=False) instrument = EnumProperty( dtype=dtypes.InstrumentationType, desc="Measure execution statistics with given method", default=dtypes.InstrumentationType.No_Instrumentation) def __init__(self, label, params, ndrange, schedule=dtypes.ScheduleType.Default, unroll=False, collapse=1, fence_instrumentation=False, debuginfo=None): super(Map, self).__init__() # Assign properties self.label = label self.schedule = schedule self.unroll = unroll self.collapse = 1 self.params = params self.range = ndrange self.debuginfo = debuginfo self._fence_instrumentation = fence_instrumentation def __str__(self): return self.label + "[" + ", ".join([ "{}={}".format(i, r) for i, r in zip(self._params, [sbs.Range.dim_to_string(d) for d in self._range]) ]) + "]" def validate(self, sdfg, state, node): if not dtypes.validate_name(self.label): raise NameError('Invalid map name "%s"' % self.label) def get_param_num(self): """ Returns the number of map dimension parameters/symbols. """ return len(self.params)
class Data(object): """ Data type descriptors that can be used as references to memory. Examples: Arrays, Streams, custom arrays (e.g., sparse matrices). """ dtype = TypeClassProperty(default=dtypes.int32, choices=dtypes.Typeclasses) shape = ShapeProperty(default=[]) transient = Property(dtype=bool, default=False) storage = EnumProperty(dtype=dtypes.StorageType, desc="Storage location", default=dtypes.StorageType.Default) lifetime = EnumProperty(dtype=dtypes.AllocationLifetime, desc='Data allocation span', default=dtypes.AllocationLifetime.Scope) location = DictProperty( key_type=str, value_type=symbolic.pystr_to_symbolic, desc='Full storage location identifier (e.g., rank, GPU ID)') debuginfo = DebugInfoProperty(allow_none=True) def __init__(self, dtype, shape, transient, storage, location, lifetime, debuginfo): self.dtype = dtype self.shape = shape self.transient = transient self.storage = storage self.location = location if location is not None else {} self.lifetime = lifetime self.debuginfo = debuginfo self._validate() def validate(self): """ Validate the correctness of this object. Raises an exception on error. """ self._validate() # Validation of this class is in a separate function, so that this # class can call `_validate()` without calling the subclasses' # `validate` function. def _validate(self): if any(not isinstance(s, (int, symbolic.SymExpr, symbolic.symbol, symbolic.sympy.Basic)) for s in self.shape): raise TypeError('Shape must be a list or tuple of integer values ' 'or symbols') return True def to_json(self): attrs = serialize.all_properties_to_json(self) retdict = {"type": type(self).__name__, "attributes": attrs} return retdict @property def toplevel(self): return self.lifetime is not dtypes.AllocationLifetime.Scope def copy(self): raise RuntimeError( 'Data descriptors are unique and should not be copied') def is_equivalent(self, other): """ Check for equivalence (shape and type) of two data descriptors. """ raise NotImplementedError def as_arg(self, with_types=True, for_call=False, name=None): """Returns a string for a C++ function signature (e.g., `int *A`). """ raise NotImplementedError @property def free_symbols(self) -> Set[symbolic.SymbolicType]: """ Returns a set of undefined symbols in this data descriptor. """ result = set() for s in self.shape: if isinstance(s, sp.Basic): result |= set(s.free_symbols) return result def __repr__(self): return 'Abstract Data Container, DO NOT USE' @property def veclen(self): return self.dtype.veclen if hasattr(self.dtype, "veclen") else 1 @property def ctype(self): return self.dtype.ctype
class Tasklet(CodeNode): """ A node that contains a tasklet: a functional computation procedure that can only access external data specified using connectors. Tasklets may be implemented in Python, C++, or any supported language by the code generator. """ code = CodeProperty(desc="Tasklet code", default=CodeBlock("")) state_fields = ListProperty( element_type=str, desc="Fields that are added to the global state") code_global = CodeProperty( desc="Global scope code needed for tasklet execution", default=CodeBlock("", dtypes.Language.CPP)) code_init = CodeProperty( desc="Extra code that is called on DaCe runtime initialization", default=CodeBlock("", dtypes.Language.CPP)) code_exit = CodeProperty( desc="Extra code that is called on DaCe runtime cleanup", default=CodeBlock("", dtypes.Language.CPP)) debuginfo = DebugInfoProperty() instrument = EnumProperty( dtype=dtypes.InstrumentationType, desc="Measure execution statistics with given method", default=dtypes.InstrumentationType.No_Instrumentation) def __init__(self, label, inputs=None, outputs=None, code="", language=dtypes.Language.Python, state_fields=None, code_global="", code_init="", code_exit="", location=None, debuginfo=None): super(Tasklet, self).__init__(label, location, inputs, outputs) self.code = CodeBlock(code, language) self.state_fields = state_fields or [] self.code_global = CodeBlock(code_global, dtypes.Language.CPP) self.code_init = CodeBlock(code_init, dtypes.Language.CPP) self.code_exit = CodeBlock(code_exit, dtypes.Language.CPP) self.debuginfo = debuginfo @property def language(self): return self.code.language @staticmethod def from_json(json_obj, context=None): ret = Tasklet("dummylabel") dace.serialize.set_properties_from_json(ret, json_obj, context=context) return ret @property def name(self): return self._label def validate(self, sdfg, state): if not dtypes.validate_name(self.label): raise NameError('Invalid tasklet name "%s"' % self.label) for in_conn in self.in_connectors: if not dtypes.validate_name(in_conn): raise NameError('Invalid input connector "%s"' % in_conn) for out_conn in self.out_connectors: if not dtypes.validate_name(out_conn): raise NameError('Invalid output connector "%s"' % out_conn) @property def free_symbols(self) -> Set[str]: return self.code.get_free_symbols(self.in_connectors.keys() | self.out_connectors.keys()) def infer_connector_types(self, sdfg, state): # If a MLIR tasklet, simply read out the types (it's explicit) if self.code.language == dtypes.Language.MLIR: # Inline import because mlir.utils depends on pyMLIR which may not be installed # Doesn't cause crashes due to missing pyMLIR if a MLIR tasklet is not present from dace.codegen.targets.mlir import utils mlir_ast = utils.get_ast(self.code.code) mlir_is_generic = utils.is_generic(mlir_ast) mlir_entry_func = utils.get_entry_func(mlir_ast, mlir_is_generic) mlir_result_type = utils.get_entry_result_type( mlir_entry_func, mlir_is_generic) mlir_out_name = next(iter(self.out_connectors.keys())) if self.out_connectors[ mlir_out_name] is None or self.out_connectors[ mlir_out_name].ctype == "void": self.out_connectors[mlir_out_name] = utils.get_dace_type( mlir_result_type) elif self.out_connectors[mlir_out_name] != utils.get_dace_type( mlir_result_type): warnings.warn( "Type mismatch between MLIR tasklet out connector and MLIR code" ) for mlir_arg in utils.get_entry_args(mlir_entry_func, mlir_is_generic): if self.in_connectors[ mlir_arg[0]] is None or self.in_connectors[ mlir_arg[0]].ctype == "void": self.in_connectors[mlir_arg[0]] = utils.get_dace_type( mlir_arg[1]) elif self.in_connectors[mlir_arg[0]] != utils.get_dace_type( mlir_arg[1]): warnings.warn( "Type mismatch between MLIR tasklet in connector and MLIR code" ) return # If a Python tasklet, use type inference to figure out all None output # connectors if all(cval.type is not None for cval in self.out_connectors.values()): return if self.code.language != dtypes.Language.Python: return if any(cval.type is None for cval in self.in_connectors.values()): raise TypeError('Cannot infer output connectors of tasklet "%s", ' 'not all input connectors have types' % str(self)) # Avoid import loop from dace.codegen.tools.type_inference import infer_types # Get symbols defined at beginning of node, and infer all types in # tasklet syms = state.symbols_defined_at(self) syms.update(self.in_connectors) new_syms = infer_types(self.code.code, syms) for cname, oconn in self.out_connectors.items(): if oconn.type is None: if cname not in new_syms: raise TypeError('Cannot infer type of tasklet %s output ' '"%s", please specify manually.' % (self.label, cname)) self.out_connectors[cname] = new_syms[cname] def __str__(self): if not self.label: return "--Empty--" else: return self.label
class NestedSDFG(CodeNode): """ An SDFG state node that contains an SDFG of its own, runnable using the data dependencies specified using its connectors. It is encouraged to use nested SDFGs instead of coarse-grained tasklets since they are analyzable with respect to transformations. :note: A nested SDFG cannot create recursion (one of its parent SDFGs). """ # NOTE: We cannot use SDFG as the type because of an import loop sdfg = SDFGReferenceProperty(desc="The SDFG", allow_none=True) schedule = EnumProperty(dtype=dtypes.ScheduleType, desc="SDFG schedule", allow_none=True, default=dtypes.ScheduleType.Default) symbol_mapping = DictProperty( key_type=str, value_type=dace.symbolic.pystr_to_symbolic, desc="Mapping between internal symbols and their values, expressed as " "symbolic expressions") debuginfo = DebugInfoProperty() is_collapsed = Property(dtype=bool, desc="Show this node/scope/state as collapsed", default=False) instrument = EnumProperty( dtype=dtypes.InstrumentationType, desc="Measure execution statistics with given method", default=dtypes.InstrumentationType.No_Instrumentation) no_inline = Property( dtype=bool, desc="If True, this nested SDFG will not be inlined during " "simplification", default=False) unique_name = Property(dtype=str, desc="Unique name of the SDFG", default="") def __init__(self, label, sdfg, inputs: Set[str], outputs: Set[str], symbol_mapping: Dict[str, Any] = None, schedule=dtypes.ScheduleType.Default, location=None, debuginfo=None): from dace.sdfg import SDFG super(NestedSDFG, self).__init__(label, location, inputs, outputs) # Properties self.sdfg: SDFG = sdfg self.symbol_mapping = symbol_mapping or {} self.schedule = schedule self.debuginfo = debuginfo @staticmethod def from_json(json_obj, context=None): from dace import SDFG # Avoid import loop # We have to load the SDFG first. ret = NestedSDFG("nolabel", SDFG('nosdfg'), {}, {}) dace.serialize.set_properties_from_json(ret, json_obj, context) if context and 'sdfg_state' in context: ret.sdfg.parent = context['sdfg_state'] if context and 'sdfg' in context: ret.sdfg.parent_sdfg = context['sdfg'] ret.sdfg.parent_nsdfg_node = ret ret.sdfg.update_sdfg_list([]) return ret @property def free_symbols(self) -> Set[str]: return set().union( *(map(str, pystr_to_symbolic(v).free_symbols) for v in self.symbol_mapping.values()), *(map(str, pystr_to_symbolic(v).free_symbols) for v in self.location.values())) def infer_connector_types(self, sdfg, state): # Avoid import loop from dace.sdfg.infer_types import infer_connector_types # Infer internal connector types infer_connector_types(self.sdfg) def __str__(self): if not self.label: return "SDFG" else: return self.label def validate(self, sdfg, state): if not dtypes.validate_name(self.label): raise NameError('Invalid nested SDFG name "%s"' % self.label) for in_conn in self.in_connectors: if not dtypes.validate_name(in_conn): raise NameError('Invalid input connector "%s"' % in_conn) for out_conn in self.out_connectors: if not dtypes.validate_name(out_conn): raise NameError('Invalid output connector "%s"' % out_conn) connectors = self.in_connectors.keys() | self.out_connectors.keys() for dname, desc in self.sdfg.arrays.items(): # TODO(later): Disallow scalars without access nodes (so that this # check passes for them too). if isinstance(desc, data.Scalar): continue if not desc.transient and dname not in connectors: raise NameError('Data descriptor "%s" not found in nested ' 'SDFG connectors' % dname) if dname in connectors and desc.transient: raise NameError( '"%s" is a connector but its corresponding array is transient' % dname) # Validate undefined symbols symbols = set(k for k in self.sdfg.free_symbols if k not in connectors) missing_symbols = [s for s in symbols if s not in self.symbol_mapping] if missing_symbols: raise ValueError('Missing symbols on nested SDFG: %s' % (missing_symbols)) extra_symbols = self.symbol_mapping.keys() - symbols if len(extra_symbols) > 0: # TODO: Elevate to an error? warnings.warn( f"{self.label} maps to unused symbol(s): {extra_symbols}") # Recursively validate nested SDFG self.sdfg.validate()
class LibraryNode(CodeNode): name = Property(dtype=str, desc="Name of node") implementation = LibraryImplementationProperty( dtype=str, allow_none=True, desc=("Which implementation this library node will expand into." "Must match a key in the list of possible implementations.")) schedule = EnumProperty( dtype=dtypes.ScheduleType, desc="If set, determines the default device mapping of " "the node upon expansion, if expanded to a nested SDFG.", default=dtypes.ScheduleType.Default) debuginfo = DebugInfoProperty() def __init__(self, name, *args, schedule=None, **kwargs): super().__init__(*args, **kwargs) self.name = name self.label = name self.schedule = schedule or dtypes.ScheduleType.Default # Overrides subclasses to return LibraryNode as their JSON type @property def __jsontype__(self): return 'LibraryNode' def to_json(self, parent): jsonobj = super().to_json(parent) jsonobj['classpath'] = full_class_path(self) return jsonobj @classmethod def from_json(cls, json_obj, context=None): if cls == LibraryNode: clazz = pydoc.locate(json_obj['classpath']) if clazz is None: return UnregisteredLibraryNode.from_json(json_obj, context) return clazz.from_json(json_obj, context) else: # Subclasses are actual library nodes ret = cls(json_obj['attributes']['name']) dace.serialize.set_properties_from_json(ret, json_obj, context=context) return ret def expand(self, sdfg, state, *args, **kwargs) -> str: """ Create and perform the expansion transformation for this library node. :return: the name of the expanded implementation """ implementation = self.implementation library_name = getattr(type(self), '_dace_library_name', '') try: if library_name: config_implementation = Config.get("library", library_name, "default_implementation") else: config_implementation = None except KeyError: # Non-standard libraries are not defined in the config schema, and # thus might not exist in the config. config_implementation = None if config_implementation is not None: try: config_override = Config.get("library", library_name, "override") if config_override and implementation in self.implementations: if implementation is not None: warnings.warn( "Overriding explicitly specified " "implementation {} for {} with {}.".format( implementation, self.label, config_implementation)) implementation = config_implementation except KeyError: config_override = False # If not explicitly set, try the node default if implementation is None: implementation = type(self).default_implementation # If no node default, try library default if implementation is None: import dace.library # Avoid cyclic dependency lib = dace.library._DACE_REGISTERED_LIBRARIES[type( self)._dace_library_name] implementation = lib.default_implementation # Try the default specified in the config if implementation is None: implementation = config_implementation # Otherwise we don't know how to expand if implementation is None: raise ValueError("No implementation or default " "implementation specified.") if implementation not in self.implementations.keys(): raise KeyError("Unknown implementation for node {}: {}".format( type(self).__name__, implementation)) transformation_type = type(self).implementations[implementation] sdfg_id = sdfg.sdfg_id state_id = sdfg.nodes().index(state) subgraph = {transformation_type._match_node: state.node_id(self)} transformation = transformation_type(sdfg, sdfg_id, state_id, subgraph, 0) if not transformation.can_be_applied(state, 0, sdfg): raise RuntimeError("Library node " "expansion applicability check failed.") sdfg.append_transformation(transformation) transformation.apply(state, sdfg, *args, **kwargs) return implementation @classmethod def register_implementation(cls, name, transformation_type): """Register an implementation to belong to this library node type.""" cls.implementations[name] = transformation_type transformation_type._match_node = cls
class AccessNode(Node): """ A node that accesses data in the SDFG. Denoted by a circular shape. """ access = EnumProperty(dtype=dtypes.AccessType, desc="Type of access to this array", default=dtypes.AccessType.ReadWrite) setzero = Property(dtype=bool, desc="Initialize to zero", default=False) debuginfo = DebugInfoProperty() data = DataProperty(desc="Data (array, stream, scalar) to access") def __init__(self, data, access=dtypes.AccessType.ReadWrite, debuginfo=None): super(AccessNode, self).__init__() # Properties self.debuginfo = debuginfo self.access = access if not isinstance(data, str): raise TypeError('Data for AccessNode must be a string') self.data = data @staticmethod def from_json(json_obj, context=None): ret = AccessNode("Nodata") dace.serialize.set_properties_from_json(ret, json_obj, context=context) return ret def __deepcopy__(self, memo): node = object.__new__(AccessNode) node._access = self._access node._data = self._data node._setzero = self._setzero node._in_connectors = dcpy(self._in_connectors, memo=memo) node._out_connectors = dcpy(self._out_connectors, memo=memo) node._debuginfo = dcpy(self._debuginfo, memo=memo) return node @property def label(self): return self.data def __label__(self, sdfg, state): return self.data def desc(self, sdfg): from dace.sdfg import SDFGState, ScopeSubgraphView if isinstance(sdfg, (SDFGState, ScopeSubgraphView)): sdfg = sdfg.parent return sdfg.arrays[self.data] def validate(self, sdfg, state): if self.data not in sdfg.arrays: raise KeyError('Array "%s" not found in SDFG' % self.data) def has_writes(self, state): for e in state.in_edges(self): if not e.data.is_empty(): return True return False def has_reads(self, state): for e in state.out_edges(self): if not e.data.is_empty(): return True return False
class Consume(object): """ Consume is a scope, like `Map`, that is a part of the parametric graph extension of the SDFG. It creates a producer-consumer relationship between the input stream and the scope subgraph. The subgraph is scheduled to a given number of processing elements for processing, and they will try to pop elements from the input stream until a given quiescence condition is reached. """ # Properties label = Property(dtype=str, desc="Name of the consume node") pe_index = Property(dtype=str, desc="Processing element identifier") num_pes = SymbolicProperty(desc="Number of processing elements", default=1) condition = CodeProperty(desc="Quiescence condition", allow_none=True) schedule = EnumProperty(dtype=dtypes.ScheduleType, desc="Consume schedule", default=dtypes.ScheduleType.Default) chunksize = Property(dtype=int, desc="Maximal size of elements to consume at a time", default=1) debuginfo = DebugInfoProperty() is_collapsed = Property(dtype=bool, desc="Show this node/scope/state as collapsed", default=False) instrument = EnumProperty( dtype=dtypes.InstrumentationType, desc="Measure execution statistics with given method", default=dtypes.InstrumentationType.No_Instrumentation) def as_map(self): """ Compatibility function that allows to view the consume as a map, mainly in memlet propagation. """ return Map(self.label, [self.pe_index], sbs.Range([(0, self.num_pes - 1, 1)]), self.schedule) def __init__(self, label, pe_tuple, condition, schedule=dtypes.ScheduleType.Default, chunksize=1, debuginfo=None): super(Consume, self).__init__() # Properties self.label = label self.pe_index, self.num_pes = pe_tuple self.condition = condition self.schedule = schedule self.chunksize = chunksize self.debuginfo = debuginfo def __str__(self): if self.condition is not None: return ("%s [%s=0:%s], Condition: %s" % (self._label, self.pe_index, self.num_pes, CodeProperty.to_string(self.condition))) else: return ("%s [%s=0:%s]" % (self._label, self.pe_index, self.num_pes)) def validate(self, sdfg, state, node): if not dtypes.validate_name(self.label): raise NameError('Invalid consume name "%s"' % self.label) def get_param_num(self): """ Returns the number of consume dimension parameters/symbols. """ return 1
class Data(object): """ Data type descriptors that can be used as references to memory. Examples: Arrays, Streams, custom arrays (e.g., sparse matrices). """ dtype = TypeClassProperty(default=dtypes.int32, choices=dtypes.Typeclasses) shape = ShapeProperty(default=[]) transient = Property(dtype=bool, default=False) storage = EnumProperty(dtype=dtypes.StorageType, desc="Storage location", default=dtypes.StorageType.Default) lifetime = EnumProperty(dtype=dtypes.AllocationLifetime, desc='Data allocation span', default=dtypes.AllocationLifetime.Scope) location = DictProperty(key_type=str, value_type=str, desc='Full storage location identifier (e.g., rank, GPU ID)') debuginfo = DebugInfoProperty(allow_none=True) def __init__(self, dtype, shape, transient, storage, location, lifetime, debuginfo): self.dtype = dtype self.shape = shape self.transient = transient self.storage = storage self.location = location if location is not None else {} self.lifetime = lifetime self.debuginfo = debuginfo self._validate() def validate(self): """ Validate the correctness of this object. Raises an exception on error. """ self._validate() # Validation of this class is in a separate function, so that this # class can call `_validate()` without calling the subclasses' # `validate` function. def _validate(self): if any(not isinstance(s, (int, symbolic.SymExpr, symbolic.symbol, symbolic.sympy.Basic)) for s in self.shape): raise TypeError('Shape must be a list or tuple of integer values ' 'or symbols') return True def to_json(self): attrs = serialize.all_properties_to_json(self) retdict = {"type": type(self).__name__, "attributes": attrs} return retdict @property def toplevel(self): return self.lifetime is not dtypes.AllocationLifetime.Scope def copy(self): raise RuntimeError('Data descriptors are unique and should not be copied') def is_equivalent(self, other): """ Check for equivalence (shape and type) of two data descriptors. """ raise NotImplementedError def as_arg(self, with_types=True, for_call=False, name=None): """Returns a string for a C++ function signature (e.g., `int *A`). """ raise NotImplementedError @property def free_symbols(self) -> Set[symbolic.SymbolicType]: """ Returns a set of undefined symbols in this data descriptor. """ result = set() for s in self.shape: if isinstance(s, sp.Basic): result |= set(s.free_symbols) return result def __repr__(self): return 'Abstract Data Container, DO NOT USE' @property def veclen(self): return self.dtype.veclen if hasattr(self.dtype, "veclen") else 1 @property def ctype(self): return self.dtype.ctype def strides_from_layout( self, *dimensions: int, alignment: symbolic.SymbolicType = 1, only_first_aligned: bool = False, ) -> Tuple[Tuple[symbolic.SymbolicType], symbolic.SymbolicType]: """ Returns the absolute strides and total size of this data descriptor, according to the given dimension ordering and alignment. :param dimensions: A sequence of integers representing a permutation of the descriptor's dimensions. :param alignment: Padding (in elements) at the end, ensuring stride is a multiple of this number. 1 (default) means no padding. :param only_first_aligned: If True, only the first dimension is padded with ``alignment``. Otherwise all dimensions are. :return: A 2-tuple of (tuple of strides, total size). """ # Verify dimensions if tuple(sorted(dimensions)) != tuple(range(len(self.shape))): raise ValueError('Every dimension must be given and appear once.') if (alignment < 1) == True or (alignment < 0) == True: raise ValueError('Invalid alignment value') strides = [1] * len(dimensions) total_size = 1 first = True for dim in dimensions: strides[dim] = total_size if not only_first_aligned or first: dimsize = (((self.shape[dim] + alignment - 1) // alignment) * alignment) else: dimsize = self.shape[dim] total_size *= dimsize first = False return (tuple(strides), total_size) def set_strides_from_layout(self, *dimensions: int, alignment: symbolic.SymbolicType = 1, only_first_aligned: bool = False): """ Sets the absolute strides and total size of this data descriptor, according to the given dimension ordering and alignment. :param dimensions: A sequence of integers representing a permutation of the descriptor's dimensions. :param alignment: Padding (in elements) at the end, ensuring stride is a multiple of this number. 1 (default) means no padding. :param only_first_aligned: If True, only the first dimension is padded with ``alignment``. Otherwise all dimensions are. """ strides, totalsize = self.strides_from_layout(*dimensions, alignment=alignment, only_first_aligned=only_first_aligned) self.strides = strides self.total_size = totalsize
class CompositeFusion(transformation.SubgraphTransformation): """ MultiExpansion + SubgraphFusion in one Transformation Additional StencilTiling is also possible as a canonicalizing transformation before fusion. """ debug = Property(desc="Debug mode", dtype=bool, default=False) allow_expansion = Property(desc="Allow MultiExpansion first", dtype=bool, default=True) allow_tiling = Property(desc="Allow StencilTiling (after MultiExpansion)", dtype=bool, default=False) transient_allocation = EnumProperty( desc="Storage Location to push transients to that are " "fully contained within the subgraph.", dtype=dtypes.StorageType, default=dtypes.StorageType.Default) schedule_innermaps = Property(desc="Schedule of inner fused maps", dtype=dtypes.ScheduleType, default=None, allow_none=True) stencil_unroll_loops = Property( desc="Unroll inner stencil loops if they have size > 1", dtype=bool, default=False) stencil_strides = ShapeProperty(dtype=tuple, default=(1, ), desc="Stencil tile stride") expansion_split = Property( desc="Allow MultiExpansion to split up maps, if enabled", dtype=bool, default=True) def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: graph = subgraph.graph if self.allow_expansion == True: subgraph_fusion = SubgraphFusion(subgraph) if subgraph_fusion.can_be_applied(sdfg, subgraph): # try w/o copy first return True expansion = MultiExpansion(subgraph) expansion.permutation_only = not self.expansion_split if expansion.can_be_applied(sdfg, subgraph): # deepcopy graph_indices = [ i for (i, n) in enumerate(graph.nodes()) if n in subgraph ] sdfg_copy = SDFG.from_json(sdfg.to_json()) graph_copy = sdfg_copy.nodes()[sdfg.nodes().index(graph)] subgraph_copy = SubgraphView( graph_copy, [graph_copy.nodes()[i] for i in graph_indices]) ##sdfg_copy.apply_transformations(MultiExpansion, states=[graph]) #expansion = MultiExpansion(subgraph_copy) expansion.apply(sdfg_copy) subgraph_fusion = SubgraphFusion(subgraph_copy) if subgraph_fusion.can_be_applied(sdfg_copy, subgraph_copy): return True stencil_tiling = StencilTiling(subgraph_copy) if self.allow_tiling and stencil_tiling.can_be_applied( sdfg_copy, subgraph_copy): return True else: subgraph_fusion = SubgraphFusion(subgraph) if subgraph_fusion.can_be_applied(sdfg, subgraph): return True if self.allow_tiling == True: stencil_tiling = StencilTiling(subgraph) if stencil_tiling.can_be_applied(sdfg, subgraph): return True return False def apply(self, sdfg): subgraph = self.subgraph_view(sdfg) graph = subgraph.graph scope_dict = graph.scope_dict() map_entries = helpers.get_outermost_scope_maps(sdfg, graph, subgraph, scope_dict) first_entry = next(iter(map_entries)) if self.allow_expansion: expansion = MultiExpansion(subgraph, self.sdfg_id, self.state_id) expansion.permutation_only = not self.expansion_split if expansion.can_be_applied(sdfg, subgraph): expansion.apply(sdfg) sf = SubgraphFusion(subgraph, self.sdfg_id, self.state_id) if sf.can_be_applied(sdfg, self.subgraph_view(sdfg)): # set SubgraphFusion properties sf.debug = self.debug sf.transient_allocation = self.transient_allocation sf.schedule_innermaps = self.schedule_innermaps sf.apply(sdfg) self._global_map_entry = sf._global_map_entry return elif self.allow_tiling == True: st = StencilTiling(subgraph, self.sdfg_id, self.state_id) if st.can_be_applied(sdfg, self.subgraph_view(sdfg)): # set StencilTiling properties st.debug = self.debug st.unroll_loops = self.stencil_unroll_loops st.strides = self.stencil_strides st.apply(sdfg) # StencilTiling: update nodes new_entries = st._outer_entries subgraph = helpers.subgraph_from_maps(sdfg, graph, new_entries) sf = SubgraphFusion(subgraph, self.sdfg_id, self.state_id) # set SubgraphFusion properties sf.debug = self.debug sf.transient_allocation = self.transient_allocation sf.schedule_innermaps = self.schedule_innermaps sf.apply(sdfg) self._global_map_entry = sf._global_map_entry return warnings.warn("CompositeFusion::Apply did not perform as expected")
class InstrumentationReport(object): sortingType = EnumProperty( dtype=dtypes.InstrumentationReportPrintType, default=dtypes.InstrumentationReportPrintType.SDFG, desc="SDFG: Sorting: SDFG, State, Node, Location" "Location: Sorting: Location, SDFG, State, Node") @staticmethod def get_event_uuid(event): try: args = event['args'] except KeyError: return (-1, -1, -1, -1) uuid = (args.get('sdfg_id', -1), args.get('state_id', -1), args.get('id', -1), args.get('loc_id', -1)) return uuid def __init__(self, filename: str, sortingType: dtypes.InstrumentationReportPrintType = dtypes. InstrumentationReportPrintType.SDFG): self.sortingType = sortingType # Parse file match = re.match(r'.*report-(\d+)\.json', filename) self._name = match.groups()[0] if match is not None else 'N/A' self._durations = {} self._counters = {} with open(filename, 'r') as fp: report = json.load(fp) if 'traceEvents' not in report or 'sdfgHash' not in report: print(filename, 'is not a valid SDFG instrumentation report!') return self._sdfg_hash = report['sdfgHash'] events = report['traceEvents'] for event in events: if 'ph' in event: phase = event['ph'] name = event['name'] if phase == 'X': uuid = self.get_event_uuid(event) if uuid not in self._durations: self._durations[uuid] = {} if name not in self._durations[uuid]: self._durations[uuid][name] = [] self._durations[uuid][name].append(event['dur'] / 1000) if phase == 'C': if name not in self.counters: self.counters[name] = 0 self.counters[name] += event['args'][name] def __repr__(self): return 'InstrumentationReport(name=%s)' % self._name def _get_runtimes_string(self, label, runtimes, element, sdfg, state, string, row_format, format_dict, with_element_heading=True): location_label = f'Device: {element[3]}' if element[3] != -1 else 'CPU' indent = '' if len(runtimes) > 0: element_label = '' format_dict['loc'] = '' format_dict['min'] = '' format_dict['mean'] = '' format_dict['median'] = '' format_dict['max'] = '' if element[0] > -1 and element[1] > -1 and element[2] > -1: # This element is a node. if sdfg != element[0]: # No parent SDFG row present yet. format_dict['elem'] = 'SDFG (' + str(element[0]) + ')' sdfg = element[0] if state != element[1]: # No parent state row present yet. format_dict['elem'] = '|-State (' + str(element[1]) + ')' # Print string += row_format.format(**format_dict) state = element[1] element_label = '| |-Node (' + str(element[2]) + ')' indent = '| | |' elif element[0] > -1 and element[1] > -1: # This element is a state. if sdfg != element[0]: # No parent SDFG row present yet, print it. format_dict['elem'] = 'SDFG (' + str(element[0]) + ')' string += row_format.format(**format_dict) sdfg = element[0] state = element[1] element_label = '|-State (' + str(element[1]) + ')' indent = '| |' elif element[0] > -1: # This element is an SDFG. sdfg = element[0] state = -1 element_label = 'SDFG (' + str(element[0]) + ')' indent = '|' else: element_label = 'N/A' if with_element_heading: format_dict['elem'] = element_label string += row_format.format(**format_dict) format_dict['elem'] = indent + label + ':' string += row_format.format(**format_dict) format_dict['elem'] = indent format_dict['loc'] = location_label format_dict['min'] = '%.3f' % np.min(runtimes) format_dict['mean'] = '%.3f' % np.mean(runtimes) format_dict['median'] = '%.3f' % np.median(runtimes) format_dict['max'] = '%.3f' % np.max(runtimes) string += row_format.format(**format_dict) return string, sdfg, state def __str__(self): element_list = list(self._durations.keys()) element_list.sort() string = 'Instrumentation report\n' string += 'SDFG Hash: ' + self._sdfg_hash + '\n' if len(self._durations) > 0: COLW_ELEM = 30 COLW_LOC = 15 COLW_RUNTIME = 15 NUM_RUNTIME_COLS = 4 line_string = ('-' * (COLW_RUNTIME * NUM_RUNTIME_COLS + COLW_ELEM + COLW_LOC)) + '\n' string += line_string if self.sortingType == dtypes.InstrumentationReportPrintType.Location: row_format = ('{loc:<{loc_width}}') + ( '{elem:<{elem_width}}') + ('{min:<{width}}') + ( '{mean:<{width}}') + ('{median:<{width}}') + ( '{max:<{width}}') + '\n' string += ('{:<{width}}').format('Location', width=COLW_LOC) string += ('{:<{width}}').format('Element', width=COLW_ELEM) else: row_format = ('{elem:<{elem_width}}') + ( '{loc:<{loc_width}}') + ('{min:<{width}}') + ( '{mean:<{width}}') + ('{median:<{width}}') + ( '{max:<{width}}') + '\n' string += ('{:<{width}}').format('Element', width=COLW_ELEM) string += ('{:<{width}}').format('Location', width=COLW_LOC) string += ('{:<{width}}').format('Runtime (ms)', width=COLW_RUNTIME) string += '\n' format_dict = { 'elem_width': COLW_ELEM, 'loc_width': COLW_LOC, 'width': COLW_RUNTIME, 'loc': '', 'elem': '', 'min': 'Min', 'mean': 'Mean', 'median': 'Median', 'max': 'Max' } string += row_format.format(**format_dict) string += line_string if self.sortingType == dtypes.InstrumentationReportPrintType.Location: location_elements = defaultdict(list) for element in element_list: location_elements[element[3]].append(element) for location in sorted(location_elements): elements_list = location_elements[location] sdfg = -1 state = -1 for element in elements_list: events = self._durations[element] if len(events) > 0: with_element_heading = True for event in events.keys(): runtimes = events[event] string, sdfg, state = self._get_runtimes_string( event, runtimes, element, sdfg, state, string, row_format, format_dict, with_element_heading) with_element_heading = False else: sdfg = -1 state = -1 for element in element_list: events = self._durations[element] if len(events) > 0: with_element_heading = True for event in events.keys(): runtimes = events[event] string, sdfg, state = self._get_runtimes_string( event, runtimes, element, sdfg, state, string, row_format, format_dict, with_element_heading) with_element_heading = False runtimes = self._durations[element] string += line_string if len(self._counters) > 0: COUNTER_COLW = 39 counter_format = ('{:<{width}}' * 2) + '\n' string += ('-' * (COUNTER_COLW * 2)) + '\n' string += ('{:<{width}}' * 2).format( 'Counter', 'Value', width=COUNTER_COLW) + '\n' string += ('-' * (COUNTER_COLW * 2)) + '\n' for counter in self._counters: string += counter_format.format(counter, self._counters[counter], width=COUNTER_COLW) string += ('-' * (COUNTER_COLW * 2)) + '\n' return string
class StripMining(transformation.Transformation): """ Implements the strip-mining transformation. Strip-mining takes as input a map dimension and splits it into two dimensions. The new dimension iterates over the range of the original one with a parameterizable step, called the tile size. The original dimension is changed to iterates over the range of the tile size, with the same step as before. """ _map_entry = nodes.MapEntry(nodes.Map("", [], [])) # Properties dim_idx = Property(dtype=int, default=-1, desc="Index of dimension to be strip-mined") new_dim_prefix = Property(dtype=str, default="tile", desc="Prefix for new dimension name") tile_size = SymbolicProperty( default=64, desc="Tile size of strip-mined dimension, " "or number of tiles if tiling_type=number_of_tiles") tile_stride = SymbolicProperty(default=0, desc="Stride between two tiles of the " "strip-mined dimension. If zero, it is set " "equal to the tile size.") tile_offset = SymbolicProperty(default=0, desc="Tile stride offset (negative)") divides_evenly = Property(dtype=bool, default=False, desc="Tile size divides dimension range evenly?") strided = Property( dtype=bool, default=False, desc="Continuous (false) or strided (true) elements in tile") tiling_type = EnumProperty( dtype=dtypes.TilingType, default=dtypes.TilingType.Normal, allow_none=True, desc="normal: the outerloop increments with tile_size, " "ceilrange: uses ceiling(N/tile_size) in outer range, " "number_of_tiles: tiles the map into the number of provided tiles, " "provide the number of tiles over tile_size") skew = Property( dtype=bool, default=False, desc="If True, offsets inner tile back such that it starts with zero") @staticmethod def annotates_memlets(): return True @staticmethod def expressions(): return [ sdutil.node_path_graph(StripMining._map_entry) # kStripMining._tasklet, StripMining._map_exit) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): return True @staticmethod def match_to_str(graph, candidate): map_entry = graph.nodes()[candidate[StripMining._map_entry]] return map_entry.map.label + ': ' + str(map_entry.map.params) def apply(self, sdfg: SDFG) -> nodes.Map: graph = sdfg.nodes()[self.state_id] # Strip-mine selected dimension. _, _, new_map = self._stripmine(sdfg, graph, self.subgraph) return new_map # def __init__(self, tag=True): def __init__(self, *args, **kwargs): self._entry = nodes.EntryNode() self._tasklet = nodes.Tasklet('_') self._exit = nodes.ExitNode() super().__init__(*args, **kwargs) # self.tag = tag @property def entry(self): return self._entry @property def exit(self): return self._exit @property def tasklet(self): return self._tasklet def print_match_pattern(self, candidate): gentry = candidate[self.entry] return str(gentry.map.params[-1]) def _find_new_dim(self, sdfg: SDFG, state: SDFGState, entry: nodes.MapEntry, prefix: str, target_dim: str): """ Finds a variable that is not already defined in scope. """ stree = state.scope_tree() if len(prefix) == 0: return target_dim candidate = '%s_%s' % (prefix, target_dim) index = 1 defined_vars = set( str(s) for s in (state.symbols_defined_at(entry).keys() | sdfg.symbols.keys())) while candidate in defined_vars: candidate = '%s%d_%s' % (prefix, index, target_dim) index += 1 return candidate def _create_strided_range(self, sdfg: SDFG, state: SDFGState, map_entry: nodes.MapEntry): map_exit = state.exit_node(map_entry) dim_idx = self.dim_idx new_dim_prefix = self.new_dim_prefix tile_size = self.tile_size divides_evenly = self.divides_evenly tile_stride = self.tile_stride if tile_stride == 0: tile_stride = tile_size if tile_stride != tile_size: raise NotImplementedError # Retrieve parameter and range of dimension to be strip-mined. target_dim = map_entry.map.params[dim_idx] td_from, td_to, td_step = map_entry.map.range[dim_idx] new_dim = self._find_new_dim(sdfg, state, map_entry, new_dim_prefix, target_dim) new_dim_range = (td_from, td_to, tile_size) new_map = nodes.Map(map_entry.map.label, [new_dim], subsets.Range([new_dim_range])) dimsym = dace.symbolic.pystr_to_symbolic(new_dim) td_from_new = dimsym if divides_evenly: td_to_new = dimsym + tile_size - 1 else: if isinstance(td_to, dace.symbolic.SymExpr): td_to = td_to.expr td_to_new = dace.symbolic.SymExpr( sympy.Min(dimsym + tile_size - 1, td_to), dimsym + tile_size - 1) td_step_new = td_step return new_dim, new_map, (td_from_new, td_to_new, td_step_new) def _create_ceil_range(self, sdfg: SDFG, graph: SDFGState, map_entry: nodes.MapEntry): map_exit = graph.exit_node(map_entry) # Retrieve transformation properties. dim_idx = self.dim_idx new_dim_prefix = self.new_dim_prefix tile_size = self.tile_size divides_evenly = self.divides_evenly strided = self.strided offset = self.tile_offset tile_stride = self.tile_stride if tile_stride == 0: tile_stride = tile_size # Retrieve parameter and range of dimension to be strip-mined. target_dim = map_entry.map.params[dim_idx] td_from, td_to, td_step = map_entry.map.range[dim_idx] # Create new map. Replace by cloning map object? new_dim = self._find_new_dim(sdfg, graph, map_entry, new_dim_prefix, target_dim) nd_from = 0 if tile_stride == 1: nd_to = td_to - td_from else: nd_to = symbolic.pystr_to_symbolic( 'int_ceil(%s + 1 - %s, %s) - 1' % (symbolic.symstr(td_to), symbolic.symstr(td_from), symbolic.symstr(tile_stride))) nd_step = 1 new_dim_range = (nd_from, nd_to, nd_step) new_map = nodes.Map(new_dim + '_' + map_entry.map.label, [new_dim], subsets.Range([new_dim_range])) # Change the range of the selected dimension to iterate over a single # tile if strided: td_from_new = symbolic.pystr_to_symbolic(new_dim) td_to_new_approx = td_to td_step = tile_size elif offset == 0: td_from_new = symbolic.pystr_to_symbolic( '%s + %s * %s' % (symbolic.symstr(td_from), symbolic.symstr(new_dim), symbolic.symstr(tile_stride))) td_to_new_exact = symbolic.pystr_to_symbolic( 'min(%s + 1, %s + %s * %s + %s) - 1' % (symbolic.symstr(td_to), symbolic.symstr(td_from), symbolic.symstr(tile_stride), symbolic.symstr(new_dim), symbolic.symstr(tile_size))) td_to_new_approx = symbolic.pystr_to_symbolic( '%s + %s * %s + %s - 1' % (symbolic.symstr(td_from), symbolic.symstr(tile_stride), symbolic.symstr(new_dim), symbolic.symstr(tile_size))) else: # include offset td_from_new_exact = symbolic.pystr_to_symbolic( 'max(%s,%s + %s * %s - %s)' % (symbolic.symstr(td_from), symbolic.symstr(td_from), symbolic.symstrtr(tile_stride), symbolic.symstr(new_dim), symbolic.symstr(offset))) td_from_new_approx = symbolic.pystr_to_symbolic( '%s + %s * %s - %s ' % (symbolic.symstr(td_from), symbolic.symstr(tile_stride), symbolic.symstr(new_dim), symbolic.symstr(offset))) td_from_new = dace.symbolic.SymExpr(td_from_new_exact, td_from_new_approx) td_to_new_exact = symbolic.pystr_to_symbolic( 'min(%s + 1, %s + %s * %s + %s - %s) -1' % (symbolic.symstr(td_to), symbolic.symstr(td_from), symbolic.symstr(tile_stride), symbolic.symstr(new_dim), symbolic.symstr(tile_size), symbolic.symstr(offset))) td_to_new_approx = symbolic.pystr_to_symbolic( '%s + %s * %s + %s - %s - 1' % (symbolic.symstr(td_from), symbolic.symstr(tile_stride), symbolic.symstr(new_dim), symbolic.symstr(tile_size), symbolic.symstr(offset))) if divides_evenly or strided: td_to_new = td_to_new_approx else: td_to_new = dace.symbolic.SymExpr(td_to_new_exact, td_to_new_approx) return new_dim, new_map, (td_from_new, td_to_new, td_step) def _create_from_tile_numbers(self, sdfg: SDFG, state: SDFGState, map_entry: nodes.MapEntry): map_exit = state.exit_node(map_entry) # Retrieve transformation properties. dim_idx = self.dim_idx new_dim_prefix = self.new_dim_prefix divides_evenly = self.divides_evenly number_of_tiles = self.tile_size tile_stride = self.tile_stride number_of_tiles = dace.symbolic.pystr_to_symbolic(number_of_tiles) # Retrieve parameter and range of dimension to be strip-mined. target_dim = map_entry.map.params[dim_idx] td_from, td_to, td_step = map_entry.map.range[dim_idx] size = map_entry.map.range.size_exact()[dim_idx] if tile_stride != 0: raise NotImplementedError new_dim = self._find_new_dim(sdfg, state, map_entry, new_dim_prefix, target_dim) new_dim_range = (td_from, number_of_tiles - 1, 1) new_map = nodes.Map(map_entry.map.label, [new_dim], subsets.Range([new_dim_range])) dimsym = dace.symbolic.pystr_to_symbolic(new_dim) td_from_new = (dimsym * size) // number_of_tiles if divides_evenly: td_to_new = ((dimsym + 1) * size) // number_of_tiles - 1 else: if isinstance(td_to, dace.symbolic.SymExpr): td_to = td_to.expr td_to_new = dace.symbolic.SymExpr( sympy.Min( ((dimsym + 1) * size) // number_of_tiles, td_to + 1) - 1, ((dimsym + 1) * size) // number_of_tiles - 1) td_step_new = td_step return new_dim, new_map, (td_from_new, td_to_new, td_step_new) def _stripmine(self, sdfg, graph, candidate): # Retrieve map entry and exit nodes. map_entry = graph.nodes()[candidate[StripMining._map_entry]] map_exit = graph.exit_node(map_entry) # Retrieve transformation properties. dim_idx = self.dim_idx target_dim = map_entry.map.params[dim_idx] if self.tiling_type == dtypes.TilingType.CeilRange: new_dim, new_map, td_rng = self._create_ceil_range( sdfg, graph, map_entry) elif self.tiling_type == dtypes.TilingType.NumberOfTiles: new_dim, new_map, td_rng = self._create_from_tile_numbers( sdfg, graph, map_entry) else: new_dim, new_map, td_rng = self._create_strided_range( sdfg, graph, map_entry) new_map_entry = nodes.MapEntry(new_map) new_map_exit = nodes.MapExit(new_map) td_to_new_approx = td_rng[1] if isinstance(td_to_new_approx, dace.symbolic.SymExpr): td_to_new_approx = td_to_new_approx.approx # Special case: If range is 1 and no prefix was specified, skip range if td_rng[0] == td_to_new_approx and target_dim == new_dim: map_entry.map.range = subsets.Range( [r for i, r in enumerate(map_entry.map.range) if i != dim_idx]) map_entry.map.params = [ p for i, p in enumerate(map_entry.map.params) if i != dim_idx ] if len(map_entry.map.params) == 0: raise ValueError('Strip-mining all dimensions of the map with ' 'empty tiles is disallowed') else: map_entry.map.range[dim_idx] = td_rng # Make internal map's schedule to "not parallel" new_map.schedule = map_entry.map.schedule map_entry.map.schedule = dtypes.ScheduleType.Sequential # Redirect edges new_map_entry.in_connectors = dcpy(map_entry.in_connectors) sdutil.change_edge_dest(graph, map_entry, new_map_entry) new_map_exit.out_connectors = dcpy(map_exit.out_connectors) sdutil.change_edge_src(graph, map_exit, new_map_exit) # Create new entry edges new_in_edges = dict() entry_in_conn = {} entry_out_conn = {} for _src, src_conn, _dst, _, memlet in graph.out_edges(map_entry): if (src_conn is not None and src_conn[:4] == 'OUT_' and not isinstance( sdfg.arrays[memlet.data], dace.data.Scalar)): new_subset = calc_set_image( map_entry.map.params, map_entry.map.range, memlet.subset, ) conn = src_conn[4:] key = (memlet.data, 'IN_' + conn, 'OUT_' + conn) if key in new_in_edges.keys(): old_subset = new_in_edges[key].subset new_in_edges[key].subset = calc_set_union( old_subset, new_subset) else: entry_in_conn['IN_' + conn] = None entry_out_conn['OUT_' + conn] = None new_memlet = dcpy(memlet) new_memlet.subset = new_subset if memlet.dynamic: new_memlet.num_accesses = memlet.num_accesses else: new_memlet.num_accesses = new_memlet.num_elements( ).simplify() new_in_edges[key] = new_memlet else: if src_conn is not None and src_conn[:4] == 'OUT_': conn = src_conn[4:] in_conn = 'IN_' + conn out_conn = 'OUT_' + conn else: in_conn = src_conn out_conn = src_conn if in_conn: entry_in_conn[in_conn] = None if out_conn: entry_out_conn[out_conn] = None new_in_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet) new_map_entry.out_connectors = entry_out_conn map_entry.in_connectors = entry_in_conn for (_, in_conn, out_conn), memlet in new_in_edges.items(): graph.add_edge(new_map_entry, out_conn, map_entry, in_conn, memlet) # Create new exit edges new_out_edges = dict() exit_in_conn = {} exit_out_conn = {} for _src, _, _dst, dst_conn, memlet in graph.in_edges(map_exit): if (dst_conn is not None and dst_conn[:3] == 'IN_' and not isinstance( sdfg.arrays[memlet.data], dace.data.Scalar)): new_subset = calc_set_image( map_entry.map.params, map_entry.map.range, memlet.subset, ) conn = dst_conn[3:] key = (memlet.data, 'IN_' + conn, 'OUT_' + conn) if key in new_out_edges.keys(): old_subset = new_out_edges[key].subset new_out_edges[key].subset = calc_set_union( old_subset, new_subset) else: exit_in_conn['IN_' + conn] = None exit_out_conn['OUT_' + conn] = None new_memlet = dcpy(memlet) new_memlet.subset = new_subset if memlet.dynamic: new_memlet.num_accesses = memlet.num_accesses else: new_memlet.num_accesses = new_memlet.num_elements( ).simplify() new_out_edges[key] = new_memlet else: if dst_conn is not None and dst_conn[:3] == 'IN_': conn = dst_conn[3:] in_conn = 'IN_' + conn out_conn = 'OUT_' + conn else: in_conn = dst_conn out_conn = dst_conn if in_conn: exit_in_conn[in_conn] = None if out_conn: exit_out_conn[out_conn] = None new_out_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet) new_map_exit.in_connectors = exit_in_conn map_exit.out_connectors = exit_out_conn for (_, in_conn, out_conn), memlet in new_out_edges.items(): graph.add_edge(map_exit, out_conn, new_map_exit, in_conn, memlet) # Skew if necessary if self.skew: xfh.offset_map(sdfg, graph, map_entry, dim_idx, td_rng[0]) # Return strip-mined dimension. return target_dim, new_dim, new_map