class CodeNode(Node): """ A node that contains runnable code with acyclic external data dependencies. May either be a tasklet or a nested SDFG, and denoted by an octagonal shape. """ label = Property(dtype=str, desc="Name of the CodeNode") location = DictProperty( key_type=str, value_type=dace.symbolic.pystr_to_symbolic, desc='Full storage location identifier (e.g., rank, GPU ID)') environments = SetProperty( str, desc="Environments required by CMake to build and run this code node.", default=set()) def __init__(self, label="", location=None, inputs=None, outputs=None): super(CodeNode, self).__init__(inputs or set(), outputs or set()) # Properties self.label = label self.location = location if location is not None else {} @property def free_symbols(self) -> Set[str]: return set().union(*(map(str, pystr_to_symbolic(v).free_symbols) for v in self.location.values()))
class CodeObject(object): name = Property(dtype=str, desc="Filename to use") code = Property(dtype=str, desc="The code attached to this object") language = Property(dtype=str, desc="Language used for this code (same " + "as its file extension)") target = Property(dtype=type, desc="Target to use for compilation", allow_none=True) target_type = Property( dtype=str, desc="Sub-target within target (e.g., host or device code)", default="") title = Property(dtype=str, desc="Title of code for GUI") extra_compiler_kwargs = DictProperty(key_type=str, value_type=str, desc="Additional compiler argument " "variables to add to template") linkable = Property(dtype=bool, desc='Should this file participate in ' 'overall linkage?') environments = SetProperty( str, desc="Environments required by CMake to build and run this code node.", default=set()) def __init__(self, name, code, language, target, title, target_type="", additional_compiler_kwargs=None, linkable=True, environments=None, sdfg=None): super(CodeObject, self).__init__() self.name = name self.code = code self.language = language self.target = target self.target_type = target_type self.title = title self.extra_compiler_kwargs = additional_compiler_kwargs or {} self.linkable = linkable self.environments = environments or set() if language == 'cpp' and title == 'Frame' and sdfg: sourcemap.create_maps(sdfg, code, self.target.target_name) @property def clean_code(self): return re.sub(r'[ \t]*////__(DACE:|CODEGEN;)[^\n]*', '', self.code)
class CodeObject(object): name = Property(dtype=str, desc="Filename to use") code = Property(dtype=str, desc="The code attached to this object") language = Property(dtype=str, desc="Language used for this code (same " + "as its file extension)") # dtype=dtypes.Language? target = Property(dtype=type, desc="Target to use for compilation", allow_none=True) target_type = Property( dtype=str, desc="Sub-target within target (e.g., host or device code)", default="") title = Property(dtype=str, desc="Title of code for GUI") extra_compiler_kwargs = DictProperty(key_type=str, value_type=str, desc="Additional compiler argument " "variables to add to template") linkable = Property(dtype=bool, desc='Should this file participate in ' 'overall linkage?') environments = SetProperty( str, desc="Environments required by CMake to build and run this code node.", default=set()) def __init__(self, name, code, language, target, title, target_type="", additional_compiler_kwargs=None, linkable=True, environments=set()): super(CodeObject, self).__init__() self.name = name self.code = code self.language = language self.target = target self.target_type = target_type self.title = title self.extra_compiler_kwargs = additional_compiler_kwargs or {} self.linkable = linkable self.environments = environments
class RTLTasklet(Tasklet): """ A specialized tasklet, which is a functional computation procedure that can only access external data specified using connectors. This tasklet is specialized for tasklets implemented in System Verilog in that it adds support for adding metadata about the IP cores in use. """ # TODO to be replaced when enums have embedded properties ip_cores = DictProperty(key_type=str, value_type=dict, desc="A set of IP cores used by the tasklet.") @property def __jsontype__(self): return 'Tasklet' def add_ip_core(self, module_name, name, vendor, version, params): self.ip_cores[module_name] = { 'name': name, 'vendor': vendor, 'version': version, 'params': params }
class NestedSDFG(CodeNode): """ An SDFG state node that contains an SDFG of its own, runnable using the data dependencies specified using its connectors. It is encouraged to use nested SDFGs instead of coarse-grained tasklets since they are analyzable with respect to transformations. @note: A nested SDFG cannot create recursion (one of its parent SDFGs). """ # NOTE: We cannot use SDFG as the type because of an import loop sdfg = SDFGReferenceProperty(desc="The SDFG", allow_none=True) schedule = Property(dtype=dtypes.ScheduleType, desc="SDFG schedule", allow_none=True, choices=dtypes.ScheduleType, from_string=lambda x: dtypes.ScheduleType[x], default=dtypes.ScheduleType.Default) symbol_mapping = DictProperty( key_type=str, value_type=dace.symbolic.pystr_to_symbolic, desc="Mapping between internal symbols and their values, expressed as " "symbolic expressions") debuginfo = DebugInfoProperty() is_collapsed = Property(dtype=bool, desc="Show this node/scope/state as collapsed", default=False) instrument = Property(choices=dtypes.InstrumentationType, desc="Measure execution statistics with given method", default=dtypes.InstrumentationType.No_Instrumentation) def __init__(self, label, sdfg, inputs: Set[str], outputs: Set[str], symbol_mapping: Dict[str, Any] = None, schedule=dtypes.ScheduleType.Default, location=None, debuginfo=None): super(NestedSDFG, self).__init__(label, location, inputs, outputs) # Properties self.sdfg = sdfg self.symbol_mapping = symbol_mapping or {} self.schedule = schedule self.debuginfo = debuginfo @staticmethod def from_json(json_obj, context=None): from dace import SDFG # Avoid import loop # We have to load the SDFG first. ret = NestedSDFG("nolabel", SDFG('nosdfg'), {}, {}) dace.serialize.set_properties_from_json(ret, json_obj, context) if context and 'sdfg_state' in context: ret.sdfg.parent = context['sdfg_state'] if context and 'sdfg' in context: ret.sdfg.parent_sdfg = context['sdfg'] ret.sdfg.parent_nsdfg_node = ret ret.sdfg.update_sdfg_list([]) return ret @property def free_symbols(self) -> Set[str]: return set().union( *(map(str, pystr_to_symbolic(v).free_symbols) for v in self.symbol_mapping.values()), *(map(str, pystr_to_symbolic(v).free_symbols) for v in self.location.values())) def infer_connector_types(self, sdfg, state): # Avoid import loop from dace.sdfg.infer_types import infer_connector_types # Infer internal connector types infer_connector_types(self.sdfg) def __str__(self): if not self.label: return "SDFG" else: return self.label def validate(self, sdfg, state): if not dtypes.validate_name(self.label): raise NameError('Invalid nested SDFG name "%s"' % self.label) for in_conn in self.in_connectors: if not dtypes.validate_name(in_conn): raise NameError('Invalid input connector "%s"' % in_conn) for out_conn in self.out_connectors: if not dtypes.validate_name(out_conn): raise NameError('Invalid output connector "%s"' % out_conn) connectors = self.in_connectors.keys() | self.out_connectors.keys() for dname, desc in self.sdfg.arrays.items(): # TODO(later): Disallow scalars without access nodes (so that this # check passes for them too). if isinstance(desc, data.Scalar): continue if not desc.transient and dname not in connectors: raise NameError('Data descriptor "%s" not found in nested ' 'SDFG connectors' % dname) if dname in connectors and desc.transient: raise NameError( '"%s" is a connector but its corresponding array is transient' % dname) # Validate undefined symbols symbols = set(k for k in self.sdfg.free_symbols if k not in connectors) missing_symbols = [s for s in symbols if s not in self.symbol_mapping] if missing_symbols: raise ValueError('Missing symbols on nested SDFG: %s' % (missing_symbols)) # Recursively validate nested SDFG self.sdfg.validate()
class Node(object): """ Base node class. """ in_connectors = DictProperty( key_type=str, value_type=dtypes.typeclass, desc="A set of input connectors for this node.") out_connectors = DictProperty( key_type=str, value_type=dtypes.typeclass, desc="A set of output connectors for this node.") def __init__(self, in_connectors=None, out_connectors=None): # Convert connectors to typed connectors with autodetect type if isinstance(in_connectors, (set, list, KeysView)): in_connectors = {k: None for k in in_connectors} if isinstance(out_connectors, (set, list, KeysView)): out_connectors = {k: None for k in out_connectors} self.in_connectors = in_connectors or {} self.out_connectors = out_connectors or {} def __str__(self): if hasattr(self, 'label'): return self.label else: return type(self).__name__ def validate(self, sdfg, state): pass def to_json(self, parent): labelstr = str(self) typestr = getattr(self, '__jsontype__', str(type(self).__name__)) try: scope_entry_node = parent.entry_node(self) except (RuntimeError, StopIteration): scope_entry_node = None if scope_entry_node is not None: ens = parent.exit_node(parent.entry_node(self)) scope_exit_node = str(parent.node_id(ens)) scope_entry_node = str(parent.node_id(scope_entry_node)) else: scope_entry_node = None scope_exit_node = None # The scope exit of an entry node is the matching exit node if isinstance(self, EntryNode): try: scope_exit_node = str(parent.node_id(parent.exit_node(self))) except (RuntimeError, StopIteration): scope_exit_node = None retdict = { "type": typestr, "label": labelstr, "attributes": dace.serialize.all_properties_to_json(self), "id": parent.node_id(self), "scope_entry": scope_entry_node, "scope_exit": scope_exit_node } return retdict def __repr__(self): return type(self).__name__ + ' (' + self.__str__() + ')' def add_in_connector(self, connector_name: str, dtype: dtypes.typeclass = None): """ Adds a new input connector to the node. The operation will fail if a connector (either input or output) with the same name already exists in the node. :param connector_name: The name of the new connector. :param dtype: The type of the connector, or None for auto-detect. :return: True if the operation is successful, otherwise False. """ if (connector_name in self.in_connectors or connector_name in self.out_connectors): return False connectors = self.in_connectors connectors[connector_name] = dtype self.in_connectors = connectors return True def add_out_connector(self, connector_name: str, dtype: dtypes.typeclass = None): """ Adds a new output connector to the node. The operation will fail if a connector (either input or output) with the same name already exists in the node. :param connector_name: The name of the new connector. :param dtype: The type of the connector, or None for auto-detect. :return: True if the operation is successful, otherwise False. """ if (connector_name in self.in_connectors or connector_name in self.out_connectors): return False connectors = self.out_connectors connectors[connector_name] = dtype self.out_connectors = connectors return True def remove_in_connector(self, connector_name: str): """ Removes an input connector from the node. :param connector_name: The name of the connector to remove. :return: True if the operation was successful. """ if connector_name in self.in_connectors: connectors = self.in_connectors del connectors[connector_name] self.in_connectors = connectors return True def remove_out_connector(self, connector_name: str): """ Removes an output connector from the node. :param connector_name: The name of the connector to remove. :return: True if the operation was successful. """ if connector_name in self.out_connectors: connectors = self.out_connectors del connectors[connector_name] self.out_connectors = connectors return True def _next_connector_int(self) -> int: """ Returns the next unused connector ID (as an integer). Used for filling connectors when adding edges to scopes. """ next_number = 1 for conn in itertools.chain(self.in_connectors, self.out_connectors): if conn.startswith('IN_'): cconn = conn[3:] elif conn.startswith('OUT_'): cconn = conn[4:] else: continue try: curconn = int(cconn) if curconn >= next_number: next_number = curconn + 1 except (TypeError, ValueError): # not integral continue return next_number def next_connector(self) -> str: """ Returns the next unused connector ID (as a string). Used for filling connectors when adding edges to scopes. """ return str(self._next_connector_int()) def last_connector(self) -> str: """ Returns the last used connector ID (as a string). Used for filling connectors when adding edges to scopes. """ return str(self._next_connector_int() - 1) @property def free_symbols(self) -> Set[str]: """ Returns a set of symbols used in this node's properties. """ return set() def new_symbols(self, sdfg, state, symbols) -> Dict[str, dtypes.typeclass]: """ Returns a mapping between symbols defined by this node (e.g., for scope entries) to their type. """ return {} def infer_connector_types(self, sdfg, state): """ Infers and fills remaining connectors (i.e., set to None) with their types. """ pass
class Data(object): """ Data type descriptors that can be used as references to memory. Examples: Arrays, Streams, custom arrays (e.g., sparse matrices). """ dtype = TypeClassProperty(default=dtypes.int32, choices=dtypes.Typeclasses) shape = ShapeProperty(default=[]) transient = Property(dtype=bool, default=False) storage = EnumProperty(dtype=dtypes.StorageType, desc="Storage location", default=dtypes.StorageType.Default) lifetime = EnumProperty(dtype=dtypes.AllocationLifetime, desc='Data allocation span', default=dtypes.AllocationLifetime.Scope) location = DictProperty( key_type=str, value_type=symbolic.pystr_to_symbolic, desc='Full storage location identifier (e.g., rank, GPU ID)') debuginfo = DebugInfoProperty(allow_none=True) def __init__(self, dtype, shape, transient, storage, location, lifetime, debuginfo): self.dtype = dtype self.shape = shape self.transient = transient self.storage = storage self.location = location if location is not None else {} self.lifetime = lifetime self.debuginfo = debuginfo self._validate() def validate(self): """ Validate the correctness of this object. Raises an exception on error. """ self._validate() # Validation of this class is in a separate function, so that this # class can call `_validate()` without calling the subclasses' # `validate` function. def _validate(self): if any(not isinstance(s, (int, symbolic.SymExpr, symbolic.symbol, symbolic.sympy.Basic)) for s in self.shape): raise TypeError('Shape must be a list or tuple of integer values ' 'or symbols') return True def to_json(self): attrs = serialize.all_properties_to_json(self) retdict = {"type": type(self).__name__, "attributes": attrs} return retdict @property def toplevel(self): return self.lifetime is not dtypes.AllocationLifetime.Scope def copy(self): raise RuntimeError( 'Data descriptors are unique and should not be copied') def is_equivalent(self, other): """ Check for equivalence (shape and type) of two data descriptors. """ raise NotImplementedError def as_arg(self, with_types=True, for_call=False, name=None): """Returns a string for a C++ function signature (e.g., `int *A`). """ raise NotImplementedError @property def free_symbols(self) -> Set[symbolic.SymbolicType]: """ Returns a set of undefined symbols in this data descriptor. """ result = set() for s in self.shape: if isinstance(s, sp.Basic): result |= set(s.free_symbols) return result def __repr__(self): return 'Abstract Data Container, DO NOT USE' @property def veclen(self): return self.dtype.veclen if hasattr(self.dtype, "veclen") else 1 @property def ctype(self): return self.dtype.ctype
class AccumulateTransient(transformation.Transformation): """ Implements the AccumulateTransient transformation, which adds transient stream and data nodes between nested maps that lead to a stream. The transient data nodes then act as a local accumulator. """ map_exit = transformation.PatternNode(nodes.MapExit) outer_map_exit = transformation.PatternNode(nodes.MapExit) array_identity_dict = DictProperty(key_type=str, value_type=symbolic.pystr_to_symbolic, desc="dict with key: Array and" "value: the Identity value to set", default=dict(), allow_none=True) array = Property( dtype=str, desc="Array to create local storage for (if empty, first available)", default=None, allow_none=True) prefix = Property(dtype=str, default="trans_", allow_none=True, desc='Prefix for new data node') identity = SymbolicProperty(desc="Identity value to set", default=None, allow_none=True) @staticmethod def expressions(): return [ sdutil.node_path_graph(AccumulateTransient.map_exit, AccumulateTransient.outer_map_exit) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): map_exit = graph.nodes()[candidate[AccumulateTransient.map_exit]] outer_map_exit = graph.nodes()[candidate[ AccumulateTransient.outer_map_exit]] # Check if there is an accumulation output for e in graph.edges_between(map_exit, outer_map_exit): if e.data.wcr is not None: return True return False @staticmethod def match_to_str(graph, candidate): map_exit = candidate[AccumulateTransient.map_exit] outer_map_exit = candidate[AccumulateTransient.outer_map_exit] return ' -> '.join(str(node) for node in [map_exit, outer_map_exit]) def apply(self, sdfg: SDFG): graph = sdfg.node(self.state_id) map_exit = graph.node(self.subgraph[AccumulateTransient.map_exit]) outer_map_exit = graph.node( self.subgraph[AccumulateTransient.outer_map_exit]) # Avoid import loop from dace.transformation.dataflow.local_storage import OutLocalStorage array_identity_dict = self.array_identity_dict # Choose array array = self.array if array is not None and len(array) != 0: array_identity_dict[array] = self.identity elif ((array is None or len(array) == 0) and len(array_identity_dict) == 0): array = next(e.data.data for e in graph.edges_between(map_exit, outer_map_exit) if e.data.wcr is not None) array_identity_dict[array] = self.identity transients: Dict[str, Any] = {} for array, identity in array_identity_dict.items(): data_node: nodes.AccessNode = OutLocalStorage.apply_to( sdfg, dict(array=array, prefix=self.prefix), verify=False, save=False, node_a=map_exit, node_b=outer_map_exit) transients[data_node.data] = identity if identity is None: warnings.warn( 'AccumulateTransient did not properly initialize ' 'newly-created transient!') return sdfg_state: SDFGState = sdfg.node(self.state_id) map_entry = sdfg_state.entry_node(map_exit) nested_sdfg: nodes.NestedSDFG = nest_state_subgraph( sdfg=sdfg, state=sdfg_state, subgraph=SubgraphView( sdfg_state, {map_entry, map_exit} | sdfg_state.all_nodes_between(map_entry, map_exit))) nested_sdfg_state: SDFGState = nested_sdfg.sdfg.nodes()[0] init_state = nested_sdfg.sdfg.add_state_before(nested_sdfg_state) for data_name, identity in transients.items(): temp_array: Array = sdfg.arrays[data_name] init_state.add_mapped_tasklet( name='acctrans_init', map_ranges={ '_o%d' % i: '0:%s' % symbolic.symstr(d) for i, d in enumerate(temp_array.shape) }, inputs={}, code='out = %s' % identity, outputs={ 'out': dace.Memlet.simple( data=data_name, subset_str=','.join([ '_o%d' % i for i, _ in enumerate(temp_array.shape) ])) }, external_edges=True) # TODO: use trivial map elimintation here when it will be merged to remove map if it has trivial ranges return nested_sdfg
class Transformation(TransformationBase): """ Base class for pattern-matching transformations, as well as a static registry of transformations, where new transformations can be added in a decentralized manner. An instance of a Transformation represents a match of the transformation on an SDFG, complete with a subgraph candidate and properties. New transformations that extend this class must contain static `PatternNode` fields that represent the nodes in the pattern graph, and use them to implement at least three methods: * `expressions`: A method that returns a list of graph patterns (SDFG or SDFGState objects) that match this transformation. * `can_be_applied`: A method that, given a subgraph candidate, checks for additional conditions whether it can be transformed. * `apply`: A method that applies the transformation on the given SDFG. For more information and optimization opportunities, see the respective methods' documentation. In order to be included in lists and apply through the `sdfg.apply_transformations` API, each transformation shouls be registered with ``Transformation.register`` (or, more commonly, the ``@dace.registry.autoregister_params`` class decorator) with two optional boolean keyword arguments: ``singlestate`` (default: False) and ``strict`` (default: False). If ``singlestate`` is True, the transformation is matched on subgraphs inside an SDFGState; otherwise, subgraphs of the SDFG state machine are matched. If ``strict`` is True, this transformation will be considered strict (i.e., always beneficial to perform) and will be performed automatically as part of SDFG strict transformations. """ # Properties sdfg_id = Property(dtype=int, category="(Debug)") state_id = Property(dtype=int, category="(Debug)") _subgraph = DictProperty(key_type=int, value_type=int, category="(Debug)") expr_index = Property(dtype=int, category="(Debug)") def annotates_memlets(self) -> bool: """ Indicates whether the transformation annotates the edges it creates or modifies with the appropriate memlets. This determines whether to apply memlet propagation after the transformation. """ return False def expressions(self) -> List[gr.SubgraphView]: """ Returns a list of Graph objects that will be matched in the subgraph isomorphism phase. Used as a pre-pass before calling `can_be_applied`. :see: Transformation.can_be_applied """ raise NotImplementedError def can_be_applied(self, graph: Union[SDFG, SDFGState], candidate: Dict['PatternNode', int], expr_index: int, sdfg: SDFG, strict: bool = False) -> bool: """ Returns True if this transformation can be applied on the candidate matched subgraph. :param graph: SDFGState object if this Transformation is single-state, or SDFG object otherwise. :param candidate: A mapping between node IDs returned from `Transformation.expressions` and the nodes in `graph`. :param expr_index: The list index from `Transformation.expressions` that was matched. :param sdfg: If `graph` is an SDFGState, its parent SDFG. Otherwise should be equal to `graph`. :param strict: Whether transformation should run in strict mode. :return: True if the transformation can be applied. """ raise NotImplementedError def apply(self, sdfg: SDFG) -> Union[Any, None]: """ Applies this transformation instance on the matched pattern graph. :param sdfg: The SDFG to apply the transformation to. :return: A transformation-defined return value, which could be used to pass analysis data out, or nothing. """ raise NotImplementedError def match_to_str(self, graph: Union[SDFG, SDFGState], candidate: Dict['PatternNode', int]) -> str: """ Returns a string representation of the pattern match on the candidate subgraph. Used when identifying matches in the console UI. """ return str(list(candidate.values())) def __init__(self, sdfg_id: int, state_id: int, subgraph: Dict['PatternNode', int], expr_index: int, override: bool = False, options: Optional[Dict[str, Any]] = None) -> None: """ Initializes an instance of Transformation match. :param sdfg_id: A unique ID of the SDFG. :param state_id: The node ID of the SDFG state, if applicable. If transformation does not operate on a single state, the value should be -1. :param subgraph: A mapping between node IDs returned from `Transformation.expressions` and the nodes in `graph`. :param expr_index: The list index from `Transformation.expressions` that was matched. :param override: If True, accepts the subgraph dictionary as-is (mostly for internal use). :param options: An optional dictionary of transformation properties :raise TypeError: When transformation is not subclass of Transformation. :raise TypeError: When state_id is not instance of int. :raise TypeError: When subgraph is not a dict of PatternNode : int. """ self.sdfg_id = sdfg_id self.state_id = state_id if not override: expr = self.expressions()[expr_index] for value in subgraph.values(): if not isinstance(value, int): raise TypeError('All values of ' 'subgraph' ' dictionary must be ' 'instances of int.') self._subgraph = {expr.node_id(k): v for k, v in subgraph.items()} else: self._subgraph = {-1: -1} # Serializable subgraph with node IDs as keys self._subgraph_user = copy.copy(subgraph) self.expr_index = expr_index # Ease-of-use API: Set new pattern-nodes with information about this # instance. for pname, pval in self._get_pattern_nodes().items(): # Create new pattern node from existing field new_pnode = PatternNode( pval.node if isinstance(pval, PatternNode) else type(pval)) new_pnode.match_instance = self # Append existing values in subgraph dictionary if pval in self._subgraph_user: self._subgraph_user[new_pnode] = self._subgraph_user[pval] # Override static field with the new node in this instance only setattr(self, pname, new_pnode) # Set properties if options is not None: for optname, optval in options.items(): setattr(self, optname, optval) @property def subgraph(self): return self._subgraph_user def apply_pattern(self, sdfg: SDFG, append: bool = True, annotate: bool = True) -> Union[Any, None]: """ Applies this transformation on the given SDFG, using the transformation instance to find the right SDFG object (based on SDFG ID), and applying memlet propagation as necessary. :param sdfg: The SDFG (or an SDFG in the same hierarchy) to apply the transformation to. :param append: If True, appends the transformation to the SDFG transformation history. :return: A transformation-defined return value, which could be used to pass analysis data out, or nothing. """ if append: sdfg.append_transformation(self) tsdfg: SDFG = sdfg.sdfg_list[self.sdfg_id] retval = self.apply(tsdfg) if annotate and not self.annotates_memlets(): propagation.propagate_memlets_sdfg(tsdfg) return retval def __lt__(self, other: 'Transformation') -> bool: """ Comparing two transformations by their class name and node IDs in match. Used for ordering transformations consistently. """ if type(self) != type(other): return type(self).__name__ < type(other).__name__ self_ids = iter(self.subgraph.values()) other_ids = iter(self.subgraph.values()) try: self_id = next(self_ids) except StopIteration: return True try: other_id = next(other_ids) except StopIteration: return False self_end = False while self_id is not None and other_id is not None: if self_id != other_id: return self_id < other_id try: self_id = next(self_ids) except StopIteration: self_end = True try: other_id = next(other_ids) except StopIteration: if self_end: # Transformations are equal return False return False if self_end: return True @classmethod def _get_pattern_nodes(cls) -> Dict[str, 'PatternNode']: """ Returns a dictionary of pattern-matching node in this transformation subclass. Used internally for pattern-matching. :return: A dictionary mapping between pattern-node name and its type. """ return { k: getattr(cls, k) for k in dir(cls) if isinstance(getattr(cls, k), PatternNode) or (k.startswith( '_') and isinstance(getattr(cls, k), (nd.Node, SDFGState))) } @classmethod def apply_to(cls, sdfg: SDFG, options: Optional[Dict[str, Any]] = None, expr_index: int = 0, verify: bool = True, annotate: bool = True, strict: bool = False, save: bool = True, **where: Union[nd.Node, SDFGState]): """ Applies this transformation to a given subgraph, defined by a set of nodes. Raises an error if arguments are invalid or transformation is not applicable. The subgraph is defined by the `where` dictionary, where each key is taken from the `PatternNode` fields of the transformation. For example, applying `MapCollapse` on two maps can pe performed as follows: ``` MapCollapse.apply_to(sdfg, outer_map_entry=map_a, inner_map_entry=map_b) ``` :param sdfg: The SDFG to apply the transformation to. :param options: A set of parameters to use for applying the transformation. :param expr_index: The pattern expression index to try to match with. :param verify: Check that `can_be_applied` returns True before applying. :param annotate: Run memlet propagation after application if necessary. :param strict: Apply transformation in strict mode. :param save: Save transformation as part of the SDFG file. Set to False if composing transformations. :param where: A dictionary of node names (from the transformation) to nodes in the SDFG or a single state. """ if len(where) == 0: raise ValueError('At least one node is required') options = options or {} # Check that all keyword arguments are nodes and if interstate or not sample_node = next(iter(where.values())) if isinstance(sample_node, SDFGState): graph = sdfg state_id = -1 elif isinstance(sample_node, nd.Node): graph = next(s for s in sdfg.nodes() if sample_node in s.nodes()) state_id = sdfg.node_id(graph) else: raise TypeError('Invalid node type "%s"' % type(sample_node).__name__) # Check that all nodes in the pattern are set required_nodes = cls.expressions()[expr_index].nodes() required_node_names = { pname: pval for pname, pval in cls._get_pattern_nodes().items() if pval in required_nodes } required = set(required_node_names.keys()) intersection = required & set(where.keys()) if len(required - intersection) > 0: raise ValueError('Missing nodes for transformation subgraph: %s' % (required - intersection)) # Construct subgraph and instantiate transformation subgraph = { required_node_names[k]: graph.node_id(where[k]) for k in required } instance = cls(sdfg.sdfg_id, state_id, subgraph, expr_index) # Construct transformation parameters for optname, optval in options.items(): if not optname in cls.__properties__: raise ValueError('Property "%s" not found in transformation' % optname) setattr(instance, optname, optval) if verify: if not instance.can_be_applied( graph, subgraph, expr_index, sdfg, strict=strict): raise ValueError('Transformation cannot be applied on the ' 'given subgraph ("can_be_applied" failed)') # Apply to SDFG return instance.apply_pattern(sdfg, annotate=annotate, append=save) def __str__(self) -> str: return type(self).__name__ def print_match(self, sdfg: SDFG) -> str: """ Returns a string representation of the pattern match on the given SDFG. Used for printing matches in the console UI. """ if not isinstance(sdfg, SDFG): raise TypeError("Expected SDFG, got: {}".format( type(sdfg).__name__)) if self.state_id == -1: graph = sdfg else: graph = sdfg.nodes()[self.state_id] string = type(self).__name__ + ' in ' string += self.match_to_str(graph, self.subgraph) return string def to_json(self, parent=None) -> Dict[str, Any]: props = serialize.all_properties_to_json(self) return { 'type': 'Transformation', 'transformation': type(self).__name__, **props } @staticmethod def from_json(json_obj: Dict[str, Any], context: Dict[str, Any] = None) -> 'Transformation': xform = next(ext for ext in Transformation.extensions().keys() if ext.__name__ == json_obj['transformation']) # Recreate subgraph expr = xform.expressions()[json_obj['expr_index']] subgraph = { expr.node(int(k)): int(v) for k, v in json_obj['_subgraph'].items() } # Reconstruct transformation ret = xform(json_obj['sdfg_id'], json_obj['state_id'], subgraph, json_obj['expr_index']) context = context or {} context['transformation'] = ret serialize.set_properties_from_json( ret, json_obj, context=context, ignore_properties={'transformation', 'type'}) return ret
class Data(object): """ Data type descriptors that can be used as references to memory. Examples: Arrays, Streams, custom arrays (e.g., sparse matrices). """ dtype = TypeClassProperty(default=dtypes.int32, choices=dtypes.Typeclasses) shape = ShapeProperty(default=[]) transient = Property(dtype=bool, default=False) storage = EnumProperty(dtype=dtypes.StorageType, desc="Storage location", default=dtypes.StorageType.Default) lifetime = EnumProperty(dtype=dtypes.AllocationLifetime, desc='Data allocation span', default=dtypes.AllocationLifetime.Scope) location = DictProperty(key_type=str, value_type=str, desc='Full storage location identifier (e.g., rank, GPU ID)') debuginfo = DebugInfoProperty(allow_none=True) def __init__(self, dtype, shape, transient, storage, location, lifetime, debuginfo): self.dtype = dtype self.shape = shape self.transient = transient self.storage = storage self.location = location if location is not None else {} self.lifetime = lifetime self.debuginfo = debuginfo self._validate() def validate(self): """ Validate the correctness of this object. Raises an exception on error. """ self._validate() # Validation of this class is in a separate function, so that this # class can call `_validate()` without calling the subclasses' # `validate` function. def _validate(self): if any(not isinstance(s, (int, symbolic.SymExpr, symbolic.symbol, symbolic.sympy.Basic)) for s in self.shape): raise TypeError('Shape must be a list or tuple of integer values ' 'or symbols') return True def to_json(self): attrs = serialize.all_properties_to_json(self) retdict = {"type": type(self).__name__, "attributes": attrs} return retdict @property def toplevel(self): return self.lifetime is not dtypes.AllocationLifetime.Scope def copy(self): raise RuntimeError('Data descriptors are unique and should not be copied') def is_equivalent(self, other): """ Check for equivalence (shape and type) of two data descriptors. """ raise NotImplementedError def as_arg(self, with_types=True, for_call=False, name=None): """Returns a string for a C++ function signature (e.g., `int *A`). """ raise NotImplementedError @property def free_symbols(self) -> Set[symbolic.SymbolicType]: """ Returns a set of undefined symbols in this data descriptor. """ result = set() for s in self.shape: if isinstance(s, sp.Basic): result |= set(s.free_symbols) return result def __repr__(self): return 'Abstract Data Container, DO NOT USE' @property def veclen(self): return self.dtype.veclen if hasattr(self.dtype, "veclen") else 1 @property def ctype(self): return self.dtype.ctype def strides_from_layout( self, *dimensions: int, alignment: symbolic.SymbolicType = 1, only_first_aligned: bool = False, ) -> Tuple[Tuple[symbolic.SymbolicType], symbolic.SymbolicType]: """ Returns the absolute strides and total size of this data descriptor, according to the given dimension ordering and alignment. :param dimensions: A sequence of integers representing a permutation of the descriptor's dimensions. :param alignment: Padding (in elements) at the end, ensuring stride is a multiple of this number. 1 (default) means no padding. :param only_first_aligned: If True, only the first dimension is padded with ``alignment``. Otherwise all dimensions are. :return: A 2-tuple of (tuple of strides, total size). """ # Verify dimensions if tuple(sorted(dimensions)) != tuple(range(len(self.shape))): raise ValueError('Every dimension must be given and appear once.') if (alignment < 1) == True or (alignment < 0) == True: raise ValueError('Invalid alignment value') strides = [1] * len(dimensions) total_size = 1 first = True for dim in dimensions: strides[dim] = total_size if not only_first_aligned or first: dimsize = (((self.shape[dim] + alignment - 1) // alignment) * alignment) else: dimsize = self.shape[dim] total_size *= dimsize first = False return (tuple(strides), total_size) def set_strides_from_layout(self, *dimensions: int, alignment: symbolic.SymbolicType = 1, only_first_aligned: bool = False): """ Sets the absolute strides and total size of this data descriptor, according to the given dimension ordering and alignment. :param dimensions: A sequence of integers representing a permutation of the descriptor's dimensions. :param alignment: Padding (in elements) at the end, ensuring stride is a multiple of this number. 1 (default) means no padding. :param only_first_aligned: If True, only the first dimension is padded with ``alignment``. Otherwise all dimensions are. """ strides, totalsize = self.strides_from_layout(*dimensions, alignment=alignment, only_first_aligned=only_first_aligned) self.strides = strides self.total_size = totalsize
class Transformation(object): """ Base class for transformations, as well as a static registry of transformations, where new transformations can be added in a decentralized manner. New transformations are registered with ``Transformation.register`` (or ``dace.registry.autoregister_params``) with two optional boolean keyword arguments: ``singlestate`` (default: False) and ``strict`` (default: False). If ``singlestate`` is True, the transformation is matched on subgraphs inside an SDFGState; otherwise, subgraphs of the SDFG state machine are matched. If ``strict`` is True, this transformation will be considered strict (i.e., always beneficial to perform) and will be performed automatically as part of SDFG strict transformations. """ # Properties sdfg_id = Property(dtype=int, category="(Debug)") state_id = Property(dtype=int, category="(Debug)") _subgraph = DictProperty(key_type=int, value_type=int, category="(Debug)") expr_index = Property(dtype=int, category="(Debug)") @staticmethod def annotates_memlets(): """ Indicates whether the transformation annotates the edges it creates or modifies with the appropriate memlets. This determines whether to apply memlet propagation after the transformation. """ return False @staticmethod def expressions(): """ Returns a list of Graph objects that will be matched in the subgraph isomorphism phase. Used as a pre-pass before calling `can_be_applied`. @see Transformation.can_be_applied """ raise NotImplementedError @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): """ Returns True if this transformation can be applied on the candidate matched subgraph. :param graph: SDFGState object if this Transformation is single-state, or SDFG object otherwise. :param candidate: A mapping between node IDs returned from `Transformation.expressions` and the nodes in `graph`. :param expr_index: The list index from `Transformation.expressions` that was matched. :param sdfg: If `graph` is an SDFGState, its parent SDFG. Otherwise should be equal to `graph`. :param strict: Whether transformation should run in strict mode. :return: True if the transformation can be applied. """ raise NotImplementedError @staticmethod def match_to_str(graph, candidate): """ Returns a string representation of the pattern match on the candidate subgraph. Used when identifying matches in the console UI. """ raise NotImplementedError def __init__(self, sdfg_id, state_id, subgraph, expr_index): """ Initializes an instance of Transformation. :param sdfg_id: A unique ID of the SDFG. :param state_id: The node ID of the SDFG state, if applicable. :param subgraph: A mapping between node IDs returned from `Transformation.expressions` and the nodes in `graph`. :param expr_index: The list index from `Transformation.expressions` that was matched. :raise TypeError: When transformation is not subclass of Transformation. :raise TypeError: When state_id is not instance of int. :raise TypeError: When subgraph is not a dict of dace.sdfg.nodes.Node : int. """ self.sdfg_id = sdfg_id self.state_id = state_id for value in subgraph.values(): if not isinstance(value, int): raise TypeError('All values of ' 'subgraph' ' dictionary must be ' 'instances of int.') # Serializable subgraph with node IDs as keys expr = self.expressions()[expr_index] self._subgraph = {expr.node_id(k): v for k, v in subgraph.items()} self._subgraph_user = subgraph self.expr_index = expr_index @property def subgraph(self): return self._subgraph_user def __lt__(self, other): """ Comparing two transformations by their class name and node IDs in match. Used for ordering transformations consistently. """ if type(self) != type(other): return type(self).__name__ < type(other).__name__ self_ids = iter(self.subgraph.values()) other_ids = iter(self.subgraph.values()) try: self_id = next(self_ids) except StopIteration: return True try: other_id = next(other_ids) except StopIteration: return False self_end = False while self_id is not None and other_id is not None: if self_id != other_id: return self_id < other_id try: self_id = next(self_ids) except StopIteration: self_end = True try: other_id = next(other_ids) except StopIteration: if self_end: # Transformations are equal return False return False if self_end: return True def apply_pattern(self, sdfg): """ Applies this transformation on the given SDFG. """ self.apply(sdfg) if not self.annotates_memlets(): propagation.propagate_memlets_sdfg(sdfg) def __str__(self): return type(self).__name__ def modifies_graph(self): return True def print_match(self, sdfg): """ Returns a string representation of the pattern match on the given SDFG. Used for printing matches in the console UI. """ if not isinstance(sdfg, dace.SDFG): raise TypeError("Expected SDFG, got: {}".format( type(sdfg).__name__)) if self.state_id == -1: graph = sdfg else: graph = sdfg.nodes()[self.state_id] string = type(self).__name__ + ' in ' string += type(self).match_to_str(graph, self.subgraph) return string def to_json(self, parent=None): props = dace.serialize.all_properties_to_json(self) return { 'type': 'Transformation', 'transformation': type(self).__name__, **props } @staticmethod def from_json(json_obj, context=None): xform = next(ext for ext in Transformation.extensions().keys() if ext.__name__ == json_obj['transformation']) # Recreate subgraph expr = xform.expressions()[json_obj['expr_index']] subgraph = { expr.node(int(k)): int(v) for k, v in json_obj['_subgraph'].items() } # Reconstruct transformation ret = xform(json_obj['sdfg_id'], json_obj['state_id'], subgraph, json_obj['expr_index']) context = context or {} context['transformation'] = ret dace.serialize.set_properties_from_json( ret, json_obj, context=context, ignore_properties={'transformation', 'type'}) return ret
class ONNXSchema: """Python representation of an ONNX schema""" name = Property(dtype=str, desc="The operator name") domain = Property(dtype=str, desc="The operator domain") doc = Property(dtype=str, desc="The operator's docstring") since_version = Property(dtype=int, desc="The version of the operator") attributes = DictProperty(key_type=str, value_type=ONNXAttribute, desc="The operator attributes") type_constraints = DictProperty( key_type=str, value_type=ONNXTypeConstraint, desc="The type constraints for inputs and outputs") inputs = ListProperty(element_type=ONNXParameter, desc="The operator input parameter descriptors") outputs = ListProperty(element_type=ONNXParameter, desc="The operator output parameter descriptors") def __repr__(self): return self.domain + "." + self.name def validate(self): # check all parameters with a type str have a entry in the type constraints for param in chain(self.inputs, self.outputs): if param.type_str not in self.type_constraints: # some operators put a type descriptor here. for those, we will try to insert a new type constraint cons_name = param.name + "_constraint" if cons_name in self.type_constraints: raise ValueError( "Attempted to insert new type constraint, but the name already existed. Please open an issue." ) parsed_typeclass = onnx_type_str_to_typeclass(param.type_str) if parsed_typeclass is None: print("Could not parse typeStr '{}' for parameter '{}'". format(param.type_str, param.name)) cons = ONNXTypeConstraint( cons_name, [parsed_typeclass] if parsed_typeclass is not None else []) self.type_constraints[cons_name] = cons param.type_str = cons_name # check for required parameters with no supported type for param in chain(self.inputs, self.outputs): if ((param.param_type == ONNXParameterType.Single or param.param_type == ONNXParameterType.Variadic) and len(self.type_constraints[param.type_str].types) == 0): raise NotImplementedError( "None of the types for parameter '{}' are supported". format(param.name)) # check that all variadic parameter names do not contain "__" for param in chain(self.inputs, self.outputs): if param.param_type == ONNXParameterType.Variadic and "__" in param.name: raise ValueError( "Unsupported parameter name '{}': variadic parameter names must not contain '__'" .format(param.name)) # check that all inputs and outputs have unique names seen = set() for param in self.inputs: if param.name in seen: raise ValueError( "Got duplicate input parameter name '{}'".format( param.name)) seen.add(param.name) seen = set() for param in self.outputs: if param.name in seen: raise ValueError( "Got duplicate output parameter name '{}'".format( param.name)) seen.add(param.name)
class Map(object): """ A Map is a two-node representation of parametric graphs, containing an integer set by which the contents (nodes dominated by an entry node and post-dominated by an exit node) are replicated. Maps contain a `schedule` property, which specifies how the scope should be scheduled (execution order). Code generators can use the schedule property to generate appropriate code, e.g., GPU kernels. """ # List of (editable) properties label = Property(dtype=str, desc="Label of the map") params = ListProperty(element_type=str, desc="Mapped parameters") range = RangeProperty(desc="Ranges of map parameters", default=sbs.Range([])) schedule = EnumProperty(dtype=dtypes.ScheduleType, desc="Map schedule", default=dtypes.ScheduleType.Default) unroll = Property(dtype=bool, desc="Map unrolling") collapse = Property(dtype=int, default=1, desc="How many dimensions to" " collapse into the parallel range") debuginfo = DebugInfoProperty() is_collapsed = Property(dtype=bool, desc="Show this node/scope/state as collapsed", default=False) instrument = EnumProperty( dtype=dtypes.InstrumentationType, desc="Measure execution statistics with given method", default=dtypes.InstrumentationType.No_Instrumentation) location = DictProperty( key_type=str, value_type=dace.symbolic.pystr_to_symbolic, desc='Full storage location identifier (e.g., rank, GPU ID)') def __init__(self, label, params, ndrange, schedule=dtypes.ScheduleType.Default, unroll=False, collapse=1, fence_instrumentation=False, debuginfo=None, location=None): super(Map, self).__init__() # Assign properties self.label = label self.schedule = schedule self.unroll = unroll self.collapse = 1 self.params = params self.range = ndrange self.debuginfo = debuginfo self.location = location if location is not None else {} self._fence_instrumentation = fence_instrumentation def __str__(self): return self.label + "[" + ", ".join([ "{}={}".format(i, r) for i, r in zip(self._params, [sbs.Range.dim_to_string(d) for d in self._range]) ]) + "]" def validate(self, sdfg, state, node): if not dtypes.validate_name(self.label): raise NameError('Invalid map name "%s"' % self.label) def get_param_num(self): """ Returns the number of map dimension parameters/symbols. """ return len(self.params)
class Consume(object): """ Consume is a scope, like `Map`, that is a part of the parametric graph extension of the SDFG. It creates a producer-consumer relationship between the input stream and the scope subgraph. The subgraph is scheduled to a given number of processing elements for processing, and they will try to pop elements from the input stream until a given quiescence condition is reached. """ # Properties label = Property(dtype=str, desc="Name of the consume node") pe_index = Property(dtype=str, desc="Processing element identifier") num_pes = SymbolicProperty(desc="Number of processing elements", default=1) condition = CodeProperty(desc="Quiescence condition", allow_none=True) schedule = EnumProperty(dtype=dtypes.ScheduleType, desc="Consume schedule", default=dtypes.ScheduleType.Default) chunksize = Property(dtype=int, desc="Maximal size of elements to consume at a time", default=1) debuginfo = DebugInfoProperty() is_collapsed = Property(dtype=bool, desc="Show this node/scope/state as collapsed", default=False) instrument = EnumProperty( dtype=dtypes.InstrumentationType, desc="Measure execution statistics with given method", default=dtypes.InstrumentationType.No_Instrumentation) location = DictProperty(key_type=str, value_type=dace.symbolic.pystr_to_symbolic, desc='Full storage location identifier' '(e.g., rank, GPU ID)') def as_map(self): """ Compatibility function that allows to view the consume as a map, mainly in memlet propagation. """ return Map(self.label, [self.pe_index], sbs.Range([(0, self.num_pes - 1, 1)]), self.schedule) def __init__(self, label, pe_tuple, condition, schedule=dtypes.ScheduleType.Default, chunksize=1, debuginfo=None, location=None): super(Consume, self).__init__() # Properties self.label = label self.pe_index, self.num_pes = pe_tuple self.condition = condition self.schedule = schedule self.chunksize = chunksize self.debuginfo = debuginfo self.location = location if location is not None else {} def __str__(self): if self.condition is not None: return ("%s [%s=0:%s], Condition: %s" % (self._label, self.pe_index, self.num_pes, CodeProperty.to_string(self.condition))) else: return ("%s [%s=0:%s]" % (self._label, self.pe_index, self.num_pes)) def validate(self, sdfg, state, node): if not dtypes.validate_name(self.label): raise NameError('Invalid consume name "%s"' % self.label) def get_param_num(self): """ Returns the number of consume dimension parameters/symbols. """ return 1
class PatternTransformation(TransformationBase): """ Abstract class for pattern-matching transformations. Please extend either ``SingleStateTransformation`` or ``MultiStateTransformation``. :see: SingleStateTransformation :see: MultiStateTransformation :seealso: PatternNode """ # Properties sdfg_id = Property(dtype=int, category="(Debug)") state_id = Property(dtype=int, category="(Debug)") _subgraph = DictProperty(key_type=int, value_type=int, category="(Debug)") expr_index = Property(dtype=int, category="(Debug)") @classmethod def subclasses_recursive( cls, all_subclasses: bool = False ) -> Set[Type['PatternTransformation']]: """ Returns all subclasses of this class, including subclasses of subclasses. :param all_subclasses: Include all subclasses (e.g., including ``ExpandTransformation``). """ if not all_subclasses and cls is PatternTransformation: subclasses = set(SingleStateTransformation.__subclasses__()) | set( MultiStateTransformation.__subclasses__()) else: subclasses = set(cls.__subclasses__()) subsubclasses = set() for sc in subclasses: subsubclasses.update(sc.subclasses_recursive()) # Ignore abstract classes result = subclasses | subsubclasses result = set(sc for sc in result if not getattr(sc, '__abstractmethods__', False)) return result def annotates_memlets(self) -> bool: """ Indicates whether the transformation annotates the edges it creates or modifies with the appropriate memlets. This determines whether to apply memlet propagation after the transformation. """ return False @classmethod def expressions(cls) -> List[gr.SubgraphView]: """ Returns a list of Graph objects that will be matched in the subgraph isomorphism phase. Used as a pre-pass before calling `can_be_applied`. :see: PatternTransformation.can_be_applied """ raise NotImplementedError def can_be_applied(self, graph: Union[SDFG, SDFGState], expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: """ Returns True if this transformation can be applied on the candidate matched subgraph. :param graph: SDFGState object if this transformation is single-state, or SDFG object otherwise. :param expr_index: The list index from `PatternTransformation.expressions` that was matched. :param sdfg: If `graph` is an SDFGState, its parent SDFG. Otherwise should be equal to `graph`. :param permissive: Whether transformation should run in permissive mode. :return: True if the transformation can be applied. """ raise NotImplementedError def apply(self, graph: Union[SDFG, SDFGState], sdfg: SDFG) -> Union[Any, None]: """ Applies this transformation instance on the matched pattern graph. :param sdfg: The SDFG to apply the transformation to. :return: A transformation-defined return value, which could be used to pass analysis data out, or nothing. """ raise NotImplementedError def match_to_str(self, graph: Union[SDFG, SDFGState]) -> str: """ Returns a string representation of the pattern match on the candidate subgraph. Used when identifying matches in the console UI. """ candidate = [] node_to_name = {v: k for k, v in self._get_pattern_nodes().items()} for cnode in self.subgraph.keys(): cname = node_to_name[cnode] candidate.append(getattr(self, cname)) return str(candidate) def __init__(self, sdfg: SDFG, sdfg_id: int, state_id: int, subgraph: Dict['PatternNode', int], expr_index: int, override: bool = False, options: Optional[Dict[str, Any]] = None) -> None: """ Initializes an instance of Transformation match. :param sdfg_id: A unique ID of the SDFG. :param state_id: The node ID of the SDFG state, if applicable. If transformation does not operate on a single state, the value should be -1. :param subgraph: A mapping between node IDs returned from `PatternTransformation.expressions` and the nodes in `graph`. :param expr_index: The list index from `PatternTransformation.expressions` that was matched. :param override: If True, accepts the subgraph dictionary as-is (mostly for internal use). :param options: An optional dictionary of transformation properties :raise TypeError: When transformation is not subclass of PatternTransformation. :raise TypeError: When state_id is not instance of int. :raise TypeError: When subgraph is not a dict of PatternNode : int. """ self._sdfg = sdfg self.sdfg_id = sdfg_id self.state_id = state_id if not override: expr = self.expressions()[expr_index] for value in subgraph.values(): if not isinstance(value, int): raise TypeError('All values of ' 'subgraph' ' dictionary must be ' 'instances of int.') self._subgraph = {expr.node_id(k): v for k, v in subgraph.items()} else: self._subgraph = {-1: -1} # Serializable subgraph with node IDs as keys self._subgraph_user = copy.copy(subgraph) self.expr_index = expr_index # Set properties if options is not None: for optname, optval in options.items(): setattr(self, optname, optval) @property def subgraph(self): return self._subgraph_user def apply_pattern(self, append: bool = True, annotate: bool = True) -> Union[Any, None]: """ Applies this transformation on the given SDFG, using the transformation instance to find the right SDFG object (based on SDFG ID), and applying memlet propagation as necessary. :param sdfg: The SDFG (or an SDFG in the same hierarchy) to apply the transformation to. :param append: If True, appends the transformation to the SDFG transformation history. :return: A transformation-defined return value, which could be used to pass analysis data out, or nothing. """ if append: self._sdfg.append_transformation(self) tsdfg: SDFG = self._sdfg.sdfg_list[self.sdfg_id] tgraph = tsdfg.node(self.state_id) if self.state_id >= 0 else tsdfg retval = self.apply(tgraph, tsdfg) if annotate and not self.annotates_memlets(): propagation.propagate_memlets_sdfg(tsdfg) return retval def __lt__(self, other: 'PatternTransformation') -> bool: """ Comparing two transformations by their class name and node IDs in match. Used for ordering transformations consistently. """ if type(self) != type(other): return type(self).__name__ < type(other).__name__ self_ids = iter(self.subgraph.values()) other_ids = iter(self.subgraph.values()) try: self_id = next(self_ids) except StopIteration: return True try: other_id = next(other_ids) except StopIteration: return False self_end = False while self_id is not None and other_id is not None: if self_id != other_id: return self_id < other_id try: self_id = next(self_ids) except StopIteration: self_end = True try: other_id = next(other_ids) except StopIteration: if self_end: # Transformations are equal return False return False if self_end: return True @classmethod def _get_pattern_nodes(cls) -> Dict[str, 'PatternNode']: """ Returns a dictionary of pattern-matching node in this transformation subclass. Used internally for pattern-matching. :return: A dictionary mapping between pattern-node name and its type. """ return { k: getattr(cls, k) for k in dir(cls) if isinstance(getattr(cls, k), PatternNode) or (k.startswith( '_') and isinstance(getattr(cls, k), (nd.Node, SDFGState))) } @classmethod def apply_to(cls, sdfg: SDFG, options: Optional[Dict[str, Any]] = None, expr_index: int = 0, verify: bool = True, annotate: bool = True, permissive: bool = False, save: bool = True, **where: Union[nd.Node, SDFGState]): """ Applies this transformation to a given subgraph, defined by a set of nodes. Raises an error if arguments are invalid or transformation is not applicable. The subgraph is defined by the `where` dictionary, where each key is taken from the `PatternNode` fields of the transformation. For example, applying `MapCollapse` on two maps can pe performed as follows: ``` MapCollapse.apply_to(sdfg, outer_map_entry=map_a, inner_map_entry=map_b) ``` :param sdfg: The SDFG to apply the transformation to. :param options: A set of parameters to use for applying the transformation. :param expr_index: The pattern expression index to try to match with. :param verify: Check that `can_be_applied` returns True before applying. :param annotate: Run memlet propagation after application if necessary. :param permissive: Apply transformation in permissive mode. :param save: Save transformation as part of the SDFG file. Set to False if composing transformations. :param where: A dictionary of node names (from the transformation) to nodes in the SDFG or a single state. """ if len(where) == 0: raise ValueError('At least one node is required') options = options or {} # Check that all keyword arguments are nodes and if interstate or not sample_node = next(iter(where.values())) if isinstance(sample_node, SDFGState): graph = sdfg state_id = -1 elif isinstance(sample_node, nd.Node): graph = next(s for s in sdfg.nodes() if sample_node in s.nodes()) state_id = sdfg.node_id(graph) else: raise TypeError('Invalid node type "%s"' % type(sample_node).__name__) # Check that all nodes in the pattern are set required_nodes = cls.expressions()[expr_index].nodes() required_node_names = { pname: pval for pname, pval in cls._get_pattern_nodes().items() if pval in required_nodes } required = set(required_node_names.keys()) intersection = required & set(where.keys()) if len(required - intersection) > 0: raise ValueError('Missing nodes for transformation subgraph: %s' % (required - intersection)) # Construct subgraph and instantiate transformation subgraph = { required_node_names[k]: graph.node_id(where[k]) for k in required } instance = cls(sdfg, sdfg.sdfg_id, state_id, subgraph, expr_index) # Construct transformation parameters for optname, optval in options.items(): if not optname in cls.__properties__: raise ValueError('Property "%s" not found in transformation' % optname) setattr(instance, optname, optval) if verify: if not instance.can_be_applied( graph, expr_index, sdfg, permissive=permissive): raise ValueError('Transformation cannot be applied on the ' 'given subgraph ("can_be_applied" failed)') # Apply to SDFG return instance.apply_pattern(annotate=annotate, append=save) def __str__(self) -> str: return type(self).__name__ def print_match(self, sdfg: SDFG) -> str: """ Returns a string representation of the pattern match on the given SDFG. Used for printing matches in the console UI. """ if not isinstance(sdfg, SDFG): raise TypeError("Expected SDFG, got: {}".format( type(sdfg).__name__)) if self.state_id == -1: graph = sdfg else: graph = sdfg.nodes()[self.state_id] string = type(self).__name__ + ' in ' string += self.match_to_str(graph) return string def to_json(self, parent=None) -> Dict[str, Any]: props = serialize.all_properties_to_json(self) return { 'type': 'PatternTransformation', 'transformation': type(self).__name__, **props } @staticmethod def from_json(json_obj: Dict[str, Any], context: Dict[str, Any] = None) -> 'PatternTransformation': xform = next(ext for ext in PatternTransformation.subclasses_recursive( all_subclasses=True) if ext.__name__ == json_obj['transformation']) # Recreate subgraph expr = xform.expressions()[json_obj['expr_index']] subgraph = { expr.node(int(k)): int(v) for k, v in json_obj['_subgraph'].items() } # Reconstruct transformation ret = xform(None, json_obj['sdfg_id'], json_obj['state_id'], subgraph, json_obj['expr_index']) context = context or {} context['transformation'] = ret serialize.set_properties_from_json( ret, json_obj, context=context, ignore_properties={'transformation', 'type'}) return ret
class Transformation(object): """ Base class for transformations, as well as a static registry of transformations, where new transformations can be added in a decentralized manner. New transformations are registered with ``Transformation.register`` (or ``dace.registry.autoregister_params``) with two optional boolean keyword arguments: ``singlestate`` (default: False) and ``strict`` (default: False). If ``singlestate`` is True, the transformation is matched on subgraphs inside an SDFGState; otherwise, subgraphs of the SDFG state machine are matched. If ``strict`` is True, this transformation will be considered strict (i.e., always beneficial to perform) and will be performed automatically as part of SDFG strict transformations. """ # Properties sdfg_id = Property(dtype=int, category="(Debug)") state_id = Property(dtype=int, category="(Debug)") _subgraph = DictProperty(key_type=int, value_type=int, category="(Debug)") expr_index = Property(dtype=int, category="(Debug)") @staticmethod def annotates_memlets(): """ Indicates whether the transformation annotates the edges it creates or modifies with the appropriate memlets. This determines whether to apply memlet propagation after the transformation. """ return False @staticmethod def expressions(): """ Returns a list of Graph objects that will be matched in the subgraph isomorphism phase. Used as a pre-pass before calling `can_be_applied`. :see: Transformation.can_be_applied """ raise NotImplementedError @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): """ Returns True if this transformation can be applied on the candidate matched subgraph. :param graph: SDFGState object if this Transformation is single-state, or SDFG object otherwise. :param candidate: A mapping between node IDs returned from `Transformation.expressions` and the nodes in `graph`. :param expr_index: The list index from `Transformation.expressions` that was matched. :param sdfg: If `graph` is an SDFGState, its parent SDFG. Otherwise should be equal to `graph`. :param strict: Whether transformation should run in strict mode. :return: True if the transformation can be applied. """ raise NotImplementedError @staticmethod def match_to_str(graph, candidate): """ Returns a string representation of the pattern match on the candidate subgraph. Used when identifying matches in the console UI. """ return str(list(candidate.values())) def __init__(self, sdfg_id, state_id, subgraph, expr_index): """ Initializes an instance of Transformation. :param sdfg_id: A unique ID of the SDFG. :param state_id: The node ID of the SDFG state, if applicable. :param subgraph: A mapping between node IDs returned from `Transformation.expressions` and the nodes in `graph`. :param expr_index: The list index from `Transformation.expressions` that was matched. :raise TypeError: When transformation is not subclass of Transformation. :raise TypeError: When state_id is not instance of int. :raise TypeError: When subgraph is not a dict of dace.sdfg.nodes.Node : int. """ self.sdfg_id = sdfg_id self.state_id = state_id for value in subgraph.values(): if not isinstance(value, int): raise TypeError('All values of ' 'subgraph' ' dictionary must be ' 'instances of int.') # Serializable subgraph with node IDs as keys expr = self.expressions()[expr_index] self._subgraph = {expr.node_id(k): v for k, v in subgraph.items()} self._subgraph_user = subgraph self.expr_index = expr_index @property def subgraph(self): return self._subgraph_user def query_node( self, sdfg: SDFG, pattern_node: Union[nd.Node, SDFGState]) -> Union[nd.Node, SDFGState]: """ Returns the matched node object (from a subgraph pattern node) in its original graph. :param sdfg: The SDFG on which this transformation is applied. :param pattern_node: The node object in the transformation properties. :return: The node object in the matched graph. """ graph = sdfg if self.state_id == -1 else sdfg.node(self.state_id) return graph.node(self.subgraph[pattern_node]) def __lt__(self, other): """ Comparing two transformations by their class name and node IDs in match. Used for ordering transformations consistently. """ if type(self) != type(other): return type(self).__name__ < type(other).__name__ self_ids = iter(self.subgraph.values()) other_ids = iter(self.subgraph.values()) try: self_id = next(self_ids) except StopIteration: return True try: other_id = next(other_ids) except StopIteration: return False self_end = False while self_id is not None and other_id is not None: if self_id != other_id: return self_id < other_id try: self_id = next(self_ids) except StopIteration: self_end = True try: other_id = next(other_ids) except StopIteration: if self_end: # Transformations are equal return False return False if self_end: return True def apply_pattern(self, sdfg): """ Applies this transformation on the given SDFG. """ sdfg.append_transformation(self) self.apply(sdfg) if not self.annotates_memlets(): propagation.propagate_memlets_sdfg(sdfg) @classmethod def apply_to(cls, sdfg: SDFG, options: Optional[Dict[str, Any]] = None, expr_index: int = 0, verify: bool = True, strict: bool = False, **where: Union[nd.Node, SDFGState]): """ Applies this transformation to a given subgraph, defined by a set of nodes. Raises an error if arguments are invalid or transformation is not applicable. :param sdfg: The SDFG to apply the transformation to. :param options: A set of parameters to use for applying the transformation. :param expr_index: The pattern expression index to try to match with. :param verify: Check that `can_be_applied` returns True before applying. :param strict: Apply transformation in strict mode. :param where: A dictionary of node names (from the transformation) to nodes in the SDFG or a single state. """ if len(where) == 0: raise ValueError('At least one node is required') options = options or {} # Check that all keyword arguments are nodes and if interstate or not sample_node = next(iter(where.values())) if isinstance(sample_node, SDFGState): graph = sdfg state_id = -1 elif isinstance(sample_node, nd.Node): graph = next(s for s in sdfg.nodes() if sample_node in s.nodes()) state_id = sdfg.node_id(graph) else: raise TypeError('Invalid node type "%s"' % type(sample_node).__name__) # Check that all nodes in the pattern are set required_nodes = cls.expressions()[expr_index].nodes() required_node_names = { pname[1:]: pval for pname, pval in cls.__dict__.items() if pname.startswith('_') and pval in required_nodes } required = set(required_node_names.keys()) intersection = required & set(where.keys()) if len(required - intersection) > 0: raise ValueError('Missing nodes for transformation subgraph: %s' % (required - intersection)) # Construct subgraph and instantiate transformation subgraph = { required_node_names[k]: graph.node_id(where[k]) for k in required } instance = cls(sdfg.sdfg_id, state_id, subgraph, expr_index) # Construct transformation parameters for optname, optval in options.items(): if not optname in cls.__properties__: raise ValueError('Property "%s" not found in transformation' % optname) setattr(instance, optname, optval) if verify: if not cls.can_be_applied( graph, subgraph, expr_index, sdfg, strict=strict): raise ValueError('Transformation cannot be applied on the ' 'given subgraph ("can_be_applied" failed)') # Apply to SDFG instance.apply_pattern(sdfg) def __str__(self): return type(self).__name__ def print_match(self, sdfg): """ Returns a string representation of the pattern match on the given SDFG. Used for printing matches in the console UI. """ if not isinstance(sdfg, SDFG): raise TypeError("Expected SDFG, got: {}".format( type(sdfg).__name__)) if self.state_id == -1: graph = sdfg else: graph = sdfg.nodes()[self.state_id] string = type(self).__name__ + ' in ' string += type(self).match_to_str(graph, self.subgraph) return string def to_json(self, parent=None): props = serialize.all_properties_to_json(self) return { 'type': 'Transformation', 'transformation': type(self).__name__, **props } @staticmethod def from_json(json_obj, context=None): xform = next(ext for ext in Transformation.extensions().keys() if ext.__name__ == json_obj['transformation']) # Recreate subgraph expr = xform.expressions()[json_obj['expr_index']] subgraph = { expr.node(int(k)): int(v) for k, v in json_obj['_subgraph'].items() } # Reconstruct transformation ret = xform(json_obj['sdfg_id'], json_obj['state_id'], subgraph, json_obj['expr_index']) context = context or {} context['transformation'] = ret serialize.set_properties_from_json( ret, json_obj, context=context, ignore_properties={'transformation', 'type'}) return ret