Exemple #1
0
class Map(object):
    """ A Map is a two-node representation of parametric graphs, containing
        an integer set by which the contents (nodes dominated by an entry
        node and post-dominated by an exit node) are replicated.

        Maps contain a `schedule` property, which specifies how the scope
        should be scheduled (execution order). Code generators can use the
        schedule property to generate appropriate code, e.g., GPU kernels.
    """

    # List of (editable) properties
    label = Property(dtype=str, desc="Label of the map")
    params = ListProperty(element_type=str, desc="Mapped parameters")
    range = RangeProperty(desc="Ranges of map parameters",
                          default=sbs.Range([]))
    schedule = EnumProperty(dtype=dtypes.ScheduleType,
                            desc="Map schedule",
                            default=dtypes.ScheduleType.Default)
    unroll = Property(dtype=bool, desc="Map unrolling")
    collapse = Property(dtype=int,
                        default=1,
                        desc="How many dimensions to"
                        " collapse into the parallel range")
    debuginfo = DebugInfoProperty()
    is_collapsed = Property(dtype=bool,
                            desc="Show this node/scope/state as collapsed",
                            default=False)

    instrument = EnumProperty(
        dtype=dtypes.InstrumentationType,
        desc="Measure execution statistics with given method",
        default=dtypes.InstrumentationType.No_Instrumentation)

    def __init__(self,
                 label,
                 params,
                 ndrange,
                 schedule=dtypes.ScheduleType.Default,
                 unroll=False,
                 collapse=1,
                 fence_instrumentation=False,
                 debuginfo=None):
        super(Map, self).__init__()

        # Assign properties
        self.label = label
        self.schedule = schedule
        self.unroll = unroll
        self.collapse = 1
        self.params = params
        self.range = ndrange
        self.debuginfo = debuginfo
        self._fence_instrumentation = fence_instrumentation

    def __str__(self):
        return self.label + "[" + ", ".join([
            "{}={}".format(i, r)
            for i, r in zip(self._params,
                            [sbs.Range.dim_to_string(d) for d in self._range])
        ]) + "]"

    def validate(self, sdfg, state, node):
        if not dtypes.validate_name(self.label):
            raise NameError('Invalid map name "%s"' % self.label)

    def get_param_num(self):
        """ Returns the number of map dimension parameters/symbols. """
        return len(self.params)
Exemple #2
0
class Data(object):
    """ Data type descriptors that can be used as references to memory.
        Examples: Arrays, Streams, custom arrays (e.g., sparse matrices).
    """

    dtype = TypeClassProperty(default=dtypes.int32, choices=dtypes.Typeclasses)
    shape = ShapeProperty(default=[])
    transient = Property(dtype=bool, default=False)
    storage = EnumProperty(dtype=dtypes.StorageType,
                           desc="Storage location",
                           default=dtypes.StorageType.Default)
    lifetime = EnumProperty(dtype=dtypes.AllocationLifetime,
                            desc='Data allocation span',
                            default=dtypes.AllocationLifetime.Scope)
    location = DictProperty(
        key_type=str,
        value_type=symbolic.pystr_to_symbolic,
        desc='Full storage location identifier (e.g., rank, GPU ID)')
    debuginfo = DebugInfoProperty(allow_none=True)

    def __init__(self, dtype, shape, transient, storage, location, lifetime,
                 debuginfo):
        self.dtype = dtype
        self.shape = shape
        self.transient = transient
        self.storage = storage
        self.location = location if location is not None else {}
        self.lifetime = lifetime
        self.debuginfo = debuginfo
        self._validate()

    def validate(self):
        """ Validate the correctness of this object.
            Raises an exception on error. """
        self._validate()

    # Validation of this class is in a separate function, so that this
    # class can call `_validate()` without calling the subclasses'
    # `validate` function.
    def _validate(self):
        if any(not isinstance(s, (int, symbolic.SymExpr, symbolic.symbol,
                                  symbolic.sympy.Basic)) for s in self.shape):
            raise TypeError('Shape must be a list or tuple of integer values '
                            'or symbols')
        return True

    def to_json(self):
        attrs = serialize.all_properties_to_json(self)

        retdict = {"type": type(self).__name__, "attributes": attrs}

        return retdict

    @property
    def toplevel(self):
        return self.lifetime is not dtypes.AllocationLifetime.Scope

    def copy(self):
        raise RuntimeError(
            'Data descriptors are unique and should not be copied')

    def is_equivalent(self, other):
        """ Check for equivalence (shape and type) of two data descriptors. """
        raise NotImplementedError

    def as_arg(self, with_types=True, for_call=False, name=None):
        """Returns a string for a C++ function signature (e.g., `int *A`). """
        raise NotImplementedError

    @property
    def free_symbols(self) -> Set[symbolic.SymbolicType]:
        """ Returns a set of undefined symbols in this data descriptor. """
        result = set()
        for s in self.shape:
            if isinstance(s, sp.Basic):
                result |= set(s.free_symbols)
        return result

    def __repr__(self):
        return 'Abstract Data Container, DO NOT USE'

    @property
    def veclen(self):
        return self.dtype.veclen if hasattr(self.dtype, "veclen") else 1

    @property
    def ctype(self):
        return self.dtype.ctype
