Exemplo n.º 1
0
class CodeNode(Node):
    """ A node that contains runnable code with acyclic external data
        dependencies. May either be a tasklet or a nested SDFG, and
        denoted by an octagonal shape. """

    label = Property(dtype=str, desc="Name of the CodeNode")
    location = DictProperty(
        key_type=str,
        value_type=dace.symbolic.pystr_to_symbolic,
        desc='Full storage location identifier (e.g., rank, GPU ID)')
    environments = SetProperty(
        str,
        desc="Environments required by CMake to build and run this code node.",
        default=set())

    def __init__(self, label="", location=None, inputs=None, outputs=None):
        super(CodeNode, self).__init__(inputs or set(), outputs or set())
        # Properties
        self.label = label
        self.location = location if location is not None else {}

    @property
    def free_symbols(self) -> Set[str]:
        return set().union(*(map(str,
                                 pystr_to_symbolic(v).free_symbols)
                             for v in self.location.values()))
Exemplo n.º 2
0
class CodeObject(object):
    name = Property(dtype=str, desc="Filename to use")
    code = Property(dtype=str, desc="The code attached to this object")
    language = Property(dtype=str,
                        desc="Language used for this code (same " +
                        "as its file extension)")
    target = Property(dtype=type,
                      desc="Target to use for compilation",
                      allow_none=True)
    target_type = Property(
        dtype=str,
        desc="Sub-target within target (e.g., host or device code)",
        default="")
    title = Property(dtype=str, desc="Title of code for GUI")
    extra_compiler_kwargs = DictProperty(key_type=str,
                                         value_type=str,
                                         desc="Additional compiler argument "
                                         "variables to add to template")
    linkable = Property(dtype=bool,
                        desc='Should this file participate in '
                        'overall linkage?')
    environments = SetProperty(
        str,
        desc="Environments required by CMake to build and run this code node.",
        default=set())

    def __init__(self,
                 name,
                 code,
                 language,
                 target,
                 title,
                 target_type="",
                 additional_compiler_kwargs=None,
                 linkable=True,
                 environments=None,
                 sdfg=None):
        super(CodeObject, self).__init__()

        self.name = name
        self.code = code
        self.language = language
        self.target = target
        self.target_type = target_type
        self.title = title
        self.extra_compiler_kwargs = additional_compiler_kwargs or {}
        self.linkable = linkable
        self.environments = environments or set()

        if language == 'cpp' and title == 'Frame' and sdfg:
            sourcemap.create_maps(sdfg, code, self.target.target_name)

    @property
    def clean_code(self):
        return re.sub(r'[ \t]*////__(DACE:|CODEGEN;)[^\n]*', '', self.code)
Exemplo n.º 3
0
class CodeObject(object):
    name = Property(dtype=str, desc="Filename to use")
    code = Property(dtype=str, desc="The code attached to this object")
    language = Property(dtype=str,
                        desc="Language used for this code (same " +
                        "as its file extension)")  # dtype=dtypes.Language?
    target = Property(dtype=type,
                      desc="Target to use for compilation",
                      allow_none=True)
    target_type = Property(
        dtype=str,
        desc="Sub-target within target (e.g., host or device code)",
        default="")
    title = Property(dtype=str, desc="Title of code for GUI")
    extra_compiler_kwargs = DictProperty(key_type=str,
                                         value_type=str,
                                         desc="Additional compiler argument "
                                         "variables to add to template")
    linkable = Property(dtype=bool,
                        desc='Should this file participate in '
                        'overall linkage?')
    environments = SetProperty(
        str,
        desc="Environments required by CMake to build and run this code node.",
        default=set())

    def __init__(self,
                 name,
                 code,
                 language,
                 target,
                 title,
                 target_type="",
                 additional_compiler_kwargs=None,
                 linkable=True,
                 environments=set()):
        super(CodeObject, self).__init__()

        self.name = name
        self.code = code
        self.language = language
        self.target = target
        self.target_type = target_type
        self.title = title
        self.extra_compiler_kwargs = additional_compiler_kwargs or {}
        self.linkable = linkable
        self.environments = environments
Exemplo n.º 4
0
class RTLTasklet(Tasklet):
    """ A specialized tasklet, which is a functional computation procedure
        that can only access external data specified using connectors.

        This tasklet is specialized for tasklets implemented in System Verilog
        in that it adds support for adding metadata about the IP cores in use.
    """
    # TODO to be replaced when enums have embedded properties
    ip_cores = DictProperty(key_type=str,
                            value_type=dict,
                            desc="A set of IP cores used by the tasklet.")

    @property
    def __jsontype__(self):
        return 'Tasklet'

    def add_ip_core(self, module_name, name, vendor, version, params):
        self.ip_cores[module_name] = {
            'name': name,
            'vendor': vendor,
            'version': version,
            'params': params
        }
Exemplo n.º 5
0
class NestedSDFG(CodeNode):
    """ An SDFG state node that contains an SDFG of its own, runnable using
        the data dependencies specified using its connectors.

        It is encouraged to use nested SDFGs instead of coarse-grained tasklets
        since they are analyzable with respect to transformations.

        @note: A nested SDFG cannot create recursion (one of its parent SDFGs).
    """

    # NOTE: We cannot use SDFG as the type because of an import loop
    sdfg = SDFGReferenceProperty(desc="The SDFG", allow_none=True)
    schedule = Property(dtype=dtypes.ScheduleType,
                        desc="SDFG schedule",
                        allow_none=True,
                        choices=dtypes.ScheduleType,
                        from_string=lambda x: dtypes.ScheduleType[x],
                        default=dtypes.ScheduleType.Default)
    symbol_mapping = DictProperty(
        key_type=str,
        value_type=dace.symbolic.pystr_to_symbolic,
        desc="Mapping between internal symbols and their values, expressed as "
        "symbolic expressions")
    debuginfo = DebugInfoProperty()
    is_collapsed = Property(dtype=bool,
                            desc="Show this node/scope/state as collapsed",
                            default=False)

    instrument = Property(choices=dtypes.InstrumentationType,
                          desc="Measure execution statistics with given method",
                          default=dtypes.InstrumentationType.No_Instrumentation)

    def __init__(self,
                 label,
                 sdfg,
                 inputs: Set[str],
                 outputs: Set[str],
                 symbol_mapping: Dict[str, Any] = None,
                 schedule=dtypes.ScheduleType.Default,
                 location=None,
                 debuginfo=None):
        super(NestedSDFG, self).__init__(label, location, inputs, outputs)

        # Properties
        self.sdfg = sdfg
        self.symbol_mapping = symbol_mapping or {}
        self.schedule = schedule
        self.debuginfo = debuginfo

    @staticmethod
    def from_json(json_obj, context=None):
        from dace import SDFG  # Avoid import loop

        # We have to load the SDFG first.
        ret = NestedSDFG("nolabel", SDFG('nosdfg'), {}, {})

        dace.serialize.set_properties_from_json(ret, json_obj, context)

        if context and 'sdfg_state' in context:
            ret.sdfg.parent = context['sdfg_state']
        if context and 'sdfg' in context:
            ret.sdfg.parent_sdfg = context['sdfg']

        ret.sdfg.parent_nsdfg_node = ret

        ret.sdfg.update_sdfg_list([])

        return ret

    @property
    def free_symbols(self) -> Set[str]:
        return set().union(
            *(map(str,
                  pystr_to_symbolic(v).free_symbols)
              for v in self.symbol_mapping.values()),
            *(map(str,
                  pystr_to_symbolic(v).free_symbols)
              for v in self.location.values()))

    def infer_connector_types(self, sdfg, state):
        # Avoid import loop
        from dace.sdfg.infer_types import infer_connector_types
        # Infer internal connector types
        infer_connector_types(self.sdfg)

    def __str__(self):
        if not self.label:
            return "SDFG"
        else:
            return self.label

    def validate(self, sdfg, state):
        if not dtypes.validate_name(self.label):
            raise NameError('Invalid nested SDFG name "%s"' % self.label)
        for in_conn in self.in_connectors:
            if not dtypes.validate_name(in_conn):
                raise NameError('Invalid input connector "%s"' % in_conn)
        for out_conn in self.out_connectors:
            if not dtypes.validate_name(out_conn):
                raise NameError('Invalid output connector "%s"' % out_conn)
        connectors = self.in_connectors.keys() | self.out_connectors.keys()
        for dname, desc in self.sdfg.arrays.items():
            # TODO(later): Disallow scalars without access nodes (so that this
            #              check passes for them too).
            if isinstance(desc, data.Scalar):
                continue
            if not desc.transient and dname not in connectors:
                raise NameError('Data descriptor "%s" not found in nested '
                                'SDFG connectors' % dname)
            if dname in connectors and desc.transient:
                raise NameError(
                    '"%s" is a connector but its corresponding array is transient'
                    % dname)

        # Validate undefined symbols
        symbols = set(k for k in self.sdfg.free_symbols if k not in connectors)
        missing_symbols = [s for s in symbols if s not in self.symbol_mapping]
        if missing_symbols:
            raise ValueError('Missing symbols on nested SDFG: %s' %
                             (missing_symbols))

        # Recursively validate nested SDFG
        self.sdfg.validate()
