class NestedSDFG(CodeNode): """ An SDFG state node that contains an SDFG of its own, runnable using the data dependencies specified using its connectors. It is encouraged to use nested SDFGs instead of coarse-grained tasklets since they are analyzable with respect to transformations. @note: A nested SDFG cannot create recursion (one of its parent SDFGs). """ # NOTE: We cannot use SDFG as the type because of an import loop sdfg = SDFGReferenceProperty(desc="The SDFG", allow_none=True) schedule = Property(dtype=dtypes.ScheduleType, desc="SDFG schedule", allow_none=True, choices=dtypes.ScheduleType, from_string=lambda x: dtypes.ScheduleType[x], default=dtypes.ScheduleType.Default) symbol_mapping = DictProperty( key_type=str, value_type=dace.symbolic.pystr_to_symbolic, desc="Mapping between internal symbols and their values, expressed as " "symbolic expressions") debuginfo = DebugInfoProperty() is_collapsed = Property(dtype=bool, desc="Show this node/scope/state as collapsed", default=False) instrument = Property(choices=dtypes.InstrumentationType, desc="Measure execution statistics with given method", default=dtypes.InstrumentationType.No_Instrumentation) def __init__(self, label, sdfg, inputs: Set[str], outputs: Set[str], symbol_mapping: Dict[str, Any] = None, schedule=dtypes.ScheduleType.Default, location=None, debuginfo=None): super(NestedSDFG, self).__init__(label, location, inputs, outputs) # Properties self.sdfg = sdfg self.symbol_mapping = symbol_mapping or {} self.schedule = schedule self.debuginfo = debuginfo @staticmethod def from_json(json_obj, context=None): from dace import SDFG # Avoid import loop # We have to load the SDFG first. ret = NestedSDFG("nolabel", SDFG('nosdfg'), {}, {}) dace.serialize.set_properties_from_json(ret, json_obj, context) if context and 'sdfg_state' in context: ret.sdfg.parent = context['sdfg_state'] if context and 'sdfg' in context: ret.sdfg.parent_sdfg = context['sdfg'] ret.sdfg.parent_nsdfg_node = ret ret.sdfg.update_sdfg_list([]) return ret @property def free_symbols(self) -> Set[str]: return set().union( *(map(str, pystr_to_symbolic(v).free_symbols) for v in self.symbol_mapping.values()), *(map(str, pystr_to_symbolic(v).free_symbols) for v in self.location.values())) def infer_connector_types(self, sdfg, state): # Avoid import loop from dace.sdfg.infer_types import infer_connector_types # Infer internal connector types infer_connector_types(self.sdfg) def __str__(self): if not self.label: return "SDFG" else: return self.label def validate(self, sdfg, state): if not dtypes.validate_name(self.label): raise NameError('Invalid nested SDFG name "%s"' % self.label) for in_conn in self.in_connectors: if not dtypes.validate_name(in_conn): raise NameError('Invalid input connector "%s"' % in_conn) for out_conn in self.out_connectors: if not dtypes.validate_name(out_conn): raise NameError('Invalid output connector "%s"' % out_conn) connectors = self.in_connectors.keys() | self.out_connectors.keys() for dname, desc in self.sdfg.arrays.items(): # TODO(later): Disallow scalars without access nodes (so that this # check passes for them too). if isinstance(desc, data.Scalar): continue if not desc.transient and dname not in connectors: raise NameError('Data descriptor "%s" not found in nested ' 'SDFG connectors' % dname) if dname in connectors and desc.transient: raise NameError( '"%s" is a connector but its corresponding array is transient' % dname) # Validate undefined symbols symbols = set(k for k in self.sdfg.free_symbols if k not in connectors) missing_symbols = [s for s in symbols if s not in self.symbol_mapping] if missing_symbols: raise ValueError('Missing symbols on nested SDFG: %s' % (missing_symbols)) # Recursively validate nested SDFG self.sdfg.validate()
class NestedSDFG(CodeNode): """ An SDFG state node that contains an SDFG of its own, runnable using the data dependencies specified using its connectors. It is encouraged to use nested SDFGs instead of coarse-grained tasklets since they are analyzable with respect to transformations. @note: A nested SDFG cannot create recursion (one of its parent SDFGs). """ label = Property(dtype=str, desc="Name of the SDFG") # NOTE: We cannot use SDFG as the type because of an import loop sdfg = SDFGReferenceProperty(dtype=graph.OrderedDiGraph, desc="The SDFG") schedule = Property(dtype=dtypes.ScheduleType, desc="SDFG schedule", choices=dtypes.ScheduleType, from_string=lambda x: dtypes.ScheduleType[x]) location = Property(dtype=str, desc="SDFG execution location descriptor") debuginfo = DebugInfoProperty() is_collapsed = Property(dtype=bool, desc="Show this node/scope/state as collapsed", default=False) instrument = Property( choices=dtypes.InstrumentationType, desc="Measure execution statistics with given method", default=dtypes.InstrumentationType.No_Instrumentation) def __init__(self, label, sdfg, inputs: Set[str], outputs: Set[str], schedule=dtypes.ScheduleType.Default, location="-1", debuginfo=None): super(NestedSDFG, self).__init__(inputs, outputs) # Properties self.label = label self.sdfg = sdfg self.schedule = schedule self.location = location self.debuginfo = debuginfo @staticmethod def from_json(json_obj, context=None): from dace import SDFG # Avoid import loop # We have to load the SDFG first. ret = NestedSDFG("nolabel", SDFG('nosdfg'), set(), set()) dace.serialize.set_properties_from_json(ret, json_obj, context) if context and 'sdfg_state' in context: ret.sdfg.parent = context['sdfg_state'] if context and 'sdfg' in context: ret.sdfg.parent_sdfg = context['sdfg'] return ret def draw_node(self, sdfg, graph): return dot.draw_node(sdfg, graph, self, shape="doubleoctagon") def __str__(self): if not self.label: return "SDFG" else: return self.label def validate(self, sdfg, state): if not data.validate_name(self.label): raise NameError('Invalid nested SDFG name "%s"' % self.label) for in_conn in self.in_connectors: if not data.validate_name(in_conn): raise NameError('Invalid input connector "%s"' % in_conn) for out_conn in self.out_connectors: if not data.validate_name(out_conn): raise NameError('Invalid output connector "%s"' % out_conn) # Recursively validate nested SDFG self.sdfg.validate()