Exemple #3
0
class Tasklet(CodeNode):
    """ A node that contains a tasklet: a functional computation procedure
        that can only access external data specified using connectors.

        Tasklets may be implemented in Python, C++, or any supported
        language by the code generator.
    """

    code = CodeProperty(desc="Tasklet code", default=CodeBlock(""))
    state_fields = ListProperty(
        element_type=str, desc="Fields that are added to the global state")
    code_global = CodeProperty(
        desc="Global scope code needed for tasklet execution",
        default=CodeBlock("", dtypes.Language.CPP))
    code_init = CodeProperty(
        desc="Extra code that is called on DaCe runtime initialization",
        default=CodeBlock("", dtypes.Language.CPP))
    code_exit = CodeProperty(
        desc="Extra code that is called on DaCe runtime cleanup",
        default=CodeBlock("", dtypes.Language.CPP))
    debuginfo = DebugInfoProperty()

    instrument = EnumProperty(
        dtype=dtypes.InstrumentationType,
        desc="Measure execution statistics with given method",
        default=dtypes.InstrumentationType.No_Instrumentation)

    def __init__(self,
                 label,
                 inputs=None,
                 outputs=None,
                 code="",
                 language=dtypes.Language.Python,
                 state_fields=None,
                 code_global="",
                 code_init="",
                 code_exit="",
                 location=None,
                 debuginfo=None):
        super(Tasklet, self).__init__(label, location, inputs, outputs)

        self.code = CodeBlock(code, language)

        self.state_fields = state_fields or []
        self.code_global = CodeBlock(code_global, dtypes.Language.CPP)
        self.code_init = CodeBlock(code_init, dtypes.Language.CPP)
        self.code_exit = CodeBlock(code_exit, dtypes.Language.CPP)
        self.debuginfo = debuginfo

    @property
    def language(self):
        return self.code.language

    @staticmethod
    def from_json(json_obj, context=None):
        ret = Tasklet("dummylabel")
        dace.serialize.set_properties_from_json(ret, json_obj, context=context)
        return ret

    @property
    def name(self):
        return self._label

    def validate(self, sdfg, state):
        if not dtypes.validate_name(self.label):
            raise NameError('Invalid tasklet name "%s"' % self.label)
        for in_conn in self.in_connectors:
            if not dtypes.validate_name(in_conn):
                raise NameError('Invalid input connector "%s"' % in_conn)
        for out_conn in self.out_connectors:
            if not dtypes.validate_name(out_conn):
                raise NameError('Invalid output connector "%s"' % out_conn)

    @property
    def free_symbols(self) -> Set[str]:
        return self.code.get_free_symbols(self.in_connectors.keys()
                                          | self.out_connectors.keys())

    def infer_connector_types(self, sdfg, state):
        # If a MLIR tasklet, simply read out the types (it's explicit)
        if self.code.language == dtypes.Language.MLIR:
            # Inline import because mlir.utils depends on pyMLIR which may not be installed
            # Doesn't cause crashes due to missing pyMLIR if a MLIR tasklet is not present
            from dace.codegen.targets.mlir import utils

            mlir_ast = utils.get_ast(self.code.code)
            mlir_is_generic = utils.is_generic(mlir_ast)
            mlir_entry_func = utils.get_entry_func(mlir_ast, mlir_is_generic)

            mlir_result_type = utils.get_entry_result_type(
                mlir_entry_func, mlir_is_generic)
            mlir_out_name = next(iter(self.out_connectors.keys()))

            if self.out_connectors[
                    mlir_out_name] is None or self.out_connectors[
                        mlir_out_name].ctype == "void":
                self.out_connectors[mlir_out_name] = utils.get_dace_type(
                    mlir_result_type)
            elif self.out_connectors[mlir_out_name] != utils.get_dace_type(
                    mlir_result_type):
                warnings.warn(
                    "Type mismatch between MLIR tasklet out connector and MLIR code"
                )

            for mlir_arg in utils.get_entry_args(mlir_entry_func,
                                                 mlir_is_generic):
                if self.in_connectors[
                        mlir_arg[0]] is None or self.in_connectors[
                            mlir_arg[0]].ctype == "void":
                    self.in_connectors[mlir_arg[0]] = utils.get_dace_type(
                        mlir_arg[1])
                elif self.in_connectors[mlir_arg[0]] != utils.get_dace_type(
                        mlir_arg[1]):
                    warnings.warn(
                        "Type mismatch between MLIR tasklet in connector and MLIR code"
                    )

            return

        # If a Python tasklet, use type inference to figure out all None output
        # connectors
        if all(cval.type is not None for cval in self.out_connectors.values()):
            return
        if self.code.language != dtypes.Language.Python:
            return

        if any(cval.type is None for cval in self.in_connectors.values()):
            raise TypeError('Cannot infer output connectors of tasklet "%s", '
                            'not all input connectors have types' % str(self))

        # Avoid import loop
        from dace.codegen.tools.type_inference import infer_types

        # Get symbols defined at beginning of node, and infer all types in
        # tasklet
        syms = state.symbols_defined_at(self)
        syms.update(self.in_connectors)
        new_syms = infer_types(self.code.code, syms)
        for cname, oconn in self.out_connectors.items():
            if oconn.type is None:
                if cname not in new_syms:
                    raise TypeError('Cannot infer type of tasklet %s output '
                                    '"%s", please specify manually.' %
                                    (self.label, cname))
                self.out_connectors[cname] = new_syms[cname]

    def __str__(self):
        if not self.label:
            return "--Empty--"
        else:
            return self.label
Exemple #4
0
class NestedSDFG(CodeNode):
    """ An SDFG state node that contains an SDFG of its own, runnable using
        the data dependencies specified using its connectors.

        It is encouraged to use nested SDFGs instead of coarse-grained tasklets
        since they are analyzable with respect to transformations.

        :note: A nested SDFG cannot create recursion (one of its parent SDFGs).
    """

    # NOTE: We cannot use SDFG as the type because of an import loop
    sdfg = SDFGReferenceProperty(desc="The SDFG", allow_none=True)
    schedule = EnumProperty(dtype=dtypes.ScheduleType,
                            desc="SDFG schedule",
                            allow_none=True,
                            default=dtypes.ScheduleType.Default)
    symbol_mapping = DictProperty(
        key_type=str,
        value_type=dace.symbolic.pystr_to_symbolic,
        desc="Mapping between internal symbols and their values, expressed as "
        "symbolic expressions")
    debuginfo = DebugInfoProperty()
    is_collapsed = Property(dtype=bool,
                            desc="Show this node/scope/state as collapsed",
                            default=False)

    instrument = EnumProperty(
        dtype=dtypes.InstrumentationType,
        desc="Measure execution statistics with given method",
        default=dtypes.InstrumentationType.No_Instrumentation)

    no_inline = Property(
        dtype=bool,
        desc="If True, this nested SDFG will not be inlined during "
        "simplification",
        default=False)

    unique_name = Property(dtype=str,
                           desc="Unique name of the SDFG",
                           default="")

    def __init__(self,
                 label,
                 sdfg,
                 inputs: Set[str],
                 outputs: Set[str],
                 symbol_mapping: Dict[str, Any] = None,
                 schedule=dtypes.ScheduleType.Default,
                 location=None,
                 debuginfo=None):
        from dace.sdfg import SDFG
        super(NestedSDFG, self).__init__(label, location, inputs, outputs)

        # Properties
        self.sdfg: SDFG = sdfg
        self.symbol_mapping = symbol_mapping or {}
        self.schedule = schedule
        self.debuginfo = debuginfo

    @staticmethod
    def from_json(json_obj, context=None):
        from dace import SDFG  # Avoid import loop

        # We have to load the SDFG first.
        ret = NestedSDFG("nolabel", SDFG('nosdfg'), {}, {})

        dace.serialize.set_properties_from_json(ret, json_obj, context)

        if context and 'sdfg_state' in context:
            ret.sdfg.parent = context['sdfg_state']
        if context and 'sdfg' in context:
            ret.sdfg.parent_sdfg = context['sdfg']

        ret.sdfg.parent_nsdfg_node = ret

        ret.sdfg.update_sdfg_list([])

        return ret

    @property
    def free_symbols(self) -> Set[str]:
        return set().union(
            *(map(str,
                  pystr_to_symbolic(v).free_symbols)
              for v in self.symbol_mapping.values()),
            *(map(str,
                  pystr_to_symbolic(v).free_symbols)
              for v in self.location.values()))

    def infer_connector_types(self, sdfg, state):
        # Avoid import loop
        from dace.sdfg.infer_types import infer_connector_types
        # Infer internal connector types
        infer_connector_types(self.sdfg)

    def __str__(self):
        if not self.label:
            return "SDFG"
        else:
            return self.label

    def validate(self, sdfg, state):
        if not dtypes.validate_name(self.label):
            raise NameError('Invalid nested SDFG name "%s"' % self.label)
        for in_conn in self.in_connectors:
            if not dtypes.validate_name(in_conn):
                raise NameError('Invalid input connector "%s"' % in_conn)
        for out_conn in self.out_connectors:
            if not dtypes.validate_name(out_conn):
                raise NameError('Invalid output connector "%s"' % out_conn)
        connectors = self.in_connectors.keys() | self.out_connectors.keys()
        for dname, desc in self.sdfg.arrays.items():
            # TODO(later): Disallow scalars without access nodes (so that this
            #              check passes for them too).
            if isinstance(desc, data.Scalar):
                continue
            if not desc.transient and dname not in connectors:
                raise NameError('Data descriptor "%s" not found in nested '
                                'SDFG connectors' % dname)
            if dname in connectors and desc.transient:
                raise NameError(
                    '"%s" is a connector but its corresponding array is transient'
                    % dname)

        # Validate undefined symbols
        symbols = set(k for k in self.sdfg.free_symbols if k not in connectors)
        missing_symbols = [s for s in symbols if s not in self.symbol_mapping]
        if missing_symbols:
            raise ValueError('Missing symbols on nested SDFG: %s' %
                             (missing_symbols))
        extra_symbols = self.symbol_mapping.keys() - symbols
        if len(extra_symbols) > 0:
            # TODO: Elevate to an error?
            warnings.warn(
                f"{self.label} maps to unused symbol(s): {extra_symbols}")

        # Recursively validate nested SDFG
        self.sdfg.validate()