Exemplo n.º 6
0
class Node(object):
    """ Base node class. """

    in_connectors = DictProperty(
        key_type=str,
        value_type=dtypes.typeclass,
        desc="A set of input connectors for this node.")
    out_connectors = DictProperty(
        key_type=str,
        value_type=dtypes.typeclass,
        desc="A set of output connectors for this node.")

    def __init__(self, in_connectors=None, out_connectors=None):
        # Convert connectors to typed connectors with autodetect type
        if isinstance(in_connectors, (set, list, KeysView)):
            in_connectors = {k: None for k in in_connectors}
        if isinstance(out_connectors, (set, list, KeysView)):
            out_connectors = {k: None for k in out_connectors}

        self.in_connectors = in_connectors or {}
        self.out_connectors = out_connectors or {}

    def __str__(self):
        if hasattr(self, 'label'):
            return self.label
        else:
            return type(self).__name__

    def validate(self, sdfg, state):
        pass

    def to_json(self, parent):
        labelstr = str(self)
        typestr = getattr(self, '__jsontype__', str(type(self).__name__))

        try:
            scope_entry_node = parent.entry_node(self)
        except (RuntimeError, StopIteration):
            scope_entry_node = None

        if scope_entry_node is not None:
            ens = parent.exit_node(parent.entry_node(self))
            scope_exit_node = str(parent.node_id(ens))
            scope_entry_node = str(parent.node_id(scope_entry_node))
        else:
            scope_entry_node = None
            scope_exit_node = None

        # The scope exit of an entry node is the matching exit node
        if isinstance(self, EntryNode):
            try:
                scope_exit_node = str(parent.node_id(parent.exit_node(self)))
            except (RuntimeError, StopIteration):
                scope_exit_node = None

        retdict = {
            "type": typestr,
            "label": labelstr,
            "attributes": dace.serialize.all_properties_to_json(self),
            "id": parent.node_id(self),
            "scope_entry": scope_entry_node,
            "scope_exit": scope_exit_node
        }
        return retdict

    def __repr__(self):
        return type(self).__name__ + ' (' + self.__str__() + ')'

    def add_in_connector(self,
                         connector_name: str,
                         dtype: dtypes.typeclass = None):
        """ Adds a new input connector to the node. The operation will fail if
            a connector (either input or output) with the same name already
            exists in the node.

            :param connector_name: The name of the new connector.
            :param dtype: The type of the connector, or None for auto-detect.
            :return: True if the operation is successful, otherwise False.
        """

        if (connector_name in self.in_connectors
                or connector_name in self.out_connectors):
            return False
        connectors = self.in_connectors
        connectors[connector_name] = dtype
        self.in_connectors = connectors
        return True

    def add_out_connector(self,
                          connector_name: str,
                          dtype: dtypes.typeclass = None):
        """ Adds a new output connector to the node. The operation will fail if
            a connector (either input or output) with the same name already
            exists in the node.

            :param connector_name: The name of the new connector.
            :param dtype: The type of the connector, or None for auto-detect.
            :return: True if the operation is successful, otherwise False.
        """

        if (connector_name in self.in_connectors
                or connector_name in self.out_connectors):
            return False
        connectors = self.out_connectors
        connectors[connector_name] = dtype
        self.out_connectors = connectors
        return True

    def remove_in_connector(self, connector_name: str):
        """ Removes an input connector from the node.
            :param connector_name: The name of the connector to remove.
            :return: True if the operation was successful.
        """

        if connector_name in self.in_connectors:
            connectors = self.in_connectors
            del connectors[connector_name]
            self.in_connectors = connectors
        return True

    def remove_out_connector(self, connector_name: str):
        """ Removes an output connector from the node.
            :param connector_name: The name of the connector to remove.
            :return: True if the operation was successful.
        """

        if connector_name in self.out_connectors:
            connectors = self.out_connectors
            del connectors[connector_name]
            self.out_connectors = connectors
        return True

    def _next_connector_int(self) -> int:
        """ Returns the next unused connector ID (as an integer). Used for
            filling connectors when adding edges to scopes. """
        next_number = 1
        for conn in itertools.chain(self.in_connectors, self.out_connectors):
            if conn.startswith('IN_'):
                cconn = conn[3:]
            elif conn.startswith('OUT_'):
                cconn = conn[4:]
            else:
                continue
            try:
                curconn = int(cconn)
                if curconn >= next_number:
                    next_number = curconn + 1
            except (TypeError, ValueError):  # not integral
                continue
        return next_number

    def next_connector(self) -> str:
        """ Returns the next unused connector ID (as a string). Used for
            filling connectors when adding edges to scopes. """
        return str(self._next_connector_int())

    def last_connector(self) -> str:
        """ Returns the last used connector ID (as a string). Used for
            filling connectors when adding edges to scopes. """
        return str(self._next_connector_int() - 1)

    @property
    def free_symbols(self) -> Set[str]:
        """ Returns a set of symbols used in this node's properties. """
        return set()

    def new_symbols(self, sdfg, state, symbols) -> Dict[str, dtypes.typeclass]:
        """ Returns a mapping between symbols defined by this node (e.g., for
            scope entries) to their type. """
        return {}

    def infer_connector_types(self, sdfg, state):
        """
        Infers and fills remaining connectors (i.e., set to None) with their
        types.
        """
        pass
Exemplo n.º 7
0
class Data(object):
    """ Data type descriptors that can be used as references to memory.
        Examples: Arrays, Streams, custom arrays (e.g., sparse matrices).
    """

    dtype = TypeClassProperty(default=dtypes.int32, choices=dtypes.Typeclasses)
    shape = ShapeProperty(default=[])
    transient = Property(dtype=bool, default=False)
    storage = EnumProperty(dtype=dtypes.StorageType,
                           desc="Storage location",
                           default=dtypes.StorageType.Default)
    lifetime = EnumProperty(dtype=dtypes.AllocationLifetime,
                            desc='Data allocation span',
                            default=dtypes.AllocationLifetime.Scope)
    location = DictProperty(
        key_type=str,
        value_type=symbolic.pystr_to_symbolic,
        desc='Full storage location identifier (e.g., rank, GPU ID)')
    debuginfo = DebugInfoProperty(allow_none=True)

    def __init__(self, dtype, shape, transient, storage, location, lifetime,
                 debuginfo):
        self.dtype = dtype
        self.shape = shape
        self.transient = transient
        self.storage = storage
        self.location = location if location is not None else {}
        self.lifetime = lifetime
        self.debuginfo = debuginfo
        self._validate()

    def validate(self):
        """ Validate the correctness of this object.
            Raises an exception on error. """
        self._validate()

    # Validation of this class is in a separate function, so that this
    # class can call `_validate()` without calling the subclasses'
    # `validate` function.
    def _validate(self):
        if any(not isinstance(s, (int, symbolic.SymExpr, symbolic.symbol,
                                  symbolic.sympy.Basic)) for s in self.shape):
            raise TypeError('Shape must be a list or tuple of integer values '
                            'or symbols')
        return True

    def to_json(self):
        attrs = serialize.all_properties_to_json(self)

        retdict = {"type": type(self).__name__, "attributes": attrs}

        return retdict

    @property
    def toplevel(self):
        return self.lifetime is not dtypes.AllocationLifetime.Scope

    def copy(self):
        raise RuntimeError(
            'Data descriptors are unique and should not be copied')

    def is_equivalent(self, other):
        """ Check for equivalence (shape and type) of two data descriptors. """
        raise NotImplementedError

    def as_arg(self, with_types=True, for_call=False, name=None):
        """Returns a string for a C++ function signature (e.g., `int *A`). """
        raise NotImplementedError

    @property
    def free_symbols(self) -> Set[symbolic.SymbolicType]:
        """ Returns a set of undefined symbols in this data descriptor. """
        result = set()
        for s in self.shape:
            if isinstance(s, sp.Basic):
                result |= set(s.free_symbols)
        return result

    def __repr__(self):
        return 'Abstract Data Container, DO NOT USE'

    @property
    def veclen(self):
        return self.dtype.veclen if hasattr(self.dtype, "veclen") else 1

    @property
    def ctype(self):
        return self.dtype.ctype
