Exemplo n.º 1
0
class NestedSDFG(CodeNode):
    """ An SDFG state node that contains an SDFG of its own, runnable using
        the data dependencies specified using its connectors.

        It is encouraged to use nested SDFGs instead of coarse-grained tasklets
        since they are analyzable with respect to transformations.

        @note: A nested SDFG cannot create recursion (one of its parent SDFGs).
    """

    # NOTE: We cannot use SDFG as the type because of an import loop
    sdfg = SDFGReferenceProperty(desc="The SDFG", allow_none=True)
    schedule = Property(dtype=dtypes.ScheduleType,
                        desc="SDFG schedule",
                        allow_none=True,
                        choices=dtypes.ScheduleType,
                        from_string=lambda x: dtypes.ScheduleType[x],
                        default=dtypes.ScheduleType.Default)
    symbol_mapping = DictProperty(
        key_type=str,
        value_type=dace.symbolic.pystr_to_symbolic,
        desc="Mapping between internal symbols and their values, expressed as "
        "symbolic expressions")
    debuginfo = DebugInfoProperty()
    is_collapsed = Property(dtype=bool,
                            desc="Show this node/scope/state as collapsed",
                            default=False)

    instrument = Property(choices=dtypes.InstrumentationType,
                          desc="Measure execution statistics with given method",
                          default=dtypes.InstrumentationType.No_Instrumentation)

    def __init__(self,
                 label,
                 sdfg,
                 inputs: Set[str],
                 outputs: Set[str],
                 symbol_mapping: Dict[str, Any] = None,
                 schedule=dtypes.ScheduleType.Default,
                 location=None,
                 debuginfo=None):
        super(NestedSDFG, self).__init__(label, location, inputs, outputs)

        # Properties
        self.sdfg = sdfg
        self.symbol_mapping = symbol_mapping or {}
        self.schedule = schedule
        self.debuginfo = debuginfo

    @staticmethod
    def from_json(json_obj, context=None):
        from dace import SDFG  # Avoid import loop

        # We have to load the SDFG first.
        ret = NestedSDFG("nolabel", SDFG('nosdfg'), {}, {})

        dace.serialize.set_properties_from_json(ret, json_obj, context)

        if context and 'sdfg_state' in context:
            ret.sdfg.parent = context['sdfg_state']
        if context and 'sdfg' in context:
            ret.sdfg.parent_sdfg = context['sdfg']

        ret.sdfg.parent_nsdfg_node = ret

        ret.sdfg.update_sdfg_list([])

        return ret

    @property
    def free_symbols(self) -> Set[str]:
        return set().union(
            *(map(str,
                  pystr_to_symbolic(v).free_symbols)
              for v in self.symbol_mapping.values()),
            *(map(str,
                  pystr_to_symbolic(v).free_symbols)
              for v in self.location.values()))

    def infer_connector_types(self, sdfg, state):
        # Avoid import loop
        from dace.sdfg.infer_types import infer_connector_types
        # Infer internal connector types
        infer_connector_types(self.sdfg)

    def __str__(self):
        if not self.label:
            return "SDFG"
        else:
            return self.label

    def validate(self, sdfg, state):
        if not dtypes.validate_name(self.label):
            raise NameError('Invalid nested SDFG name "%s"' % self.label)
        for in_conn in self.in_connectors:
            if not dtypes.validate_name(in_conn):
                raise NameError('Invalid input connector "%s"' % in_conn)
        for out_conn in self.out_connectors:
            if not dtypes.validate_name(out_conn):
                raise NameError('Invalid output connector "%s"' % out_conn)
        connectors = self.in_connectors.keys() | self.out_connectors.keys()
        for dname, desc in self.sdfg.arrays.items():
            # TODO(later): Disallow scalars without access nodes (so that this
            #              check passes for them too).
            if isinstance(desc, data.Scalar):
                continue
            if not desc.transient and dname not in connectors:
                raise NameError('Data descriptor "%s" not found in nested '
                                'SDFG connectors' % dname)
            if dname in connectors and desc.transient:
                raise NameError(
                    '"%s" is a connector but its corresponding array is transient'
                    % dname)

        # Validate undefined symbols
        symbols = set(k for k in self.sdfg.free_symbols if k not in connectors)
        missing_symbols = [s for s in symbols if s not in self.symbol_mapping]
        if missing_symbols:
            raise ValueError('Missing symbols on nested SDFG: %s' %
                             (missing_symbols))

        # Recursively validate nested SDFG
        self.sdfg.validate()
Exemplo n.º 2
0
class NestedSDFG(CodeNode):
    """ An SDFG state node that contains an SDFG of its own, runnable using
        the data dependencies specified using its connectors.

        It is encouraged to use nested SDFGs instead of coarse-grained tasklets
        since they are analyzable with respect to transformations.

        @note: A nested SDFG cannot create recursion (one of its parent SDFGs).
    """

    label = Property(dtype=str, desc="Name of the SDFG")
    # NOTE: We cannot use SDFG as the type because of an import loop
    sdfg = SDFGReferenceProperty(dtype=graph.OrderedDiGraph, desc="The SDFG")
    schedule = Property(dtype=dtypes.ScheduleType,
                        desc="SDFG schedule",
                        choices=dtypes.ScheduleType,
                        from_string=lambda x: dtypes.ScheduleType[x])
    location = Property(dtype=str, desc="SDFG execution location descriptor")
    debuginfo = DebugInfoProperty()
    is_collapsed = Property(dtype=bool,
                            desc="Show this node/scope/state as collapsed",
                            default=False)

    instrument = Property(
        choices=dtypes.InstrumentationType,
        desc="Measure execution statistics with given method",
        default=dtypes.InstrumentationType.No_Instrumentation)

    def __init__(self,
                 label,
                 sdfg,
                 inputs: Set[str],
                 outputs: Set[str],
                 schedule=dtypes.ScheduleType.Default,
                 location="-1",
                 debuginfo=None):
        super(NestedSDFG, self).__init__(inputs, outputs)

        # Properties
        self.label = label
        self.sdfg = sdfg
        self.schedule = schedule
        self.location = location
        self.debuginfo = debuginfo

    @staticmethod
    def from_json(json_obj, context=None):
        from dace import SDFG  # Avoid import loop

        # We have to load the SDFG first.
        ret = NestedSDFG("nolabel", SDFG('nosdfg'), set(), set())

        dace.serialize.set_properties_from_json(ret, json_obj, context)

        if context and 'sdfg_state' in context:
            ret.sdfg.parent = context['sdfg_state']
        if context and 'sdfg' in context:
            ret.sdfg.parent_sdfg = context['sdfg']

        return ret

    def draw_node(self, sdfg, graph):
        return dot.draw_node(sdfg, graph, self, shape="doubleoctagon")

    def __str__(self):
        if not self.label:
            return "SDFG"
        else:
            return self.label

    def validate(self, sdfg, state):
        if not data.validate_name(self.label):
            raise NameError('Invalid nested SDFG name "%s"' % self.label)
        for in_conn in self.in_connectors:
            if not data.validate_name(in_conn):
                raise NameError('Invalid input connector "%s"' % in_conn)
        for out_conn in self.out_connectors:
            if not data.validate_name(out_conn):
                raise NameError('Invalid output connector "%s"' % out_conn)

        # Recursively validate nested SDFG
        self.sdfg.validate()