Exemple #5
0
class LibraryNode(CodeNode):

    name = Property(dtype=str, desc="Name of node")
    implementation = LibraryImplementationProperty(
        dtype=str,
        allow_none=True,
        desc=("Which implementation this library node will expand into."
              "Must match a key in the list of possible implementations."))
    schedule = EnumProperty(
        dtype=dtypes.ScheduleType,
        desc="If set, determines the default device mapping of "
        "the node upon expansion, if expanded to a nested SDFG.",
        default=dtypes.ScheduleType.Default)
    debuginfo = DebugInfoProperty()

    def __init__(self, name, *args, schedule=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.name = name
        self.label = name
        self.schedule = schedule or dtypes.ScheduleType.Default

    # Overrides subclasses to return LibraryNode as their JSON type
    @property
    def __jsontype__(self):
        return 'LibraryNode'

    def to_json(self, parent):
        jsonobj = super().to_json(parent)
        jsonobj['classpath'] = full_class_path(self)
        return jsonobj

    @classmethod
    def from_json(cls, json_obj, context=None):
        if cls == LibraryNode:
            clazz = pydoc.locate(json_obj['classpath'])
            if clazz is None:
                return UnregisteredLibraryNode.from_json(json_obj, context)
            return clazz.from_json(json_obj, context)
        else:  # Subclasses are actual library nodes
            ret = cls(json_obj['attributes']['name'])
            dace.serialize.set_properties_from_json(ret,
                                                    json_obj,
                                                    context=context)
            return ret

    def expand(self, sdfg, state, *args, **kwargs) -> str:
        """ Create and perform the expansion transformation for this library
            node.
            :return: the name of the expanded implementation
        """
        implementation = self.implementation
        library_name = getattr(type(self), '_dace_library_name', '')
        try:
            if library_name:
                config_implementation = Config.get("library", library_name,
                                                   "default_implementation")
            else:
                config_implementation = None
        except KeyError:
            # Non-standard libraries are not defined in the config schema, and
            # thus might not exist in the config.
            config_implementation = None
        if config_implementation is not None:
            try:
                config_override = Config.get("library", library_name,
                                             "override")
                if config_override and implementation in self.implementations:
                    if implementation is not None:
                        warnings.warn(
                            "Overriding explicitly specified "
                            "implementation {} for {} with {}.".format(
                                implementation, self.label,
                                config_implementation))
                    implementation = config_implementation
            except KeyError:
                config_override = False
        # If not explicitly set, try the node default
        if implementation is None:
            implementation = type(self).default_implementation
            # If no node default, try library default
            if implementation is None:
                import dace.library  # Avoid cyclic dependency
                lib = dace.library._DACE_REGISTERED_LIBRARIES[type(
                    self)._dace_library_name]
                implementation = lib.default_implementation
                # Try the default specified in the config
                if implementation is None:
                    implementation = config_implementation
                    # Otherwise we don't know how to expand
                    if implementation is None:
                        raise ValueError("No implementation or default "
                                         "implementation specified.")
        if implementation not in self.implementations.keys():
            raise KeyError("Unknown implementation for node {}: {}".format(
                type(self).__name__, implementation))
        transformation_type = type(self).implementations[implementation]
        sdfg_id = sdfg.sdfg_id
        state_id = sdfg.nodes().index(state)
        subgraph = {transformation_type._match_node: state.node_id(self)}
        transformation = transformation_type(sdfg, sdfg_id, state_id, subgraph,
                                             0)
        if not transformation.can_be_applied(state, 0, sdfg):
            raise RuntimeError("Library node "
                               "expansion applicability check failed.")
        sdfg.append_transformation(transformation)
        transformation.apply(state, sdfg, *args, **kwargs)
        return implementation

    @classmethod
    def register_implementation(cls, name, transformation_type):
        """Register an implementation to belong to this library node type."""
        cls.implementations[name] = transformation_type
        transformation_type._match_node = cls
Exemple #6
0
class AccessNode(Node):
    """ A node that accesses data in the SDFG. Denoted by a circular shape. """

    access = EnumProperty(dtype=dtypes.AccessType,
                          desc="Type of access to this array",
                          default=dtypes.AccessType.ReadWrite)
    setzero = Property(dtype=bool, desc="Initialize to zero", default=False)
    debuginfo = DebugInfoProperty()
    data = DataProperty(desc="Data (array, stream, scalar) to access")

    def __init__(self,
                 data,
                 access=dtypes.AccessType.ReadWrite,
                 debuginfo=None):
        super(AccessNode, self).__init__()

        # Properties
        self.debuginfo = debuginfo
        self.access = access
        if not isinstance(data, str):
            raise TypeError('Data for AccessNode must be a string')
        self.data = data

    @staticmethod
    def from_json(json_obj, context=None):
        ret = AccessNode("Nodata")
        dace.serialize.set_properties_from_json(ret, json_obj, context=context)
        return ret

    def __deepcopy__(self, memo):
        node = object.__new__(AccessNode)
        node._access = self._access
        node._data = self._data
        node._setzero = self._setzero
        node._in_connectors = dcpy(self._in_connectors, memo=memo)
        node._out_connectors = dcpy(self._out_connectors, memo=memo)
        node._debuginfo = dcpy(self._debuginfo, memo=memo)
        return node

    @property
    def label(self):
        return self.data

    def __label__(self, sdfg, state):
        return self.data

    def desc(self, sdfg):
        from dace.sdfg import SDFGState, ScopeSubgraphView
        if isinstance(sdfg, (SDFGState, ScopeSubgraphView)):
            sdfg = sdfg.parent
        return sdfg.arrays[self.data]

    def validate(self, sdfg, state):
        if self.data not in sdfg.arrays:
            raise KeyError('Array "%s" not found in SDFG' % self.data)

    def has_writes(self, state):
        for e in state.in_edges(self):
            if not e.data.is_empty():
                return True
        return False

    def has_reads(self, state):
        for e in state.out_edges(self):
            if not e.data.is_empty():
                return True
        return False
Exemple #7
0
class Consume(object):
    """ Consume is a scope, like `Map`, that is a part of the parametric
        graph extension of the SDFG. It creates a producer-consumer
        relationship between the input stream and the scope subgraph. The
        subgraph is scheduled to a given number of processing elements
        for processing, and they will try to pop elements from the input
        stream until a given quiescence condition is reached. """

    # Properties
    label = Property(dtype=str, desc="Name of the consume node")
    pe_index = Property(dtype=str, desc="Processing element identifier")
    num_pes = SymbolicProperty(desc="Number of processing elements", default=1)
    condition = CodeProperty(desc="Quiescence condition", allow_none=True)
    schedule = EnumProperty(dtype=dtypes.ScheduleType,
                            desc="Consume schedule",
                            default=dtypes.ScheduleType.Default)
    chunksize = Property(dtype=int,
                         desc="Maximal size of elements to consume at a time",
                         default=1)
    debuginfo = DebugInfoProperty()
    is_collapsed = Property(dtype=bool,
                            desc="Show this node/scope/state as collapsed",
                            default=False)

    instrument = EnumProperty(
        dtype=dtypes.InstrumentationType,
        desc="Measure execution statistics with given method",
        default=dtypes.InstrumentationType.No_Instrumentation)

    def as_map(self):
        """ Compatibility function that allows to view the consume as a map,
            mainly in memlet propagation. """
        return Map(self.label, [self.pe_index],
                   sbs.Range([(0, self.num_pes - 1, 1)]), self.schedule)

    def __init__(self,
                 label,
                 pe_tuple,
                 condition,
                 schedule=dtypes.ScheduleType.Default,
                 chunksize=1,
                 debuginfo=None):
        super(Consume, self).__init__()

        # Properties
        self.label = label
        self.pe_index, self.num_pes = pe_tuple
        self.condition = condition
        self.schedule = schedule
        self.chunksize = chunksize
        self.debuginfo = debuginfo

    def __str__(self):
        if self.condition is not None:
            return ("%s [%s=0:%s], Condition: %s" %
                    (self._label, self.pe_index, self.num_pes,
                     CodeProperty.to_string(self.condition)))
        else:
            return ("%s [%s=0:%s]" %
                    (self._label, self.pe_index, self.num_pes))

    def validate(self, sdfg, state, node):
        if not dtypes.validate_name(self.label):
            raise NameError('Invalid consume name "%s"' % self.label)

    def get_param_num(self):
        """ Returns the number of consume dimension parameters/symbols. """
        return 1
Exemple #8
0
class Data(object):
    """ Data type descriptors that can be used as references to memory.
        Examples: Arrays, Streams, custom arrays (e.g., sparse matrices).
    """

    dtype = TypeClassProperty(default=dtypes.int32, choices=dtypes.Typeclasses)
    shape = ShapeProperty(default=[])
    transient = Property(dtype=bool, default=False)
    storage = EnumProperty(dtype=dtypes.StorageType, desc="Storage location", default=dtypes.StorageType.Default)
    lifetime = EnumProperty(dtype=dtypes.AllocationLifetime,
                            desc='Data allocation span',
                            default=dtypes.AllocationLifetime.Scope)
    location = DictProperty(key_type=str, value_type=str, desc='Full storage location identifier (e.g., rank, GPU ID)')
    debuginfo = DebugInfoProperty(allow_none=True)

    def __init__(self, dtype, shape, transient, storage, location, lifetime, debuginfo):
        self.dtype = dtype
        self.shape = shape
        self.transient = transient
        self.storage = storage
        self.location = location if location is not None else {}
        self.lifetime = lifetime
        self.debuginfo = debuginfo
        self._validate()

    def validate(self):
        """ Validate the correctness of this object.
            Raises an exception on error. """
        self._validate()

    # Validation of this class is in a separate function, so that this
    # class can call `_validate()` without calling the subclasses'
    # `validate` function.
    def _validate(self):
        if any(not isinstance(s, (int, symbolic.SymExpr, symbolic.symbol, symbolic.sympy.Basic)) for s in self.shape):
            raise TypeError('Shape must be a list or tuple of integer values ' 'or symbols')
        return True

    def to_json(self):
        attrs = serialize.all_properties_to_json(self)

        retdict = {"type": type(self).__name__, "attributes": attrs}

        return retdict

    @property
    def toplevel(self):
        return self.lifetime is not dtypes.AllocationLifetime.Scope

    def copy(self):
        raise RuntimeError('Data descriptors are unique and should not be copied')

    def is_equivalent(self, other):
        """ Check for equivalence (shape and type) of two data descriptors. """
        raise NotImplementedError

    def as_arg(self, with_types=True, for_call=False, name=None):
        """Returns a string for a C++ function signature (e.g., `int *A`). """
        raise NotImplementedError

    @property
    def free_symbols(self) -> Set[symbolic.SymbolicType]:
        """ Returns a set of undefined symbols in this data descriptor. """
        result = set()
        for s in self.shape:
            if isinstance(s, sp.Basic):
                result |= set(s.free_symbols)
        return result

    def __repr__(self):
        return 'Abstract Data Container, DO NOT USE'

    @property
    def veclen(self):
        return self.dtype.veclen if hasattr(self.dtype, "veclen") else 1

    @property
    def ctype(self):
        return self.dtype.ctype

    def strides_from_layout(
        self,
        *dimensions: int,
        alignment: symbolic.SymbolicType = 1,
        only_first_aligned: bool = False,
    ) -> Tuple[Tuple[symbolic.SymbolicType], symbolic.SymbolicType]:
        """
        Returns the absolute strides and total size of this data descriptor,
        according to the given dimension ordering and alignment.
        :param dimensions: A sequence of integers representing a permutation
                           of the descriptor's dimensions.
        :param alignment: Padding (in elements) at the end, ensuring stride
                          is a multiple of this number. 1 (default) means no
                          padding.
        :param only_first_aligned: If True, only the first dimension is padded
                                   with ``alignment``. Otherwise all dimensions
                                   are.
        :return: A 2-tuple of (tuple of strides, total size).
        """
        # Verify dimensions
        if tuple(sorted(dimensions)) != tuple(range(len(self.shape))):
            raise ValueError('Every dimension must be given and appear once.')
        if (alignment < 1) == True or (alignment < 0) == True:
            raise ValueError('Invalid alignment value')

        strides = [1] * len(dimensions)
        total_size = 1
        first = True
        for dim in dimensions:
            strides[dim] = total_size
            if not only_first_aligned or first:
                dimsize = (((self.shape[dim] + alignment - 1) // alignment) * alignment)
            else:
                dimsize = self.shape[dim]
            total_size *= dimsize
            first = False

        return (tuple(strides), total_size)

    def set_strides_from_layout(self,
                                *dimensions: int,
                                alignment: symbolic.SymbolicType = 1,
                                only_first_aligned: bool = False):
        """
        Sets the absolute strides and total size of this data descriptor,
        according to the given dimension ordering and alignment.
        :param dimensions: A sequence of integers representing a permutation
                           of the descriptor's dimensions.
        :param alignment: Padding (in elements) at the end, ensuring stride
                          is a multiple of this number. 1 (default) means no
                          padding.
        :param only_first_aligned: If True, only the first dimension is padded
                                   with ``alignment``. Otherwise all dimensions
                                   are.
        """
        strides, totalsize = self.strides_from_layout(*dimensions,
                                                      alignment=alignment,
                                                      only_first_aligned=only_first_aligned)
        self.strides = strides
        self.total_size = totalsize
Exemple #9
0
class CompositeFusion(transformation.SubgraphTransformation):
    """ MultiExpansion + SubgraphFusion in one Transformation
        Additional StencilTiling is also possible as a canonicalizing
        transformation before fusion.
    """

    debug = Property(desc="Debug mode", dtype=bool, default=False)

    allow_expansion = Property(desc="Allow MultiExpansion first",
                               dtype=bool,
                               default=True)

    allow_tiling = Property(desc="Allow StencilTiling (after MultiExpansion)",
                            dtype=bool,
                            default=False)

    transient_allocation = EnumProperty(
        desc="Storage Location to push transients to that are "
        "fully contained within the subgraph.",
        dtype=dtypes.StorageType,
        default=dtypes.StorageType.Default)

    schedule_innermaps = Property(desc="Schedule of inner fused maps",
                                  dtype=dtypes.ScheduleType,
                                  default=None,
                                  allow_none=True)

    stencil_unroll_loops = Property(
        desc="Unroll inner stencil loops if they have size > 1",
        dtype=bool,
        default=False)
    stencil_strides = ShapeProperty(dtype=tuple,
                                    default=(1, ),
                                    desc="Stencil tile stride")

    expansion_split = Property(
        desc="Allow MultiExpansion to split up maps, if enabled",
        dtype=bool,
        default=True)

    def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool:
        graph = subgraph.graph
        if self.allow_expansion == True:
            subgraph_fusion = SubgraphFusion(subgraph)
            if subgraph_fusion.can_be_applied(sdfg, subgraph):
                # try w/o copy first
                return True

            expansion = MultiExpansion(subgraph)
            expansion.permutation_only = not self.expansion_split
            if expansion.can_be_applied(sdfg, subgraph):
                # deepcopy
                graph_indices = [
                    i for (i, n) in enumerate(graph.nodes()) if n in subgraph
                ]
                sdfg_copy = SDFG.from_json(sdfg.to_json())
                graph_copy = sdfg_copy.nodes()[sdfg.nodes().index(graph)]
                subgraph_copy = SubgraphView(
                    graph_copy, [graph_copy.nodes()[i] for i in graph_indices])

                ##sdfg_copy.apply_transformations(MultiExpansion, states=[graph])
                #expansion = MultiExpansion(subgraph_copy)
                expansion.apply(sdfg_copy)

                subgraph_fusion = SubgraphFusion(subgraph_copy)
                if subgraph_fusion.can_be_applied(sdfg_copy, subgraph_copy):
                    return True

                stencil_tiling = StencilTiling(subgraph_copy)
                if self.allow_tiling and stencil_tiling.can_be_applied(
                        sdfg_copy, subgraph_copy):
                    return True

        else:
            subgraph_fusion = SubgraphFusion(subgraph)
            if subgraph_fusion.can_be_applied(sdfg, subgraph):
                return True

        if self.allow_tiling == True:
            stencil_tiling = StencilTiling(subgraph)
            if stencil_tiling.can_be_applied(sdfg, subgraph):
                return True

        return False

    def apply(self, sdfg):
        subgraph = self.subgraph_view(sdfg)
        graph = subgraph.graph
        scope_dict = graph.scope_dict()
        map_entries = helpers.get_outermost_scope_maps(sdfg, graph, subgraph,
                                                       scope_dict)
        first_entry = next(iter(map_entries))

        if self.allow_expansion:
            expansion = MultiExpansion(subgraph, self.sdfg_id, self.state_id)
            expansion.permutation_only = not self.expansion_split
            if expansion.can_be_applied(sdfg, subgraph):
                expansion.apply(sdfg)

        sf = SubgraphFusion(subgraph, self.sdfg_id, self.state_id)
        if sf.can_be_applied(sdfg, self.subgraph_view(sdfg)):
            # set SubgraphFusion properties
            sf.debug = self.debug
            sf.transient_allocation = self.transient_allocation
            sf.schedule_innermaps = self.schedule_innermaps
            sf.apply(sdfg)
            self._global_map_entry = sf._global_map_entry
            return

        elif self.allow_tiling == True:
            st = StencilTiling(subgraph, self.sdfg_id, self.state_id)
            if st.can_be_applied(sdfg, self.subgraph_view(sdfg)):
                # set StencilTiling properties
                st.debug = self.debug
                st.unroll_loops = self.stencil_unroll_loops
                st.strides = self.stencil_strides
                st.apply(sdfg)
                # StencilTiling: update nodes
                new_entries = st._outer_entries
                subgraph = helpers.subgraph_from_maps(sdfg, graph, new_entries)
                sf = SubgraphFusion(subgraph, self.sdfg_id, self.state_id)
                # set SubgraphFusion properties
                sf.debug = self.debug
                sf.transient_allocation = self.transient_allocation
                sf.schedule_innermaps = self.schedule_innermaps

                sf.apply(sdfg)
                self._global_map_entry = sf._global_map_entry
                return

        warnings.warn("CompositeFusion::Apply did not perform as expected")
Exemple #10
0
class InstrumentationReport(object):

    sortingType = EnumProperty(
        dtype=dtypes.InstrumentationReportPrintType,
        default=dtypes.InstrumentationReportPrintType.SDFG,
        desc="SDFG: Sorting: SDFG, State, Node, Location"
        "Location:  Sorting: Location, SDFG, State, Node")

    @staticmethod
    def get_event_uuid(event):
        try:
            args = event['args']
        except KeyError:
            return (-1, -1, -1, -1)

        uuid = (args.get('sdfg_id', -1), args.get('state_id', -1),
                args.get('id', -1), args.get('loc_id', -1))

        return uuid

    def __init__(self,
                 filename: str,
                 sortingType: dtypes.InstrumentationReportPrintType = dtypes.
                 InstrumentationReportPrintType.SDFG):
        self.sortingType = sortingType
        # Parse file
        match = re.match(r'.*report-(\d+)\.json', filename)
        self._name = match.groups()[0] if match is not None else 'N/A'

        self._durations = {}
        self._counters = {}

        with open(filename, 'r') as fp:
            report = json.load(fp)

            if 'traceEvents' not in report or 'sdfgHash' not in report:
                print(filename, 'is not a valid SDFG instrumentation report!')
                return

            self._sdfg_hash = report['sdfgHash']

            events = report['traceEvents']

            for event in events:
                if 'ph' in event:
                    phase = event['ph']
                    name = event['name']
                    if phase == 'X':
                        uuid = self.get_event_uuid(event)
                        if uuid not in self._durations:
                            self._durations[uuid] = {}
                        if name not in self._durations[uuid]:
                            self._durations[uuid][name] = []
                        self._durations[uuid][name].append(event['dur'] / 1000)
                    if phase == 'C':
                        if name not in self.counters:
                            self.counters[name] = 0
                        self.counters[name] += event['args'][name]

    def __repr__(self):
        return 'InstrumentationReport(name=%s)' % self._name

    def _get_runtimes_string(self,
                             label,
                             runtimes,
                             element,
                             sdfg,
                             state,
                             string,
                             row_format,
                             format_dict,
                             with_element_heading=True):
        location_label = f'Device: {element[3]}' if element[3] != -1 else 'CPU'
        indent = ''
        if len(runtimes) > 0:
            element_label = ''
            format_dict['loc'] = ''
            format_dict['min'] = ''
            format_dict['mean'] = ''
            format_dict['median'] = ''
            format_dict['max'] = ''
            if element[0] > -1 and element[1] > -1 and element[2] > -1:

                # This element is a node.
                if sdfg != element[0]:
                    # No parent SDFG row present yet.
                    format_dict['elem'] = 'SDFG (' + str(element[0]) + ')'
                sdfg = element[0]

                if state != element[1]:
                    # No parent state row present yet.
                    format_dict['elem'] = '|-State (' + str(element[1]) + ')'

                # Print
                string += row_format.format(**format_dict)
                state = element[1]
                element_label = '| |-Node (' + str(element[2]) + ')'
                indent = '| | |'
            elif element[0] > -1 and element[1] > -1:
                # This element is a state.
                if sdfg != element[0]:
                    # No parent SDFG row present yet, print it.
                    format_dict['elem'] = 'SDFG (' + str(element[0]) + ')'
                    string += row_format.format(**format_dict)

                sdfg = element[0]
                state = element[1]
                element_label = '|-State (' + str(element[1]) + ')'
                indent = '| |'
            elif element[0] > -1:
                # This element is an SDFG.
                sdfg = element[0]
                state = -1
                element_label = 'SDFG (' + str(element[0]) + ')'
                indent = '|'
            else:
                element_label = 'N/A'

            if with_element_heading:
                format_dict['elem'] = element_label
                string += row_format.format(**format_dict)

            format_dict['elem'] = indent + label + ':'
            string += row_format.format(**format_dict)

            format_dict['elem'] = indent
            format_dict['loc'] = location_label
            format_dict['min'] = '%.3f' % np.min(runtimes)
            format_dict['mean'] = '%.3f' % np.mean(runtimes)
            format_dict['median'] = '%.3f' % np.median(runtimes)
            format_dict['max'] = '%.3f' % np.max(runtimes)
            string += row_format.format(**format_dict)

        return string, sdfg, state

    def __str__(self):
        element_list = list(self._durations.keys())
        element_list.sort()

        string = 'Instrumentation report\n'
        string += 'SDFG Hash: ' + self._sdfg_hash + '\n'

        if len(self._durations) > 0:
            COLW_ELEM = 30
            COLW_LOC = 15
            COLW_RUNTIME = 15
            NUM_RUNTIME_COLS = 4

            line_string = ('-' * (COLW_RUNTIME * NUM_RUNTIME_COLS + COLW_ELEM +
                                  COLW_LOC)) + '\n'

            string += line_string

            if self.sortingType == dtypes.InstrumentationReportPrintType.Location:
                row_format = ('{loc:<{loc_width}}') + (
                    '{elem:<{elem_width}}') + ('{min:<{width}}') + (
                        '{mean:<{width}}') + ('{median:<{width}}') + (
                            '{max:<{width}}') + '\n'
                string += ('{:<{width}}').format('Location', width=COLW_LOC)
                string += ('{:<{width}}').format('Element', width=COLW_ELEM)
            else:
                row_format = ('{elem:<{elem_width}}') + (
                    '{loc:<{loc_width}}') + ('{min:<{width}}') + (
                        '{mean:<{width}}') + ('{median:<{width}}') + (
                            '{max:<{width}}') + '\n'
                string += ('{:<{width}}').format('Element', width=COLW_ELEM)
                string += ('{:<{width}}').format('Location', width=COLW_LOC)

            string += ('{:<{width}}').format('Runtime (ms)',
                                             width=COLW_RUNTIME)
            string += '\n'
            format_dict = {
                'elem_width': COLW_ELEM,
                'loc_width': COLW_LOC,
                'width': COLW_RUNTIME,
                'loc': '',
                'elem': '',
                'min': 'Min',
                'mean': 'Mean',
                'median': 'Median',
                'max': 'Max'
            }
            string += row_format.format(**format_dict)
            string += line_string

            if self.sortingType == dtypes.InstrumentationReportPrintType.Location:
                location_elements = defaultdict(list)
                for element in element_list:
                    location_elements[element[3]].append(element)
                for location in sorted(location_elements):
                    elements_list = location_elements[location]
                    sdfg = -1
                    state = -1
                    for element in elements_list:
                        events = self._durations[element]
                        if len(events) > 0:
                            with_element_heading = True
                            for event in events.keys():
                                runtimes = events[event]
                                string, sdfg, state = self._get_runtimes_string(
                                    event, runtimes, element, sdfg, state,
                                    string, row_format, format_dict,
                                    with_element_heading)
                                with_element_heading = False

            else:
                sdfg = -1
                state = -1
                for element in element_list:
                    events = self._durations[element]
                    if len(events) > 0:
                        with_element_heading = True
                        for event in events.keys():
                            runtimes = events[event]
                            string, sdfg, state = self._get_runtimes_string(
                                event, runtimes, element, sdfg, state, string,
                                row_format, format_dict, with_element_heading)
                            with_element_heading = False
                    runtimes = self._durations[element]

            string += line_string

        if len(self._counters) > 0:
            COUNTER_COLW = 39
            counter_format = ('{:<{width}}' * 2) + '\n'

            string += ('-' * (COUNTER_COLW * 2)) + '\n'
            string += ('{:<{width}}' * 2).format(
                'Counter', 'Value', width=COUNTER_COLW) + '\n'
            string += ('-' * (COUNTER_COLW * 2)) + '\n'
            for counter in self._counters:
                string += counter_format.format(counter,
                                                self._counters[counter],
                                                width=COUNTER_COLW)
            string += ('-' * (COUNTER_COLW * 2)) + '\n'

        return string
Exemple #11
0
class StripMining(transformation.Transformation):
    """ Implements the strip-mining transformation.

        Strip-mining takes as input a map dimension and splits it into
        two dimensions. The new dimension iterates over the range of
        the original one with a parameterizable step, called the tile
        size. The original dimension is changed to iterates over the
        range of the tile size, with the same step as before.
    """

    _map_entry = nodes.MapEntry(nodes.Map("", [], []))

    # Properties
    dim_idx = Property(dtype=int,
                       default=-1,
                       desc="Index of dimension to be strip-mined")
    new_dim_prefix = Property(dtype=str,
                              default="tile",
                              desc="Prefix for new dimension name")
    tile_size = SymbolicProperty(
        default=64,
        desc="Tile size of strip-mined dimension, "
        "or number of tiles if tiling_type=number_of_tiles")
    tile_stride = SymbolicProperty(default=0,
                                   desc="Stride between two tiles of the "
                                   "strip-mined dimension. If zero, it is set "
                                   "equal to the tile size.")
    tile_offset = SymbolicProperty(default=0,
                                   desc="Tile stride offset (negative)")
    divides_evenly = Property(dtype=bool,
                              default=False,
                              desc="Tile size divides dimension range evenly?")
    strided = Property(
        dtype=bool,
        default=False,
        desc="Continuous (false) or strided (true) elements in tile")

    tiling_type = EnumProperty(
        dtype=dtypes.TilingType,
        default=dtypes.TilingType.Normal,
        allow_none=True,
        desc="normal: the outerloop increments with tile_size, "
        "ceilrange: uses ceiling(N/tile_size) in outer range, "
        "number_of_tiles: tiles the map into the number of provided tiles, "
        "provide the number of tiles over tile_size")

    skew = Property(
        dtype=bool,
        default=False,
        desc="If True, offsets inner tile back such that it starts with zero")

    @staticmethod
    def annotates_memlets():
        return True

    @staticmethod
    def expressions():
        return [
            sdutil.node_path_graph(StripMining._map_entry)
            # kStripMining._tasklet, StripMining._map_exit)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        return True

    @staticmethod
    def match_to_str(graph, candidate):
        map_entry = graph.nodes()[candidate[StripMining._map_entry]]
        return map_entry.map.label + ': ' + str(map_entry.map.params)

    def apply(self, sdfg: SDFG) -> nodes.Map:
        graph = sdfg.nodes()[self.state_id]
        # Strip-mine selected dimension.
        _, _, new_map = self._stripmine(sdfg, graph, self.subgraph)
        return new_map

    # def __init__(self, tag=True):
    def __init__(self, *args, **kwargs):
        self._entry = nodes.EntryNode()
        self._tasklet = nodes.Tasklet('_')
        self._exit = nodes.ExitNode()
        super().__init__(*args, **kwargs)
        # self.tag = tag

    @property
    def entry(self):
        return self._entry

    @property
    def exit(self):
        return self._exit

    @property
    def tasklet(self):
        return self._tasklet

    def print_match_pattern(self, candidate):
        gentry = candidate[self.entry]
        return str(gentry.map.params[-1])

    def _find_new_dim(self, sdfg: SDFG, state: SDFGState,
                      entry: nodes.MapEntry, prefix: str, target_dim: str):
        """ Finds a variable that is not already defined in scope. """
        stree = state.scope_tree()
        if len(prefix) == 0:
            return target_dim
        candidate = '%s_%s' % (prefix, target_dim)
        index = 1
        defined_vars = set(
            str(s) for s in (state.symbols_defined_at(entry).keys()
                             | sdfg.symbols.keys()))
        while candidate in defined_vars:
            candidate = '%s%d_%s' % (prefix, index, target_dim)
            index += 1
        return candidate

    def _create_strided_range(self, sdfg: SDFG, state: SDFGState,
                              map_entry: nodes.MapEntry):
        map_exit = state.exit_node(map_entry)
        dim_idx = self.dim_idx
        new_dim_prefix = self.new_dim_prefix
        tile_size = self.tile_size
        divides_evenly = self.divides_evenly
        tile_stride = self.tile_stride
        if tile_stride == 0:
            tile_stride = tile_size
        if tile_stride != tile_size:
            raise NotImplementedError

        # Retrieve parameter and range of dimension to be strip-mined.
        target_dim = map_entry.map.params[dim_idx]
        td_from, td_to, td_step = map_entry.map.range[dim_idx]
        new_dim = self._find_new_dim(sdfg, state, map_entry, new_dim_prefix,
                                     target_dim)
        new_dim_range = (td_from, td_to, tile_size)
        new_map = nodes.Map(map_entry.map.label, [new_dim],
                            subsets.Range([new_dim_range]))

        dimsym = dace.symbolic.pystr_to_symbolic(new_dim)
        td_from_new = dimsym
        if divides_evenly:
            td_to_new = dimsym + tile_size - 1
        else:
            if isinstance(td_to, dace.symbolic.SymExpr):
                td_to = td_to.expr
            td_to_new = dace.symbolic.SymExpr(
                sympy.Min(dimsym + tile_size - 1, td_to),
                dimsym + tile_size - 1)
        td_step_new = td_step

        return new_dim, new_map, (td_from_new, td_to_new, td_step_new)

    def _create_ceil_range(self, sdfg: SDFG, graph: SDFGState,
                           map_entry: nodes.MapEntry):
        map_exit = graph.exit_node(map_entry)

        # Retrieve transformation properties.
        dim_idx = self.dim_idx
        new_dim_prefix = self.new_dim_prefix
        tile_size = self.tile_size
        divides_evenly = self.divides_evenly
        strided = self.strided
        offset = self.tile_offset

        tile_stride = self.tile_stride
        if tile_stride == 0:
            tile_stride = tile_size

        # Retrieve parameter and range of dimension to be strip-mined.
        target_dim = map_entry.map.params[dim_idx]
        td_from, td_to, td_step = map_entry.map.range[dim_idx]
        # Create new map. Replace by cloning map object?
        new_dim = self._find_new_dim(sdfg, graph, map_entry, new_dim_prefix,
                                     target_dim)
        nd_from = 0
        if tile_stride == 1:
            nd_to = td_to - td_from
        else:
            nd_to = symbolic.pystr_to_symbolic(
                'int_ceil(%s + 1 - %s, %s) - 1' %
                (symbolic.symstr(td_to), symbolic.symstr(td_from),
                 symbolic.symstr(tile_stride)))
        nd_step = 1
        new_dim_range = (nd_from, nd_to, nd_step)
        new_map = nodes.Map(new_dim + '_' + map_entry.map.label, [new_dim],
                            subsets.Range([new_dim_range]))

        # Change the range of the selected dimension to iterate over a single
        # tile
        if strided:
            td_from_new = symbolic.pystr_to_symbolic(new_dim)
            td_to_new_approx = td_to
            td_step = tile_size

        elif offset == 0:
            td_from_new = symbolic.pystr_to_symbolic(
                '%s + %s * %s' %
                (symbolic.symstr(td_from), symbolic.symstr(new_dim),
                 symbolic.symstr(tile_stride)))
            td_to_new_exact = symbolic.pystr_to_symbolic(
                'min(%s + 1, %s + %s * %s + %s) - 1' %
                (symbolic.symstr(td_to), symbolic.symstr(td_from),
                 symbolic.symstr(tile_stride), symbolic.symstr(new_dim),
                 symbolic.symstr(tile_size)))
            td_to_new_approx = symbolic.pystr_to_symbolic(
                '%s + %s * %s + %s - 1' %
                (symbolic.symstr(td_from), symbolic.symstr(tile_stride),
                 symbolic.symstr(new_dim), symbolic.symstr(tile_size)))

        else:
            # include offset
            td_from_new_exact = symbolic.pystr_to_symbolic(
                'max(%s,%s + %s * %s - %s)' %
                (symbolic.symstr(td_from), symbolic.symstr(td_from),
                 symbolic.symstrtr(tile_stride), symbolic.symstr(new_dim),
                 symbolic.symstr(offset)))
            td_from_new_approx = symbolic.pystr_to_symbolic(
                '%s + %s * %s - %s ' %
                (symbolic.symstr(td_from), symbolic.symstr(tile_stride),
                 symbolic.symstr(new_dim), symbolic.symstr(offset)))
            td_from_new = dace.symbolic.SymExpr(td_from_new_exact,
                                                td_from_new_approx)

            td_to_new_exact = symbolic.pystr_to_symbolic(
                'min(%s + 1, %s + %s * %s + %s - %s) -1' %
                (symbolic.symstr(td_to), symbolic.symstr(td_from),
                 symbolic.symstr(tile_stride), symbolic.symstr(new_dim),
                 symbolic.symstr(tile_size), symbolic.symstr(offset)))
            td_to_new_approx = symbolic.pystr_to_symbolic(
                '%s + %s * %s + %s - %s - 1' %
                (symbolic.symstr(td_from), symbolic.symstr(tile_stride),
                 symbolic.symstr(new_dim), symbolic.symstr(tile_size),
                 symbolic.symstr(offset)))

        if divides_evenly or strided:
            td_to_new = td_to_new_approx
        else:
            td_to_new = dace.symbolic.SymExpr(td_to_new_exact,
                                              td_to_new_approx)
        return new_dim, new_map, (td_from_new, td_to_new, td_step)

    def _create_from_tile_numbers(self, sdfg: SDFG, state: SDFGState,
                                  map_entry: nodes.MapEntry):
        map_exit = state.exit_node(map_entry)

        # Retrieve transformation properties.
        dim_idx = self.dim_idx
        new_dim_prefix = self.new_dim_prefix
        divides_evenly = self.divides_evenly
        number_of_tiles = self.tile_size
        tile_stride = self.tile_stride

        number_of_tiles = dace.symbolic.pystr_to_symbolic(number_of_tiles)

        # Retrieve parameter and range of dimension to be strip-mined.
        target_dim = map_entry.map.params[dim_idx]
        td_from, td_to, td_step = map_entry.map.range[dim_idx]
        size = map_entry.map.range.size_exact()[dim_idx]

        if tile_stride != 0:
            raise NotImplementedError

        new_dim = self._find_new_dim(sdfg, state, map_entry, new_dim_prefix,
                                     target_dim)
        new_dim_range = (td_from, number_of_tiles - 1, 1)
        new_map = nodes.Map(map_entry.map.label, [new_dim],
                            subsets.Range([new_dim_range]))

        dimsym = dace.symbolic.pystr_to_symbolic(new_dim)
        td_from_new = (dimsym * size) // number_of_tiles
        if divides_evenly:
            td_to_new = ((dimsym + 1) * size) // number_of_tiles - 1
        else:
            if isinstance(td_to, dace.symbolic.SymExpr):
                td_to = td_to.expr
            td_to_new = dace.symbolic.SymExpr(
                sympy.Min(
                    ((dimsym + 1) * size) // number_of_tiles, td_to + 1) - 1,
                ((dimsym + 1) * size) // number_of_tiles - 1)
        td_step_new = td_step
        return new_dim, new_map, (td_from_new, td_to_new, td_step_new)

    def _stripmine(self, sdfg, graph, candidate):
        # Retrieve map entry and exit nodes.
        map_entry = graph.nodes()[candidate[StripMining._map_entry]]
        map_exit = graph.exit_node(map_entry)

        # Retrieve transformation properties.
        dim_idx = self.dim_idx
        target_dim = map_entry.map.params[dim_idx]

        if self.tiling_type == dtypes.TilingType.CeilRange:
            new_dim, new_map, td_rng = self._create_ceil_range(
                sdfg, graph, map_entry)
        elif self.tiling_type == dtypes.TilingType.NumberOfTiles:
            new_dim, new_map, td_rng = self._create_from_tile_numbers(
                sdfg, graph, map_entry)
        else:
            new_dim, new_map, td_rng = self._create_strided_range(
                sdfg, graph, map_entry)

        new_map_entry = nodes.MapEntry(new_map)
        new_map_exit = nodes.MapExit(new_map)

        td_to_new_approx = td_rng[1]
        if isinstance(td_to_new_approx, dace.symbolic.SymExpr):
            td_to_new_approx = td_to_new_approx.approx

        # Special case: If range is 1 and no prefix was specified, skip range
        if td_rng[0] == td_to_new_approx and target_dim == new_dim:
            map_entry.map.range = subsets.Range(
                [r for i, r in enumerate(map_entry.map.range) if i != dim_idx])
            map_entry.map.params = [
                p for i, p in enumerate(map_entry.map.params) if i != dim_idx
            ]
            if len(map_entry.map.params) == 0:
                raise ValueError('Strip-mining all dimensions of the map with '
                                 'empty tiles is disallowed')
        else:
            map_entry.map.range[dim_idx] = td_rng

        # Make internal map's schedule to "not parallel"
        new_map.schedule = map_entry.map.schedule
        map_entry.map.schedule = dtypes.ScheduleType.Sequential

        # Redirect edges
        new_map_entry.in_connectors = dcpy(map_entry.in_connectors)
        sdutil.change_edge_dest(graph, map_entry, new_map_entry)
        new_map_exit.out_connectors = dcpy(map_exit.out_connectors)
        sdutil.change_edge_src(graph, map_exit, new_map_exit)

        # Create new entry edges
        new_in_edges = dict()
        entry_in_conn = {}
        entry_out_conn = {}
        for _src, src_conn, _dst, _, memlet in graph.out_edges(map_entry):
            if (src_conn is not None
                    and src_conn[:4] == 'OUT_' and not isinstance(
                        sdfg.arrays[memlet.data], dace.data.Scalar)):
                new_subset = calc_set_image(
                    map_entry.map.params,
                    map_entry.map.range,
                    memlet.subset,
                )
                conn = src_conn[4:]
                key = (memlet.data, 'IN_' + conn, 'OUT_' + conn)
                if key in new_in_edges.keys():
                    old_subset = new_in_edges[key].subset
                    new_in_edges[key].subset = calc_set_union(
                        old_subset, new_subset)
                else:
                    entry_in_conn['IN_' + conn] = None
                    entry_out_conn['OUT_' + conn] = None
                    new_memlet = dcpy(memlet)
                    new_memlet.subset = new_subset
                    if memlet.dynamic:
                        new_memlet.num_accesses = memlet.num_accesses
                    else:
                        new_memlet.num_accesses = new_memlet.num_elements(
                        ).simplify()
                    new_in_edges[key] = new_memlet
            else:
                if src_conn is not None and src_conn[:4] == 'OUT_':
                    conn = src_conn[4:]
                    in_conn = 'IN_' + conn
                    out_conn = 'OUT_' + conn
                else:
                    in_conn = src_conn
                    out_conn = src_conn
                if in_conn:
                    entry_in_conn[in_conn] = None
                if out_conn:
                    entry_out_conn[out_conn] = None
                new_in_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet)
        new_map_entry.out_connectors = entry_out_conn
        map_entry.in_connectors = entry_in_conn
        for (_, in_conn, out_conn), memlet in new_in_edges.items():
            graph.add_edge(new_map_entry, out_conn, map_entry, in_conn, memlet)

        # Create new exit edges
        new_out_edges = dict()
        exit_in_conn = {}
        exit_out_conn = {}
        for _src, _, _dst, dst_conn, memlet in graph.in_edges(map_exit):
            if (dst_conn is not None
                    and dst_conn[:3] == 'IN_' and not isinstance(
                        sdfg.arrays[memlet.data], dace.data.Scalar)):
                new_subset = calc_set_image(
                    map_entry.map.params,
                    map_entry.map.range,
                    memlet.subset,
                )
                conn = dst_conn[3:]
                key = (memlet.data, 'IN_' + conn, 'OUT_' + conn)
                if key in new_out_edges.keys():
                    old_subset = new_out_edges[key].subset
                    new_out_edges[key].subset = calc_set_union(
                        old_subset, new_subset)
                else:
                    exit_in_conn['IN_' + conn] = None
                    exit_out_conn['OUT_' + conn] = None
                    new_memlet = dcpy(memlet)
                    new_memlet.subset = new_subset
                    if memlet.dynamic:
                        new_memlet.num_accesses = memlet.num_accesses
                    else:
                        new_memlet.num_accesses = new_memlet.num_elements(
                        ).simplify()
                    new_out_edges[key] = new_memlet
            else:
                if dst_conn is not None and dst_conn[:3] == 'IN_':
                    conn = dst_conn[3:]
                    in_conn = 'IN_' + conn
                    out_conn = 'OUT_' + conn
                else:
                    in_conn = dst_conn
                    out_conn = dst_conn
                if in_conn:
                    exit_in_conn[in_conn] = None
                if out_conn:
                    exit_out_conn[out_conn] = None
                new_out_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet)
        new_map_exit.in_connectors = exit_in_conn
        map_exit.out_connectors = exit_out_conn
        for (_, in_conn, out_conn), memlet in new_out_edges.items():
            graph.add_edge(map_exit, out_conn, new_map_exit, in_conn, memlet)

        # Skew if necessary
        if self.skew:
            xfh.offset_map(sdfg, graph, map_entry, dim_idx, td_rng[0])

        # Return strip-mined dimension.
        return target_dim, new_dim, new_map