Exemplo n.º 8
0
class AccumulateTransient(transformation.Transformation):
    """ Implements the AccumulateTransient transformation, which adds
        transient stream and data nodes between nested maps that lead to a
        stream. The transient data nodes then act as a local accumulator.
    """

    map_exit = transformation.PatternNode(nodes.MapExit)
    outer_map_exit = transformation.PatternNode(nodes.MapExit)

    array_identity_dict = DictProperty(key_type=str,
                                       value_type=symbolic.pystr_to_symbolic,
                                       desc="dict with key: Array and"
                                       "value: the Identity value to set",
                                       default=dict(),
                                       allow_none=True)

    array = Property(
        dtype=str,
        desc="Array to create local storage for (if empty, first available)",
        default=None,
        allow_none=True)

    prefix = Property(dtype=str,
                      default="trans_",
                      allow_none=True,
                      desc='Prefix for new data node')

    identity = SymbolicProperty(desc="Identity value to set",
                                default=None,
                                allow_none=True)

    @staticmethod
    def expressions():
        return [
            sdutil.node_path_graph(AccumulateTransient.map_exit,
                                   AccumulateTransient.outer_map_exit)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        map_exit = graph.nodes()[candidate[AccumulateTransient.map_exit]]
        outer_map_exit = graph.nodes()[candidate[
            AccumulateTransient.outer_map_exit]]

        # Check if there is an accumulation output
        for e in graph.edges_between(map_exit, outer_map_exit):
            if e.data.wcr is not None:
                return True

        return False

    @staticmethod
    def match_to_str(graph, candidate):
        map_exit = candidate[AccumulateTransient.map_exit]
        outer_map_exit = candidate[AccumulateTransient.outer_map_exit]

        return ' -> '.join(str(node) for node in [map_exit, outer_map_exit])

    def apply(self, sdfg: SDFG):
        graph = sdfg.node(self.state_id)
        map_exit = graph.node(self.subgraph[AccumulateTransient.map_exit])
        outer_map_exit = graph.node(
            self.subgraph[AccumulateTransient.outer_map_exit])

        # Avoid import loop
        from dace.transformation.dataflow.local_storage import OutLocalStorage

        array_identity_dict = self.array_identity_dict

        # Choose array
        array = self.array
        if array is not None and len(array) != 0:
            array_identity_dict[array] = self.identity
        elif ((array is None or len(array) == 0)
              and len(array_identity_dict) == 0):
            array = next(e.data.data
                         for e in graph.edges_between(map_exit, outer_map_exit)
                         if e.data.wcr is not None)
            array_identity_dict[array] = self.identity

        transients: Dict[str, Any] = {}
        for array, identity in array_identity_dict.items():
            data_node: nodes.AccessNode = OutLocalStorage.apply_to(
                sdfg,
                dict(array=array, prefix=self.prefix),
                verify=False,
                save=False,
                node_a=map_exit,
                node_b=outer_map_exit)

            transients[data_node.data] = identity

            if identity is None:
                warnings.warn(
                    'AccumulateTransient did not properly initialize '
                    'newly-created transient!')
                return

        sdfg_state: SDFGState = sdfg.node(self.state_id)

        map_entry = sdfg_state.entry_node(map_exit)

        nested_sdfg: nodes.NestedSDFG = nest_state_subgraph(
            sdfg=sdfg,
            state=sdfg_state,
            subgraph=SubgraphView(
                sdfg_state, {map_entry, map_exit}
                | sdfg_state.all_nodes_between(map_entry, map_exit)))

        nested_sdfg_state: SDFGState = nested_sdfg.sdfg.nodes()[0]

        init_state = nested_sdfg.sdfg.add_state_before(nested_sdfg_state)

        for data_name, identity in transients.items():
            temp_array: Array = sdfg.arrays[data_name]

            init_state.add_mapped_tasklet(
                name='acctrans_init',
                map_ranges={
                    '_o%d' % i: '0:%s' % symbolic.symstr(d)
                    for i, d in enumerate(temp_array.shape)
                },
                inputs={},
                code='out = %s' % identity,
                outputs={
                    'out':
                    dace.Memlet.simple(
                        data=data_name,
                        subset_str=','.join([
                            '_o%d' % i for i, _ in enumerate(temp_array.shape)
                        ]))
                },
                external_edges=True)

        # TODO: use trivial map elimintation here when it will be merged to remove map if it has trivial ranges

        return nested_sdfg
Exemplo n.º 9
0
class Transformation(TransformationBase):
    """ Base class for pattern-matching transformations, as well as a static
        registry of transformations, where new transformations can be added in a
        decentralized manner.
        An instance of a Transformation represents a match of the transformation
        on an SDFG, complete with a subgraph candidate and properties.

        New transformations that extend this class must contain static
        `PatternNode` fields that represent the nodes in the pattern graph, and
        use them to implement at least three methods:
          * `expressions`: A method that returns a list of graph
                           patterns (SDFG or SDFGState objects) that match this
                           transformation.
          * `can_be_applied`: A method that, given a subgraph candidate,
                              checks for additional conditions whether it can
                              be transformed.
          * `apply`: A method that applies the transformation
                     on the given SDFG.

        For more information and optimization opportunities, see the respective
        methods' documentation.

        In order to be included in lists and apply through the
        `sdfg.apply_transformations` API, each transformation shouls be
        registered with ``Transformation.register`` (or, more commonly,
        the ``@dace.registry.autoregister_params`` class decorator) with two
        optional boolean keyword arguments: ``singlestate`` (default: False)
        and ``strict`` (default: False).
        If ``singlestate`` is True, the transformation is matched on subgraphs
        inside an SDFGState; otherwise, subgraphs of the SDFG state machine are
        matched.
        If ``strict`` is True, this transformation will be considered strict
        (i.e., always beneficial to perform) and will be performed automatically
        as part of SDFG strict transformations.
    """

    # Properties
    sdfg_id = Property(dtype=int, category="(Debug)")
    state_id = Property(dtype=int, category="(Debug)")
    _subgraph = DictProperty(key_type=int, value_type=int, category="(Debug)")
    expr_index = Property(dtype=int, category="(Debug)")

    def annotates_memlets(self) -> bool:
        """ Indicates whether the transformation annotates the edges it creates
            or modifies with the appropriate memlets. This determines
            whether to apply memlet propagation after the transformation.
        """
        return False

    def expressions(self) -> List[gr.SubgraphView]:
        """ Returns a list of Graph objects that will be matched in the
            subgraph isomorphism phase. Used as a pre-pass before calling
            `can_be_applied`.
            :see: Transformation.can_be_applied
        """
        raise NotImplementedError

    def can_be_applied(self,
                       graph: Union[SDFG, SDFGState],
                       candidate: Dict['PatternNode', int],
                       expr_index: int,
                       sdfg: SDFG,
                       strict: bool = False) -> bool:
        """ Returns True if this transformation can be applied on the candidate
            matched subgraph.
            :param graph: SDFGState object if this Transformation is
                          single-state, or SDFG object otherwise.
            :param candidate: A mapping between node IDs returned from
                              `Transformation.expressions` and the nodes in
                              `graph`.
            :param expr_index: The list index from `Transformation.expressions`
                               that was matched.
            :param sdfg: If `graph` is an SDFGState, its parent SDFG. Otherwise
                         should be equal to `graph`.
            :param strict: Whether transformation should run in strict mode.
            :return: True if the transformation can be applied.
        """
        raise NotImplementedError

    def apply(self, sdfg: SDFG) -> Union[Any, None]:
        """
        Applies this transformation instance on the matched pattern graph.
        :param sdfg: The SDFG to apply the transformation to.
        :return: A transformation-defined return value, which could be used
                 to pass analysis data out, or nothing.
        """
        raise NotImplementedError

    def match_to_str(self, graph: Union[SDFG, SDFGState],
                     candidate: Dict['PatternNode', int]) -> str:
        """ Returns a string representation of the pattern match on the
            candidate subgraph. Used when identifying matches in the console
            UI.
        """
        return str(list(candidate.values()))

    def __init__(self,
                 sdfg_id: int,
                 state_id: int,
                 subgraph: Dict['PatternNode', int],
                 expr_index: int,
                 override: bool = False,
                 options: Optional[Dict[str, Any]] = None) -> None:
        """ Initializes an instance of Transformation match.
            :param sdfg_id: A unique ID of the SDFG.
            :param state_id: The node ID of the SDFG state, if applicable. If
                             transformation does not operate on a single state,
                             the value should be -1.
            :param subgraph: A mapping between node IDs returned from
                             `Transformation.expressions` and the nodes in
                             `graph`.
            :param expr_index: The list index from `Transformation.expressions`
                               that was matched.
            :param override: If True, accepts the subgraph dictionary as-is
                             (mostly for internal use).
            :param options: An optional dictionary of transformation properties
            :raise TypeError: When transformation is not subclass of
                              Transformation.
            :raise TypeError: When state_id is not instance of int.
            :raise TypeError: When subgraph is not a dict of
                              PatternNode : int.
        """

        self.sdfg_id = sdfg_id
        self.state_id = state_id
        if not override:
            expr = self.expressions()[expr_index]
            for value in subgraph.values():
                if not isinstance(value, int):
                    raise TypeError('All values of '
                                    'subgraph'
                                    ' dictionary must be '
                                    'instances of int.')
            self._subgraph = {expr.node_id(k): v for k, v in subgraph.items()}
        else:
            self._subgraph = {-1: -1}
        # Serializable subgraph with node IDs as keys
        self._subgraph_user = copy.copy(subgraph)
        self.expr_index = expr_index

        # Ease-of-use API: Set new pattern-nodes with information about this
        # instance.
        for pname, pval in self._get_pattern_nodes().items():
            # Create new pattern node from existing field
            new_pnode = PatternNode(
                pval.node if isinstance(pval, PatternNode) else type(pval))
            new_pnode.match_instance = self

            # Append existing values in subgraph dictionary
            if pval in self._subgraph_user:
                self._subgraph_user[new_pnode] = self._subgraph_user[pval]

            # Override static field with the new node in this instance only
            setattr(self, pname, new_pnode)

        # Set properties
        if options is not None:
            for optname, optval in options.items():
                setattr(self, optname, optval)

    @property
    def subgraph(self):
        return self._subgraph_user

    def apply_pattern(self,
                      sdfg: SDFG,
                      append: bool = True,
                      annotate: bool = True) -> Union[Any, None]:
        """
        Applies this transformation on the given SDFG, using the transformation
        instance to find the right SDFG object (based on SDFG ID), and applying
        memlet propagation as necessary.
        :param sdfg: The SDFG (or an SDFG in the same hierarchy) to apply the
                     transformation to.
        :param append: If True, appends the transformation to the SDFG
                       transformation history.
        :return: A transformation-defined return value, which could be used
                 to pass analysis data out, or nothing.
        """
        if append:
            sdfg.append_transformation(self)
        tsdfg: SDFG = sdfg.sdfg_list[self.sdfg_id]
        retval = self.apply(tsdfg)
        if annotate and not self.annotates_memlets():
            propagation.propagate_memlets_sdfg(tsdfg)
        return retval

    def __lt__(self, other: 'Transformation') -> bool:
        """
        Comparing two transformations by their class name and node IDs
        in match. Used for ordering transformations consistently.
        """
        if type(self) != type(other):
            return type(self).__name__ < type(other).__name__

        self_ids = iter(self.subgraph.values())
        other_ids = iter(self.subgraph.values())

        try:
            self_id = next(self_ids)
        except StopIteration:
            return True
        try:
            other_id = next(other_ids)
        except StopIteration:
            return False

        self_end = False

        while self_id is not None and other_id is not None:
            if self_id != other_id:
                return self_id < other_id
            try:
                self_id = next(self_ids)
            except StopIteration:
                self_end = True
            try:
                other_id = next(other_ids)
            except StopIteration:
                if self_end:  # Transformations are equal
                    return False
                return False
            if self_end:
                return True

    @classmethod
    def _get_pattern_nodes(cls) -> Dict[str, 'PatternNode']:
        """
        Returns a dictionary of pattern-matching node in this transformation
        subclass. Used internally for pattern-matching.
        :return: A dictionary mapping between pattern-node name and its type.
        """
        return {
            k: getattr(cls, k)
            for k in dir(cls)
            if isinstance(getattr(cls, k), PatternNode) or (k.startswith(
                '_') and isinstance(getattr(cls, k), (nd.Node, SDFGState)))
        }

    @classmethod
    def apply_to(cls,
                 sdfg: SDFG,
                 options: Optional[Dict[str, Any]] = None,
                 expr_index: int = 0,
                 verify: bool = True,
                 annotate: bool = True,
                 strict: bool = False,
                 save: bool = True,
                 **where: Union[nd.Node, SDFGState]):
        """
        Applies this transformation to a given subgraph, defined by a set of
        nodes. Raises an error if arguments are invalid or transformation is
        not applicable.

        The subgraph is defined by the `where` dictionary, where each key is
        taken from the `PatternNode` fields of the transformation. For example,
        applying `MapCollapse` on two maps can pe performed as follows:

        ```
        MapCollapse.apply_to(sdfg, outer_map_entry=map_a, inner_map_entry=map_b)
        ```

        :param sdfg: The SDFG to apply the transformation to.
        :param options: A set of parameters to use for applying the
                        transformation.
        :param expr_index: The pattern expression index to try to match with.
        :param verify: Check that `can_be_applied` returns True before applying.
        :param annotate: Run memlet propagation after application if necessary.
        :param strict: Apply transformation in strict mode.
        :param save: Save transformation as part of the SDFG file. Set to
                     False if composing transformations.
        :param where: A dictionary of node names (from the transformation) to
                      nodes in the SDFG or a single state.
        """
        if len(where) == 0:
            raise ValueError('At least one node is required')
        options = options or {}

        # Check that all keyword arguments are nodes and if interstate or not
        sample_node = next(iter(where.values()))

        if isinstance(sample_node, SDFGState):
            graph = sdfg
            state_id = -1
        elif isinstance(sample_node, nd.Node):
            graph = next(s for s in sdfg.nodes() if sample_node in s.nodes())
            state_id = sdfg.node_id(graph)
        else:
            raise TypeError('Invalid node type "%s"' %
                            type(sample_node).__name__)

        # Check that all nodes in the pattern are set
        required_nodes = cls.expressions()[expr_index].nodes()
        required_node_names = {
            pname: pval
            for pname, pval in cls._get_pattern_nodes().items()
            if pval in required_nodes
        }
        required = set(required_node_names.keys())
        intersection = required & set(where.keys())
        if len(required - intersection) > 0:
            raise ValueError('Missing nodes for transformation subgraph: %s' %
                             (required - intersection))

        # Construct subgraph and instantiate transformation
        subgraph = {
            required_node_names[k]: graph.node_id(where[k])
            for k in required
        }
        instance = cls(sdfg.sdfg_id, state_id, subgraph, expr_index)

        # Construct transformation parameters
        for optname, optval in options.items():
            if not optname in cls.__properties__:
                raise ValueError('Property "%s" not found in transformation' %
                                 optname)
            setattr(instance, optname, optval)

        if verify:
            if not instance.can_be_applied(
                    graph, subgraph, expr_index, sdfg, strict=strict):
                raise ValueError('Transformation cannot be applied on the '
                                 'given subgraph ("can_be_applied" failed)')

        # Apply to SDFG
        return instance.apply_pattern(sdfg, annotate=annotate, append=save)

    def __str__(self) -> str:
        return type(self).__name__

    def print_match(self, sdfg: SDFG) -> str:
        """ Returns a string representation of the pattern match on the
            given SDFG. Used for printing matches in the console UI.
        """
        if not isinstance(sdfg, SDFG):
            raise TypeError("Expected SDFG, got: {}".format(
                type(sdfg).__name__))
        if self.state_id == -1:
            graph = sdfg
        else:
            graph = sdfg.nodes()[self.state_id]
        string = type(self).__name__ + ' in '
        string += self.match_to_str(graph, self.subgraph)
        return string

    def to_json(self, parent=None) -> Dict[str, Any]:
        props = serialize.all_properties_to_json(self)
        return {
            'type': 'Transformation',
            'transformation': type(self).__name__,
            **props
        }

    @staticmethod
    def from_json(json_obj: Dict[str, Any],
                  context: Dict[str, Any] = None) -> 'Transformation':
        xform = next(ext for ext in Transformation.extensions().keys()
                     if ext.__name__ == json_obj['transformation'])

        # Recreate subgraph
        expr = xform.expressions()[json_obj['expr_index']]
        subgraph = {
            expr.node(int(k)): int(v)
            for k, v in json_obj['_subgraph'].items()
        }

        # Reconstruct transformation
        ret = xform(json_obj['sdfg_id'], json_obj['state_id'], subgraph,
                    json_obj['expr_index'])
        context = context or {}
        context['transformation'] = ret
        serialize.set_properties_from_json(
            ret,
            json_obj,
            context=context,
            ignore_properties={'transformation', 'type'})
        return ret
Exemplo n.º 10
0
class Data(object):
    """ Data type descriptors that can be used as references to memory.
        Examples: Arrays, Streams, custom arrays (e.g., sparse matrices).
    """

    dtype = TypeClassProperty(default=dtypes.int32, choices=dtypes.Typeclasses)
    shape = ShapeProperty(default=[])
    transient = Property(dtype=bool, default=False)
    storage = EnumProperty(dtype=dtypes.StorageType, desc="Storage location", default=dtypes.StorageType.Default)
    lifetime = EnumProperty(dtype=dtypes.AllocationLifetime,
                            desc='Data allocation span',
                            default=dtypes.AllocationLifetime.Scope)
    location = DictProperty(key_type=str, value_type=str, desc='Full storage location identifier (e.g., rank, GPU ID)')
    debuginfo = DebugInfoProperty(allow_none=True)

    def __init__(self, dtype, shape, transient, storage, location, lifetime, debuginfo):
        self.dtype = dtype
        self.shape = shape
        self.transient = transient
        self.storage = storage
        self.location = location if location is not None else {}
        self.lifetime = lifetime
        self.debuginfo = debuginfo
        self._validate()

    def validate(self):
        """ Validate the correctness of this object.
            Raises an exception on error. """
        self._validate()

    # Validation of this class is in a separate function, so that this
    # class can call `_validate()` without calling the subclasses'
    # `validate` function.
    def _validate(self):
        if any(not isinstance(s, (int, symbolic.SymExpr, symbolic.symbol, symbolic.sympy.Basic)) for s in self.shape):
            raise TypeError('Shape must be a list or tuple of integer values ' 'or symbols')
        return True

    def to_json(self):
        attrs = serialize.all_properties_to_json(self)

        retdict = {"type": type(self).__name__, "attributes": attrs}

        return retdict

    @property
    def toplevel(self):
        return self.lifetime is not dtypes.AllocationLifetime.Scope

    def copy(self):
        raise RuntimeError('Data descriptors are unique and should not be copied')

    def is_equivalent(self, other):
        """ Check for equivalence (shape and type) of two data descriptors. """
        raise NotImplementedError

    def as_arg(self, with_types=True, for_call=False, name=None):
        """Returns a string for a C++ function signature (e.g., `int *A`). """
        raise NotImplementedError

    @property
    def free_symbols(self) -> Set[symbolic.SymbolicType]:
        """ Returns a set of undefined symbols in this data descriptor. """
        result = set()
        for s in self.shape:
            if isinstance(s, sp.Basic):
                result |= set(s.free_symbols)
        return result

    def __repr__(self):
        return 'Abstract Data Container, DO NOT USE'

    @property
    def veclen(self):
        return self.dtype.veclen if hasattr(self.dtype, "veclen") else 1

    @property
    def ctype(self):
        return self.dtype.ctype

    def strides_from_layout(
        self,
        *dimensions: int,
        alignment: symbolic.SymbolicType = 1,
        only_first_aligned: bool = False,
    ) -> Tuple[Tuple[symbolic.SymbolicType], symbolic.SymbolicType]:
        """
        Returns the absolute strides and total size of this data descriptor,
        according to the given dimension ordering and alignment.
        :param dimensions: A sequence of integers representing a permutation
                           of the descriptor's dimensions.
        :param alignment: Padding (in elements) at the end, ensuring stride
                          is a multiple of this number. 1 (default) means no
                          padding.
        :param only_first_aligned: If True, only the first dimension is padded
                                   with ``alignment``. Otherwise all dimensions
                                   are.
        :return: A 2-tuple of (tuple of strides, total size).
        """
        # Verify dimensions
        if tuple(sorted(dimensions)) != tuple(range(len(self.shape))):
            raise ValueError('Every dimension must be given and appear once.')
        if (alignment < 1) == True or (alignment < 0) == True:
            raise ValueError('Invalid alignment value')

        strides = [1] * len(dimensions)
        total_size = 1
        first = True
        for dim in dimensions:
            strides[dim] = total_size
            if not only_first_aligned or first:
                dimsize = (((self.shape[dim] + alignment - 1) // alignment) * alignment)
            else:
                dimsize = self.shape[dim]
            total_size *= dimsize
            first = False

        return (tuple(strides), total_size)

    def set_strides_from_layout(self,
                                *dimensions: int,
                                alignment: symbolic.SymbolicType = 1,
                                only_first_aligned: bool = False):
        """
        Sets the absolute strides and total size of this data descriptor,
        according to the given dimension ordering and alignment.
        :param dimensions: A sequence of integers representing a permutation
                           of the descriptor's dimensions.
        :param alignment: Padding (in elements) at the end, ensuring stride
                          is a multiple of this number. 1 (default) means no
                          padding.
        :param only_first_aligned: If True, only the first dimension is padded
                                   with ``alignment``. Otherwise all dimensions
                                   are.
        """
        strides, totalsize = self.strides_from_layout(*dimensions,
                                                      alignment=alignment,
                                                      only_first_aligned=only_first_aligned)
        self.strides = strides
        self.total_size = totalsize
Exemplo n.º 11
0
class Transformation(object):
    """ Base class for transformations, as well as a static registry of
        transformations, where new transformations can be added in a
        decentralized manner.

        New transformations are registered with ``Transformation.register``
        (or ``dace.registry.autoregister_params``) with two optional boolean
        keyword arguments: ``singlestate`` (default: False) and ``strict``
        (default: False).
        If ``singlestate`` is True, the transformation is matched on subgraphs
        inside an SDFGState; otherwise, subgraphs of the SDFG state machine are
        matched.
        If ``strict`` is True, this transformation will be considered strict
        (i.e., always beneficial to perform) and will be performed automatically
        as part of SDFG strict transformations.
    """

    # Properties
    sdfg_id = Property(dtype=int, category="(Debug)")
    state_id = Property(dtype=int, category="(Debug)")
    _subgraph = DictProperty(key_type=int, value_type=int, category="(Debug)")
    expr_index = Property(dtype=int, category="(Debug)")

    @staticmethod
    def annotates_memlets():
        """ Indicates whether the transformation annotates the edges it creates
            or modifies with the appropriate memlets. This determines
            whether to apply memlet propagation after the transformation.
        """

        return False

    @staticmethod
    def expressions():
        """ Returns a list of Graph objects that will be matched in the
            subgraph isomorphism phase. Used as a pre-pass before calling
            `can_be_applied`.
            @see Transformation.can_be_applied
        """

        raise NotImplementedError

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        """ Returns True if this transformation can be applied on the candidate
            matched subgraph.
            :param graph: SDFGState object if this Transformation is
                          single-state, or SDFG object otherwise.
            :param candidate: A mapping between node IDs returned from
                              `Transformation.expressions` and the nodes in
                              `graph`.
            :param expr_index: The list index from `Transformation.expressions`
                               that was matched.
            :param sdfg: If `graph` is an SDFGState, its parent SDFG. Otherwise
                         should be equal to `graph`.
            :param strict: Whether transformation should run in strict mode.
            :return: True if the transformation can be applied.
        """
        raise NotImplementedError

    @staticmethod
    def match_to_str(graph, candidate):
        """ Returns a string representation of the pattern match on the
            candidate subgraph. Used when identifying matches in the console
            UI.
        """
        raise NotImplementedError

    def __init__(self, sdfg_id, state_id, subgraph, expr_index):
        """ Initializes an instance of Transformation.
            :param sdfg_id: A unique ID of the SDFG.
            :param state_id: The node ID of the SDFG state, if applicable.
            :param subgraph: A mapping between node IDs returned from
                             `Transformation.expressions` and the nodes in
                             `graph`.
            :param expr_index: The list index from `Transformation.expressions`
                               that was matched.
            :raise TypeError: When transformation is not subclass of
                              Transformation.
            :raise TypeError: When state_id is not instance of int.
            :raise TypeError: When subgraph is not a dict of
                              dace.sdfg.nodes.Node : int.
        """

        self.sdfg_id = sdfg_id
        self.state_id = state_id
        for value in subgraph.values():
            if not isinstance(value, int):
                raise TypeError('All values of '
                                'subgraph'
                                ' dictionary must be '
                                'instances of int.')
        # Serializable subgraph with node IDs as keys
        expr = self.expressions()[expr_index]
        self._subgraph = {expr.node_id(k): v for k, v in subgraph.items()}
        self._subgraph_user = subgraph
        self.expr_index = expr_index

    @property
    def subgraph(self):
        return self._subgraph_user

    def __lt__(self, other):
        """ Comparing two transformations by their class name and node IDs
            in match. Used for ordering transformations consistently.
        """
        if type(self) != type(other):
            return type(self).__name__ < type(other).__name__

        self_ids = iter(self.subgraph.values())
        other_ids = iter(self.subgraph.values())

        try:
            self_id = next(self_ids)
        except StopIteration:
            return True
        try:
            other_id = next(other_ids)
        except StopIteration:
            return False

        self_end = False

        while self_id is not None and other_id is not None:
            if self_id != other_id:
                return self_id < other_id
            try:
                self_id = next(self_ids)
            except StopIteration:
                self_end = True
            try:
                other_id = next(other_ids)
            except StopIteration:
                if self_end:  # Transformations are equal
                    return False
                return False
            if self_end:
                return True

    def apply_pattern(self, sdfg):
        """ Applies this transformation on the given SDFG. """
        self.apply(sdfg)
        if not self.annotates_memlets():
            propagation.propagate_memlets_sdfg(sdfg)

    def __str__(self):
        return type(self).__name__

    def modifies_graph(self):
        return True

    def print_match(self, sdfg):
        """ Returns a string representation of the pattern match on the
            given SDFG. Used for printing matches in the console UI.
        """
        if not isinstance(sdfg, dace.SDFG):
            raise TypeError("Expected SDFG, got: {}".format(
                type(sdfg).__name__))
        if self.state_id == -1:
            graph = sdfg
        else:
            graph = sdfg.nodes()[self.state_id]
        string = type(self).__name__ + ' in '
        string += type(self).match_to_str(graph, self.subgraph)
        return string

    def to_json(self, parent=None):
        props = dace.serialize.all_properties_to_json(self)
        return {
            'type': 'Transformation',
            'transformation': type(self).__name__,
            **props
        }

    @staticmethod
    def from_json(json_obj, context=None):
        xform = next(ext for ext in Transformation.extensions().keys()
                     if ext.__name__ == json_obj['transformation'])

        # Recreate subgraph
        expr = xform.expressions()[json_obj['expr_index']]
        subgraph = {
            expr.node(int(k)): int(v)
            for k, v in json_obj['_subgraph'].items()
        }

        # Reconstruct transformation
        ret = xform(json_obj['sdfg_id'], json_obj['state_id'], subgraph,
                    json_obj['expr_index'])
        context = context or {}
        context['transformation'] = ret
        dace.serialize.set_properties_from_json(
            ret,
            json_obj,
            context=context,
            ignore_properties={'transformation', 'type'})
        return ret
Exemplo n.º 12
0
class ONNXSchema:
    """Python representation of an ONNX schema"""

    name = Property(dtype=str, desc="The operator name")
    domain = Property(dtype=str, desc="The operator domain")
    doc = Property(dtype=str, desc="The operator's docstring")
    since_version = Property(dtype=int, desc="The version of the operator")
    attributes = DictProperty(key_type=str,
                              value_type=ONNXAttribute,
                              desc="The operator attributes")
    type_constraints = DictProperty(
        key_type=str,
        value_type=ONNXTypeConstraint,
        desc="The type constraints for inputs and outputs")
    inputs = ListProperty(element_type=ONNXParameter,
                          desc="The operator input parameter descriptors")
    outputs = ListProperty(element_type=ONNXParameter,
                           desc="The operator output parameter descriptors")

    def __repr__(self):
        return self.domain + "." + self.name

    def validate(self):
        # check all parameters with a type str have a entry in the type constraints
        for param in chain(self.inputs, self.outputs):
            if param.type_str not in self.type_constraints:
                # some operators put a type descriptor here. for those, we will try to insert a new type constraint
                cons_name = param.name + "_constraint"
                if cons_name in self.type_constraints:
                    raise ValueError(
                        "Attempted to insert new type constraint, but the name already existed. Please open an issue."
                    )
                parsed_typeclass = onnx_type_str_to_typeclass(param.type_str)

                if parsed_typeclass is None:
                    print("Could not parse typeStr '{}' for parameter '{}'".
                          format(param.type_str, param.name))

                cons = ONNXTypeConstraint(
                    cons_name,
                    [parsed_typeclass] if parsed_typeclass is not None else [])
                self.type_constraints[cons_name] = cons
                param.type_str = cons_name

        # check for required parameters with no supported type
        for param in chain(self.inputs, self.outputs):
            if ((param.param_type == ONNXParameterType.Single
                 or param.param_type == ONNXParameterType.Variadic)
                    and len(self.type_constraints[param.type_str].types) == 0):
                raise NotImplementedError(
                    "None of the types for parameter '{}' are supported".
                    format(param.name))

        # check that all variadic parameter names do not contain "__"
        for param in chain(self.inputs, self.outputs):
            if param.param_type == ONNXParameterType.Variadic and "__" in param.name:
                raise ValueError(
                    "Unsupported parameter name '{}': variadic parameter names must not contain '__'"
                    .format(param.name))

        # check that all inputs and outputs have unique names
        seen = set()
        for param in self.inputs:
            if param.name in seen:
                raise ValueError(
                    "Got duplicate input parameter name '{}'".format(
                        param.name))
            seen.add(param.name)

        seen = set()
        for param in self.outputs:
            if param.name in seen:
                raise ValueError(
                    "Got duplicate output parameter name '{}'".format(
                        param.name))
            seen.add(param.name)
Exemplo n.º 13
0
class Map(object):
    """ A Map is a two-node representation of parametric graphs, containing
        an integer set by which the contents (nodes dominated by an entry
        node and post-dominated by an exit node) are replicated.

        Maps contain a `schedule` property, which specifies how the scope
        should be scheduled (execution order). Code generators can use the
        schedule property to generate appropriate code, e.g., GPU kernels.
    """

    # List of (editable) properties
    label = Property(dtype=str, desc="Label of the map")
    params = ListProperty(element_type=str, desc="Mapped parameters")
    range = RangeProperty(desc="Ranges of map parameters",
                          default=sbs.Range([]))
    schedule = EnumProperty(dtype=dtypes.ScheduleType,
                            desc="Map schedule",
                            default=dtypes.ScheduleType.Default)
    unroll = Property(dtype=bool, desc="Map unrolling")
    collapse = Property(dtype=int,
                        default=1,
                        desc="How many dimensions to"
                        " collapse into the parallel range")
    debuginfo = DebugInfoProperty()
    is_collapsed = Property(dtype=bool,
                            desc="Show this node/scope/state as collapsed",
                            default=False)

    instrument = EnumProperty(
        dtype=dtypes.InstrumentationType,
        desc="Measure execution statistics with given method",
        default=dtypes.InstrumentationType.No_Instrumentation)

    location = DictProperty(
        key_type=str,
        value_type=dace.symbolic.pystr_to_symbolic,
        desc='Full storage location identifier (e.g., rank, GPU ID)')

    def __init__(self,
                 label,
                 params,
                 ndrange,
                 schedule=dtypes.ScheduleType.Default,
                 unroll=False,
                 collapse=1,
                 fence_instrumentation=False,
                 debuginfo=None,
                 location=None):
        super(Map, self).__init__()

        # Assign properties
        self.label = label
        self.schedule = schedule
        self.unroll = unroll
        self.collapse = 1
        self.params = params
        self.range = ndrange
        self.debuginfo = debuginfo
        self.location = location if location is not None else {}
        self._fence_instrumentation = fence_instrumentation

    def __str__(self):
        return self.label + "[" + ", ".join([
            "{}={}".format(i, r)
            for i, r in zip(self._params,
                            [sbs.Range.dim_to_string(d) for d in self._range])
        ]) + "]"

    def validate(self, sdfg, state, node):
        if not dtypes.validate_name(self.label):
            raise NameError('Invalid map name "%s"' % self.label)

    def get_param_num(self):
        """ Returns the number of map dimension parameters/symbols. """
        return len(self.params)
Exemplo n.º 14
0
class Consume(object):
    """ Consume is a scope, like `Map`, that is a part of the parametric
        graph extension of the SDFG. It creates a producer-consumer
        relationship between the input stream and the scope subgraph. The
        subgraph is scheduled to a given number of processing elements
        for processing, and they will try to pop elements from the input
        stream until a given quiescence condition is reached. """

    # Properties
    label = Property(dtype=str, desc="Name of the consume node")
    pe_index = Property(dtype=str, desc="Processing element identifier")
    num_pes = SymbolicProperty(desc="Number of processing elements", default=1)
    condition = CodeProperty(desc="Quiescence condition", allow_none=True)
    schedule = EnumProperty(dtype=dtypes.ScheduleType,
                            desc="Consume schedule",
                            default=dtypes.ScheduleType.Default)
    chunksize = Property(dtype=int,
                         desc="Maximal size of elements to consume at a time",
                         default=1)
    debuginfo = DebugInfoProperty()
    is_collapsed = Property(dtype=bool,
                            desc="Show this node/scope/state as collapsed",
                            default=False)

    instrument = EnumProperty(
        dtype=dtypes.InstrumentationType,
        desc="Measure execution statistics with given method",
        default=dtypes.InstrumentationType.No_Instrumentation)

    location = DictProperty(key_type=str,
                            value_type=dace.symbolic.pystr_to_symbolic,
                            desc='Full storage location identifier'
                            '(e.g., rank, GPU ID)')

    def as_map(self):
        """ Compatibility function that allows to view the consume as a map,
            mainly in memlet propagation. """
        return Map(self.label, [self.pe_index],
                   sbs.Range([(0, self.num_pes - 1, 1)]), self.schedule)

    def __init__(self,
                 label,
                 pe_tuple,
                 condition,
                 schedule=dtypes.ScheduleType.Default,
                 chunksize=1,
                 debuginfo=None,
                 location=None):
        super(Consume, self).__init__()

        # Properties
        self.label = label
        self.pe_index, self.num_pes = pe_tuple
        self.condition = condition
        self.schedule = schedule
        self.chunksize = chunksize
        self.debuginfo = debuginfo
        self.location = location if location is not None else {}

    def __str__(self):
        if self.condition is not None:
            return ("%s [%s=0:%s], Condition: %s" %
                    (self._label, self.pe_index, self.num_pes,
                     CodeProperty.to_string(self.condition)))
        else:
            return ("%s [%s=0:%s]" %
                    (self._label, self.pe_index, self.num_pes))

    def validate(self, sdfg, state, node):
        if not dtypes.validate_name(self.label):
            raise NameError('Invalid consume name "%s"' % self.label)

    def get_param_num(self):
        """ Returns the number of consume dimension parameters/symbols. """
        return 1
Exemplo n.º 15
0
class PatternTransformation(TransformationBase):
    """ 
    Abstract class for pattern-matching transformations.
    Please extend either ``SingleStateTransformation`` or ``MultiStateTransformation``.
    
    :see: SingleStateTransformation
    :see: MultiStateTransformation
    :seealso: PatternNode
    """

    # Properties
    sdfg_id = Property(dtype=int, category="(Debug)")
    state_id = Property(dtype=int, category="(Debug)")
    _subgraph = DictProperty(key_type=int, value_type=int, category="(Debug)")
    expr_index = Property(dtype=int, category="(Debug)")

    @classmethod
    def subclasses_recursive(
            cls,
            all_subclasses: bool = False
    ) -> Set[Type['PatternTransformation']]:
        """
        Returns all subclasses of this class, including subclasses of subclasses. 
        :param all_subclasses: Include all subclasses (e.g., including ``ExpandTransformation``).
        """
        if not all_subclasses and cls is PatternTransformation:
            subclasses = set(SingleStateTransformation.__subclasses__()) | set(
                MultiStateTransformation.__subclasses__())
        else:
            subclasses = set(cls.__subclasses__())
        subsubclasses = set()
        for sc in subclasses:
            subsubclasses.update(sc.subclasses_recursive())

        # Ignore abstract classes
        result = subclasses | subsubclasses
        result = set(sc for sc in result
                     if not getattr(sc, '__abstractmethods__', False))

        return result

    def annotates_memlets(self) -> bool:
        """ Indicates whether the transformation annotates the edges it creates
            or modifies with the appropriate memlets. This determines
            whether to apply memlet propagation after the transformation.
        """
        return False

    @classmethod
    def expressions(cls) -> List[gr.SubgraphView]:
        """ Returns a list of Graph objects that will be matched in the
            subgraph isomorphism phase. Used as a pre-pass before calling
            `can_be_applied`.
            :see: PatternTransformation.can_be_applied
        """
        raise NotImplementedError

    def can_be_applied(self,
                       graph: Union[SDFG, SDFGState],
                       expr_index: int,
                       sdfg: SDFG,
                       permissive: bool = False) -> bool:
        """ Returns True if this transformation can be applied on the candidate
            matched subgraph.
            :param graph: SDFGState object if this transformation is
                          single-state, or SDFG object otherwise.
            :param expr_index: The list index from `PatternTransformation.expressions`
                               that was matched.
            :param sdfg: If `graph` is an SDFGState, its parent SDFG. Otherwise
                         should be equal to `graph`.
            :param permissive: Whether transformation should run in permissive mode.
            :return: True if the transformation can be applied.
        """
        raise NotImplementedError

    def apply(self, graph: Union[SDFG, SDFGState],
              sdfg: SDFG) -> Union[Any, None]:
        """
        Applies this transformation instance on the matched pattern graph.
        :param sdfg: The SDFG to apply the transformation to.
        :return: A transformation-defined return value, which could be used
                 to pass analysis data out, or nothing.
        """
        raise NotImplementedError

    def match_to_str(self, graph: Union[SDFG, SDFGState]) -> str:
        """ Returns a string representation of the pattern match on the
            candidate subgraph. Used when identifying matches in the console
            UI.
        """
        candidate = []
        node_to_name = {v: k for k, v in self._get_pattern_nodes().items()}
        for cnode in self.subgraph.keys():
            cname = node_to_name[cnode]
            candidate.append(getattr(self, cname))
        return str(candidate)

    def __init__(self,
                 sdfg: SDFG,
                 sdfg_id: int,
                 state_id: int,
                 subgraph: Dict['PatternNode', int],
                 expr_index: int,
                 override: bool = False,
                 options: Optional[Dict[str, Any]] = None) -> None:
        """ Initializes an instance of Transformation match.
            :param sdfg_id: A unique ID of the SDFG.
            :param state_id: The node ID of the SDFG state, if applicable. If
                             transformation does not operate on a single state,
                             the value should be -1.
            :param subgraph: A mapping between node IDs returned from
                             `PatternTransformation.expressions` and the nodes in
                             `graph`.
            :param expr_index: The list index from `PatternTransformation.expressions`
                               that was matched.
            :param override: If True, accepts the subgraph dictionary as-is
                             (mostly for internal use).
            :param options: An optional dictionary of transformation properties
            :raise TypeError: When transformation is not subclass of
                              PatternTransformation.
            :raise TypeError: When state_id is not instance of int.
            :raise TypeError: When subgraph is not a dict of
                              PatternNode : int.
        """

        self._sdfg = sdfg
        self.sdfg_id = sdfg_id
        self.state_id = state_id
        if not override:
            expr = self.expressions()[expr_index]
            for value in subgraph.values():
                if not isinstance(value, int):
                    raise TypeError('All values of '
                                    'subgraph'
                                    ' dictionary must be '
                                    'instances of int.')
            self._subgraph = {expr.node_id(k): v for k, v in subgraph.items()}
        else:
            self._subgraph = {-1: -1}
        # Serializable subgraph with node IDs as keys
        self._subgraph_user = copy.copy(subgraph)
        self.expr_index = expr_index

        # Set properties
        if options is not None:
            for optname, optval in options.items():
                setattr(self, optname, optval)

    @property
    def subgraph(self):
        return self._subgraph_user

    def apply_pattern(self,
                      append: bool = True,
                      annotate: bool = True) -> Union[Any, None]:
        """
        Applies this transformation on the given SDFG, using the transformation
        instance to find the right SDFG object (based on SDFG ID), and applying
        memlet propagation as necessary.
        :param sdfg: The SDFG (or an SDFG in the same hierarchy) to apply the
                     transformation to.
        :param append: If True, appends the transformation to the SDFG
                       transformation history.
        :return: A transformation-defined return value, which could be used
                 to pass analysis data out, or nothing.
        """
        if append:
            self._sdfg.append_transformation(self)
        tsdfg: SDFG = self._sdfg.sdfg_list[self.sdfg_id]
        tgraph = tsdfg.node(self.state_id) if self.state_id >= 0 else tsdfg
        retval = self.apply(tgraph, tsdfg)
        if annotate and not self.annotates_memlets():
            propagation.propagate_memlets_sdfg(tsdfg)
        return retval

    def __lt__(self, other: 'PatternTransformation') -> bool:
        """
        Comparing two transformations by their class name and node IDs
        in match. Used for ordering transformations consistently.
        """
        if type(self) != type(other):
            return type(self).__name__ < type(other).__name__

        self_ids = iter(self.subgraph.values())
        other_ids = iter(self.subgraph.values())

        try:
            self_id = next(self_ids)
        except StopIteration:
            return True
        try:
            other_id = next(other_ids)
        except StopIteration:
            return False

        self_end = False

        while self_id is not None and other_id is not None:
            if self_id != other_id:
                return self_id < other_id
            try:
                self_id = next(self_ids)
            except StopIteration:
                self_end = True
            try:
                other_id = next(other_ids)
            except StopIteration:
                if self_end:  # Transformations are equal
                    return False
                return False
            if self_end:
                return True

    @classmethod
    def _get_pattern_nodes(cls) -> Dict[str, 'PatternNode']:
        """
        Returns a dictionary of pattern-matching node in this transformation
        subclass. Used internally for pattern-matching.
        :return: A dictionary mapping between pattern-node name and its type.
        """
        return {
            k: getattr(cls, k)
            for k in dir(cls)
            if isinstance(getattr(cls, k), PatternNode) or (k.startswith(
                '_') and isinstance(getattr(cls, k), (nd.Node, SDFGState)))
        }

    @classmethod
    def apply_to(cls,
                 sdfg: SDFG,
                 options: Optional[Dict[str, Any]] = None,
                 expr_index: int = 0,
                 verify: bool = True,
                 annotate: bool = True,
                 permissive: bool = False,
                 save: bool = True,
                 **where: Union[nd.Node, SDFGState]):
        """
        Applies this transformation to a given subgraph, defined by a set of
        nodes. Raises an error if arguments are invalid or transformation is
        not applicable.

        The subgraph is defined by the `where` dictionary, where each key is
        taken from the `PatternNode` fields of the transformation. For example,
        applying `MapCollapse` on two maps can pe performed as follows:

        ```
        MapCollapse.apply_to(sdfg, outer_map_entry=map_a, inner_map_entry=map_b)
        ```

        :param sdfg: The SDFG to apply the transformation to.
        :param options: A set of parameters to use for applying the
                        transformation.
        :param expr_index: The pattern expression index to try to match with.
        :param verify: Check that `can_be_applied` returns True before applying.
        :param annotate: Run memlet propagation after application if necessary.
        :param permissive: Apply transformation in permissive mode.
        :param save: Save transformation as part of the SDFG file. Set to
                     False if composing transformations.
        :param where: A dictionary of node names (from the transformation) to
                      nodes in the SDFG or a single state.
        """
        if len(where) == 0:
            raise ValueError('At least one node is required')
        options = options or {}

        # Check that all keyword arguments are nodes and if interstate or not
        sample_node = next(iter(where.values()))

        if isinstance(sample_node, SDFGState):
            graph = sdfg
            state_id = -1
        elif isinstance(sample_node, nd.Node):
            graph = next(s for s in sdfg.nodes() if sample_node in s.nodes())
            state_id = sdfg.node_id(graph)
        else:
            raise TypeError('Invalid node type "%s"' %
                            type(sample_node).__name__)

        # Check that all nodes in the pattern are set
        required_nodes = cls.expressions()[expr_index].nodes()
        required_node_names = {
            pname: pval
            for pname, pval in cls._get_pattern_nodes().items()
            if pval in required_nodes
        }
        required = set(required_node_names.keys())
        intersection = required & set(where.keys())
        if len(required - intersection) > 0:
            raise ValueError('Missing nodes for transformation subgraph: %s' %
                             (required - intersection))

        # Construct subgraph and instantiate transformation
        subgraph = {
            required_node_names[k]: graph.node_id(where[k])
            for k in required
        }
        instance = cls(sdfg, sdfg.sdfg_id, state_id, subgraph, expr_index)

        # Construct transformation parameters
        for optname, optval in options.items():
            if not optname in cls.__properties__:
                raise ValueError('Property "%s" not found in transformation' %
                                 optname)
            setattr(instance, optname, optval)

        if verify:
            if not instance.can_be_applied(
                    graph, expr_index, sdfg, permissive=permissive):
                raise ValueError('Transformation cannot be applied on the '
                                 'given subgraph ("can_be_applied" failed)')

        # Apply to SDFG
        return instance.apply_pattern(annotate=annotate, append=save)

    def __str__(self) -> str:
        return type(self).__name__

    def print_match(self, sdfg: SDFG) -> str:
        """ Returns a string representation of the pattern match on the
            given SDFG. Used for printing matches in the console UI.
        """
        if not isinstance(sdfg, SDFG):
            raise TypeError("Expected SDFG, got: {}".format(
                type(sdfg).__name__))
        if self.state_id == -1:
            graph = sdfg
        else:
            graph = sdfg.nodes()[self.state_id]
        string = type(self).__name__ + ' in '
        string += self.match_to_str(graph)
        return string

    def to_json(self, parent=None) -> Dict[str, Any]:
        props = serialize.all_properties_to_json(self)
        return {
            'type': 'PatternTransformation',
            'transformation': type(self).__name__,
            **props
        }

    @staticmethod
    def from_json(json_obj: Dict[str, Any],
                  context: Dict[str, Any] = None) -> 'PatternTransformation':
        xform = next(ext for ext in PatternTransformation.subclasses_recursive(
            all_subclasses=True) if ext.__name__ == json_obj['transformation'])

        # Recreate subgraph
        expr = xform.expressions()[json_obj['expr_index']]
        subgraph = {
            expr.node(int(k)): int(v)
            for k, v in json_obj['_subgraph'].items()
        }

        # Reconstruct transformation
        ret = xform(None, json_obj['sdfg_id'], json_obj['state_id'], subgraph,
                    json_obj['expr_index'])
        context = context or {}
        context['transformation'] = ret
        serialize.set_properties_from_json(
            ret,
            json_obj,
            context=context,
            ignore_properties={'transformation', 'type'})
        return ret
Exemplo n.º 16
0
class Transformation(object):
    """ Base class for transformations, as well as a static registry of
        transformations, where new transformations can be added in a
        decentralized manner.

        New transformations are registered with ``Transformation.register``
        (or ``dace.registry.autoregister_params``) with two optional boolean
        keyword arguments: ``singlestate`` (default: False) and ``strict``
        (default: False).
        If ``singlestate`` is True, the transformation is matched on subgraphs
        inside an SDFGState; otherwise, subgraphs of the SDFG state machine are
        matched.
        If ``strict`` is True, this transformation will be considered strict
        (i.e., always beneficial to perform) and will be performed automatically
        as part of SDFG strict transformations.
    """

    # Properties
    sdfg_id = Property(dtype=int, category="(Debug)")
    state_id = Property(dtype=int, category="(Debug)")
    _subgraph = DictProperty(key_type=int, value_type=int, category="(Debug)")
    expr_index = Property(dtype=int, category="(Debug)")

    @staticmethod
    def annotates_memlets():
        """ Indicates whether the transformation annotates the edges it creates
            or modifies with the appropriate memlets. This determines
            whether to apply memlet propagation after the transformation.
        """
        return False

    @staticmethod
    def expressions():
        """ Returns a list of Graph objects that will be matched in the
            subgraph isomorphism phase. Used as a pre-pass before calling
            `can_be_applied`.
            :see: Transformation.can_be_applied
        """
        raise NotImplementedError

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        """ Returns True if this transformation can be applied on the candidate
            matched subgraph.
            :param graph: SDFGState object if this Transformation is
                          single-state, or SDFG object otherwise.
            :param candidate: A mapping between node IDs returned from
                              `Transformation.expressions` and the nodes in
                              `graph`.
            :param expr_index: The list index from `Transformation.expressions`
                               that was matched.
            :param sdfg: If `graph` is an SDFGState, its parent SDFG. Otherwise
                         should be equal to `graph`.
            :param strict: Whether transformation should run in strict mode.
            :return: True if the transformation can be applied.
        """
        raise NotImplementedError

    @staticmethod
    def match_to_str(graph, candidate):
        """ Returns a string representation of the pattern match on the
            candidate subgraph. Used when identifying matches in the console
            UI.
        """
        return str(list(candidate.values()))

    def __init__(self, sdfg_id, state_id, subgraph, expr_index):
        """ Initializes an instance of Transformation.
            :param sdfg_id: A unique ID of the SDFG.
            :param state_id: The node ID of the SDFG state, if applicable.
            :param subgraph: A mapping between node IDs returned from
                             `Transformation.expressions` and the nodes in
                             `graph`.
            :param expr_index: The list index from `Transformation.expressions`
                               that was matched.
            :raise TypeError: When transformation is not subclass of
                              Transformation.
            :raise TypeError: When state_id is not instance of int.
            :raise TypeError: When subgraph is not a dict of
                              dace.sdfg.nodes.Node : int.
        """

        self.sdfg_id = sdfg_id
        self.state_id = state_id
        for value in subgraph.values():
            if not isinstance(value, int):
                raise TypeError('All values of '
                                'subgraph'
                                ' dictionary must be '
                                'instances of int.')
        # Serializable subgraph with node IDs as keys
        expr = self.expressions()[expr_index]
        self._subgraph = {expr.node_id(k): v for k, v in subgraph.items()}
        self._subgraph_user = subgraph
        self.expr_index = expr_index

    @property
    def subgraph(self):
        return self._subgraph_user

    def query_node(
            self, sdfg: SDFG,
            pattern_node: Union[nd.Node,
                                SDFGState]) -> Union[nd.Node, SDFGState]:
        """ 
        Returns the matched node object (from a subgraph pattern node) in its
        original graph.
        :param sdfg: The SDFG on which this transformation is applied.
        :param pattern_node: The node object in the transformation properties.
        :return: The node object in the matched graph.
        """
        graph = sdfg if self.state_id == -1 else sdfg.node(self.state_id)
        return graph.node(self.subgraph[pattern_node])

    def __lt__(self, other):
        """ Comparing two transformations by their class name and node IDs
            in match. Used for ordering transformations consistently.
        """
        if type(self) != type(other):
            return type(self).__name__ < type(other).__name__

        self_ids = iter(self.subgraph.values())
        other_ids = iter(self.subgraph.values())

        try:
            self_id = next(self_ids)
        except StopIteration:
            return True
        try:
            other_id = next(other_ids)
        except StopIteration:
            return False

        self_end = False

        while self_id is not None and other_id is not None:
            if self_id != other_id:
                return self_id < other_id
            try:
                self_id = next(self_ids)
            except StopIteration:
                self_end = True
            try:
                other_id = next(other_ids)
            except StopIteration:
                if self_end:  # Transformations are equal
                    return False
                return False
            if self_end:
                return True

    def apply_pattern(self, sdfg):
        """ Applies this transformation on the given SDFG. """
        sdfg.append_transformation(self)
        self.apply(sdfg)
        if not self.annotates_memlets():
            propagation.propagate_memlets_sdfg(sdfg)

    @classmethod
    def apply_to(cls,
                 sdfg: SDFG,
                 options: Optional[Dict[str, Any]] = None,
                 expr_index: int = 0,
                 verify: bool = True,
                 strict: bool = False,
                 **where: Union[nd.Node, SDFGState]):
        """
        Applies this transformation to a given subgraph, defined by a set of
        nodes. Raises an error if arguments are invalid or transformation is
        not applicable.
        :param sdfg: The SDFG to apply the transformation to.
        :param options: A set of parameters to use for applying the 
                        transformation.
        :param expr_index: The pattern expression index to try to match with.
        :param verify: Check that `can_be_applied` returns True before applying.
        :param strict: Apply transformation in strict mode.
        :param where: A dictionary of node names (from the transformation) to
                      nodes in the SDFG or a single state.
        """
        if len(where) == 0:
            raise ValueError('At least one node is required')
        options = options or {}

        # Check that all keyword arguments are nodes and if interstate or not
        sample_node = next(iter(where.values()))

        if isinstance(sample_node, SDFGState):
            graph = sdfg
            state_id = -1
        elif isinstance(sample_node, nd.Node):
            graph = next(s for s in sdfg.nodes() if sample_node in s.nodes())
            state_id = sdfg.node_id(graph)
        else:
            raise TypeError('Invalid node type "%s"' %
                            type(sample_node).__name__)

        # Check that all nodes in the pattern are set
        required_nodes = cls.expressions()[expr_index].nodes()
        required_node_names = {
            pname[1:]: pval
            for pname, pval in cls.__dict__.items()
            if pname.startswith('_') and pval in required_nodes
        }
        required = set(required_node_names.keys())
        intersection = required & set(where.keys())
        if len(required - intersection) > 0:
            raise ValueError('Missing nodes for transformation subgraph: %s' %
                             (required - intersection))

        # Construct subgraph and instantiate transformation
        subgraph = {
            required_node_names[k]: graph.node_id(where[k])
            for k in required
        }
        instance = cls(sdfg.sdfg_id, state_id, subgraph, expr_index)

        # Construct transformation parameters
        for optname, optval in options.items():
            if not optname in cls.__properties__:
                raise ValueError('Property "%s" not found in transformation' %
                                 optname)
            setattr(instance, optname, optval)

        if verify:
            if not cls.can_be_applied(
                    graph, subgraph, expr_index, sdfg, strict=strict):
                raise ValueError('Transformation cannot be applied on the '
                                 'given subgraph ("can_be_applied" failed)')

        # Apply to SDFG
        instance.apply_pattern(sdfg)

    def __str__(self):
        return type(self).__name__

    def print_match(self, sdfg):
        """ Returns a string representation of the pattern match on the
            given SDFG. Used for printing matches in the console UI.
        """
        if not isinstance(sdfg, SDFG):
            raise TypeError("Expected SDFG, got: {}".format(
                type(sdfg).__name__))
        if self.state_id == -1:
            graph = sdfg
        else:
            graph = sdfg.nodes()[self.state_id]
        string = type(self).__name__ + ' in '
        string += type(self).match_to_str(graph, self.subgraph)
        return string

    def to_json(self, parent=None):
        props = serialize.all_properties_to_json(self)
        return {
            'type': 'Transformation',
            'transformation': type(self).__name__,
            **props
        }

    @staticmethod
    def from_json(json_obj, context=None):
        xform = next(ext for ext in Transformation.extensions().keys()
                     if ext.__name__ == json_obj['transformation'])

        # Recreate subgraph
        expr = xform.expressions()[json_obj['expr_index']]
        subgraph = {
            expr.node(int(k)): int(v)
            for k, v in json_obj['_subgraph'].items()
        }

        # Reconstruct transformation
        ret = xform(json_obj['sdfg_id'], json_obj['state_id'], subgraph,
                    json_obj['expr_index'])
        context = context or {}
        context['transformation'] = ret
        serialize.set_properties_from_json(
            ret,
            json_obj,
            context=context,
            ignore_properties={'transformation', 'type'})
        return ret