예제 #1
0
class CodeNode(Node):
    """ A node that contains runnable code with acyclic external data
        dependencies. May either be a tasklet or a nested SDFG, and
        denoted by an octagonal shape. """

    label = Property(dtype=str, desc="Name of the CodeNode")
    location = DictProperty(
        key_type=str,
        value_type=dace.symbolic.pystr_to_symbolic,
        desc='Full storage location identifier (e.g., rank, GPU ID)')
    environments = SetProperty(
        str,
        desc="Environments required by CMake to build and run this code node.",
        default=set())

    def __init__(self, label="", location=None, inputs=None, outputs=None):
        super(CodeNode, self).__init__(inputs or set(), outputs or set())
        # Properties
        self.label = label
        self.location = location if location is not None else {}

    @property
    def free_symbols(self) -> Set[str]:
        return set().union(*(map(str,
                                 pystr_to_symbolic(v).free_symbols)
                             for v in self.location.values()))
예제 #2
0
class CodeObject(object):
    name = Property(dtype=str, desc="Filename to use")
    code = Property(dtype=str, desc="The code attached to this object")
    language = Property(dtype=str,
                        desc="Language used for this code (same " +
                        "as its file extension)")
    target = Property(dtype=type,
                      desc="Target to use for compilation",
                      allow_none=True)
    target_type = Property(
        dtype=str,
        desc="Sub-target within target (e.g., host or device code)",
        default="")
    title = Property(dtype=str, desc="Title of code for GUI")
    extra_compiler_kwargs = DictProperty(key_type=str,
                                         value_type=str,
                                         desc="Additional compiler argument "
                                         "variables to add to template")
    linkable = Property(dtype=bool,
                        desc='Should this file participate in '
                        'overall linkage?')
    environments = SetProperty(
        str,
        desc="Environments required by CMake to build and run this code node.",
        default=set())

    def __init__(self,
                 name,
                 code,
                 language,
                 target,
                 title,
                 target_type="",
                 additional_compiler_kwargs=None,
                 linkable=True,
                 environments=None,
                 sdfg=None):
        super(CodeObject, self).__init__()

        self.name = name
        self.code = code
        self.language = language
        self.target = target
        self.target_type = target_type
        self.title = title
        self.extra_compiler_kwargs = additional_compiler_kwargs or {}
        self.linkable = linkable
        self.environments = environments or set()

        if language == 'cpp' and title == 'Frame' and sdfg:
            sourcemap.create_maps(sdfg, code, self.target.target_name)

    @property
    def clean_code(self):
        return re.sub(r'[ \t]*////__(DACE:|CODEGEN;)[^\n]*', '', self.code)
예제 #3
0
파일: codeobject.py 프로젝트: gronerl/dace
class CodeObject(object):
    name = Property(dtype=str, desc="Filename to use")
    code = Property(dtype=str, desc="The code attached to this object")
    language = Property(
        dtype=str,
        desc="Language used for this code (same " +
        "as its file extension)")  # dtype=dtypes.Language?
    target = Property(
        dtype=type, desc="Target to use for compilation", allow_none=True)
    target_type = Property(
        dtype=str,
        desc="Sub-target within target (e.g., host or device code)",
        default="")
    title = Property(dtype=str, desc="Title of code for GUI")
    extra_compiler_kwargs = Property(
        dtype=dict,
        desc="Additional compiler argument "
        "variables to add to template")
    linkable = Property(
        dtype=bool, desc='Should this file participate in '
        'overall linkage?')
    environments = SetProperty(
        str,
        desc="Environments required by CMake to build and run this code node.",
        default=set())

    def __init__(self,
                 name,
                 code,
                 language,
                 target,
                 title,
                 target_type="",
                 additional_compiler_kwargs=None,
                 linkable=True,
                 environments=set()):
        super(CodeObject, self).__init__()

        self.name = name
        self.code = code
        self.language = language
        self.target = target
        self.target_type = target_type
        self.title = title
        self.extra_compiler_kwargs = additional_compiler_kwargs or {}
        self.linkable = linkable
        self.environments = environments
예제 #4
0
파일: nodes.py 프로젝트: cpenny42/dace
class Node(object):
    """ Base node class. """

    in_connectors = SetProperty(
        str, default=set(), desc="A set of input connectors for this node.")
    out_connectors = SetProperty(
        str, default=set(), desc="A set of output connectors for this node.")

    def __init__(self, in_connectors=set(), out_connectors=set()):
        self.in_connectors = in_connectors
        self.out_connectors = out_connectors

    def __str__(self):
        if hasattr(self, 'label'):
            return self.label
        else:
            return type(self).__name__

    def validate(self, sdfg, state):
        pass

    def toJSON(self, indent=0):
        labelstr = str(self)
        typestr = str(type(self).__name__)
        inconn = "[" + ",".join(
            ['"' + str(x) + '"' for x in self.in_connectors]) + "]"
        outconn = "[" + ",".join(
            ['"' + str(x) + '"' for x in self.out_connectors]) + "]"
        json = " " * indent + "{ \"label\": \"" + labelstr
        json += "\", \"type\": \"" + typestr + "\", \"in_connectors\": " + inconn
        json += ", \"out_connectors\" :" + outconn
        json += "}\n"
        return json

    def __repr__(self):
        return type(self).__name__ + ' (' + self.__str__() + ')'

    def add_in_connector(self, connector_name: str):
        """ Adds a new input connector to the node. The operation will fail if
            a connector (either input or output) with the same name already 
            exists in the node.

            @param connector_name: The name of the new connector.
            @return: True if the operation is successful, otherwise False.
        """

        if (connector_name in self.in_connectors
                or connector_name in self.out_connectors):
            return False
        connectors = self.in_connectors
        connectors.add(connector_name)
        self.in_connectors = connectors
        return True

    def add_out_connector(self, connector_name: str):
        """ Adds a new output connector to the node. The operation will fail if
            a connector (either input or output) with the same name already 
            exists in the node.

            @param connector_name: The name of the new connector.
            @return: True if the operation is successful, otherwise False.
        """

        if (connector_name in self.in_connectors
                or connector_name in self.out_connectors):
            return False
        connectors = self.out_connectors
        connectors.add(connector_name)
        self.out_connectors = connectors
        return True

    def remove_in_connector(self, connector_name: str):
        """ Removes an input connector from the node.
            @param connector_name: The name of the connector to remove.
            @return: True if the operation was successful.
        """

        if connector_name in self.in_connectors:
            connectors = self.in_connectors
            connectors.remove(connector_name)
            self.in_connectors = connectors
        return True

    def remove_out_connector(self, connector_name: str):
        """ Removes an output connector from the node.
            @param connector_name: The name of the connector to remove.
            @return: True if the operation was successful.
        """

        if connector_name in self.out_connectors:
            connectors = self.out_connectors
            connectors.remove(connector_name)
            self.out_connectors = connectors
        return True

    def _next_connector_int(self) -> int:
        """ Returns the next unused connector ID (as an integer). Used for
            filling connectors when adding edges to scopes. """
        next_number = 1
        for conn in itertools.chain(self.in_connectors, self.out_connectors):
            if conn.startswith('IN_'):
                cconn = conn[3:]
            elif conn.startswith('OUT_'):
                cconn = conn[4:]
            else:
                continue
            try:
                curconn = int(cconn)
                if curconn >= next_number:
                    next_number = curconn + 1
            except TypeError:  # not integral
                continue
        return next_number

    def next_connector(self) -> str:
        """ Returns the next unused connector ID (as a string). Used for
            filling connectors when adding edges to scopes. """
        return str(self._next_connector_int())

    def last_connector(self) -> str:
        """ Returns the last used connector ID (as a string). Used for
            filling connectors when adding edges to scopes. """
        return str(self._next_connector_int() - 1)
예제 #5
0
class Node(object):
    """ Base node class. """

    in_connectors = SetProperty(
        str, default=set(), desc="A set of input connectors for this node.")
    out_connectors = SetProperty(
        str, default=set(), desc="A set of output connectors for this node.")

    def __init__(self, in_connectors=None, out_connectors=None):
        self.in_connectors = in_connectors or set()
        self.out_connectors = out_connectors or set()

    def __str__(self):
        if hasattr(self, 'label'):
            return self.label
        else:
            return type(self).__name__

    def validate(self, sdfg, state):
        pass

    def to_json(self, parent):
        labelstr = str(self)
        typestr = getattr(self, '__jsontype__', str(type(self).__name__))

        try:
            scope_entry_node = parent.entry_node(self)
        except (RuntimeError, StopIteration):
            scope_entry_node = None

        if scope_entry_node is not None:
            ens = parent.exit_node(parent.entry_node(self))
            scope_exit_node = str(parent.node_id(ens))
            scope_entry_node = str(parent.node_id(scope_entry_node))
        else:
            scope_entry_node = None
            scope_exit_node = None

        # The scope exit of an entry node is the matching exit node
        if isinstance(self, EntryNode):
            try:
                scope_exit_node = str(parent.node_id(parent.exit_node(self)))
            except (RuntimeError, StopIteration):
                scope_exit_node = None

        retdict = {
            "type": typestr,
            "label": labelstr,
            "attributes": dace.serialize.all_properties_to_json(self),
            "id": parent.node_id(self),
            "scope_entry": scope_entry_node,
            "scope_exit": scope_exit_node
        }
        return retdict

    def __repr__(self):
        return type(self).__name__ + ' (' + self.__str__() + ')'

    def add_in_connector(self, connector_name: str):
        """ Adds a new input connector to the node. The operation will fail if
            a connector (either input or output) with the same name already
            exists in the node.

            :param connector_name: The name of the new connector.
            :return: True if the operation is successful, otherwise False.
        """

        if (connector_name in self.in_connectors
                or connector_name in self.out_connectors):
            return False
        connectors = self.in_connectors
        connectors.add(connector_name)
        self.in_connectors = connectors
        return True

    def add_out_connector(self, connector_name: str):
        """ Adds a new output connector to the node. The operation will fail if
            a connector (either input or output) with the same name already
            exists in the node.

            :param connector_name: The name of the new connector.
            :return: True if the operation is successful, otherwise False.
        """

        if (connector_name in self.in_connectors
                or connector_name in self.out_connectors):
            return False
        connectors = self.out_connectors
        connectors.add(connector_name)
        self.out_connectors = connectors
        return True

    def remove_in_connector(self, connector_name: str):
        """ Removes an input connector from the node.
            :param connector_name: The name of the connector to remove.
            :return: True if the operation was successful.
        """

        if connector_name in self.in_connectors:
            connectors = self.in_connectors
            connectors.remove(connector_name)
            self.in_connectors = connectors
        return True

    def remove_out_connector(self, connector_name: str):
        """ Removes an output connector from the node.
            :param connector_name: The name of the connector to remove.
            :return: True if the operation was successful.
        """

        if connector_name in self.out_connectors:
            connectors = self.out_connectors
            connectors.remove(connector_name)
            self.out_connectors = connectors
        return True

    def _next_connector_int(self) -> int:
        """ Returns the next unused connector ID (as an integer). Used for
            filling connectors when adding edges to scopes. """
        next_number = 1
        for conn in itertools.chain(self.in_connectors, self.out_connectors):
            if conn.startswith('IN_'):
                cconn = conn[3:]
            elif conn.startswith('OUT_'):
                cconn = conn[4:]
            else:
                continue
            try:
                curconn = int(cconn)
                if curconn >= next_number:
                    next_number = curconn + 1
            except (TypeError, ValueError):  # not integral
                continue
        return next_number

    def next_connector(self) -> str:
        """ Returns the next unused connector ID (as a string). Used for
            filling connectors when adding edges to scopes. """
        return str(self._next_connector_int())

    def last_connector(self) -> str:
        """ Returns the last used connector ID (as a string). Used for
            filling connectors when adding edges to scopes. """
        return str(self._next_connector_int() - 1)

    @property
    def free_symbols(self) -> Set[str]:
        """ Returns a set of symbols used in this node's properties. """
        return set()

    def new_symbols(self, sdfg, state, symbols) -> Dict[str, dtypes.typeclass]:
        """ Returns a mapping between symbols defined by this node (e.g., for
            scope entries) to their type. """
        return {}
예제 #6
0
class SubgraphTransformation(TransformationBase):
    """
    Base class for transformations that apply on arbitrary subgraphs, rather
    than matching a specific pattern.

    Subclasses need to implement the `can_be_applied` and `apply` operations,
    as well as registered with the subclass registry. See the `Transformation`
    class docstring for more information.
    """

    sdfg_id = Property(dtype=int, desc='ID of SDFG to transform')
    state_id = Property(
        dtype=int,
        desc='ID of state to transform subgraph within, or -1 to transform the '
        'SDFG')
    subgraph = SetProperty(element_type=int,
                           desc='Subgraph in transformation instance')

    def __init__(self,
                 subgraph: Union[Set[int], gr.SubgraphView],
                 sdfg_id: int = None,
                 state_id: int = None):
        if (not isinstance(subgraph, (gr.SubgraphView, SDFG, SDFGState))
                and (sdfg_id is None or state_id is None)):
            raise TypeError(
                'Subgraph transformation either expects a SubgraphView or a '
                'set of node IDs, SDFG ID and state ID (or -1).')

        # An entire graph is given as a subgraph
        if isinstance(subgraph, (SDFG, SDFGState)):
            subgraph = gr.SubgraphView(subgraph, subgraph.nodes())

        if isinstance(subgraph, gr.SubgraphView):
            self.subgraph = set(
                subgraph.graph.node_id(n) for n in subgraph.nodes())

            if isinstance(subgraph.graph, SDFGState):
                sdfg = subgraph.graph.parent
                self.sdfg_id = sdfg.sdfg_id
                self.state_id = sdfg.node_id(subgraph.graph)
            elif isinstance(subgraph.graph, SDFG):
                self.sdfg_id = subgraph.graph.sdfg_id
                self.state_id = -1
            else:
                raise TypeError('Unrecognized graph type "%s"' %
                                type(subgraph.graph).__name__)
        else:
            self.subgraph = subgraph
            self.sdfg_id = sdfg_id
            self.state_id = state_id

    def subgraph_view(self, sdfg: SDFG) -> gr.SubgraphView:
        graph = sdfg.sdfg_list[self.sdfg_id]
        if self.state_id != -1:
            graph = graph.node(self.state_id)
        return gr.SubgraphView(graph,
                               [graph.node(idx) for idx in self.subgraph])

    def can_be_applied(self, sdfg: SDFG, subgraph: gr.SubgraphView) -> bool:
        """
        Tries to match the transformation on a given subgraph, returning
        True if this transformation can be applied.
        :param sdfg: The SDFG that includes the subgraph.
        :param subgraph: The SDFG or state subgraph to try to apply the
                         transformation on.
        :return: True if the subgraph can be transformed, or False otherwise.
        """
        pass

    def apply(self, sdfg: SDFG):
        """
        Applies the transformation on the given subgraph.
        :param sdfg: The SDFG that includes the subgraph.
        """
        pass

    @classmethod
    def apply_to(cls,
                 sdfg: SDFG,
                 *where: Union[nd.Node, SDFGState, gr.SubgraphView],
                 verify: bool = True,
                 **options: Any):
        """
        Applies this transformation to a given subgraph, defined by a set of
        nodes. Raises an error if arguments are invalid or transformation is
        not applicable.

        To apply the transformation on a specific subgraph, the `where`
        parameter can be used either on a subgraph object (`SubgraphView`), or
        on directly on a list of subgraph nodes, given as `Node` or `SDFGState`
        objects. Transformation properties can then be given as keyword
        arguments. For example, applying `SubgraphFusion` on a subgraph of three
        nodes can be called in one of two ways:
        ```
        # Subgraph
        SubgraphFusion.apply_to(
            sdfg, SubgraphView(state, [node_a, node_b, node_c]))

        # Simplified API: list of nodes
        SubgraphFusion.apply_to(sdfg, node_a, node_b, node_c)
        ```

        :param sdfg: The SDFG to apply the transformation to.
        :param where: A set of nodes in the SDFG/state, or a subgraph thereof.
        :param verify: Check that `can_be_applied` returns True before applying.
        :param options: A set of parameters to use for applying the
                        transformation.
        """
        subgraph = None
        if len(where) == 1:
            if isinstance(where[0], (list, tuple)):
                where = where[0]
            elif isinstance(where[0], gr.SubgraphView):
                subgraph = where[0]
        if len(where) == 0:
            raise ValueError('At least one node is required')

        # Check that all keyword arguments are nodes and if interstate or not
        if subgraph is None:
            sample_node = where[0]

            if isinstance(sample_node, SDFGState):
                graph = sdfg
                state_id = -1
            elif isinstance(sample_node, nd.Node):
                graph = next(s for s in sdfg.nodes()
                             if sample_node in s.nodes())
                state_id = sdfg.node_id(graph)
            else:
                raise TypeError('Invalid node type "%s"' %
                                type(sample_node).__name__)

            # Construct subgraph and instantiate transformation
            subgraph = gr.SubgraphView(graph, where)
            instance = cls(subgraph, sdfg.sdfg_id, state_id)
        else:
            # Construct instance from subgraph directly
            instance = cls(subgraph)

        # Construct transformation parameters
        for optname, optval in options.items():
            if not optname in cls.__properties__:
                raise ValueError('Property "%s" not found in transformation' %
                                 optname)
            setattr(instance, optname, optval)

        if verify:
            if not instance.can_be_applied(sdfg, subgraph):
                raise ValueError('Transformation cannot be applied on the '
                                 'given subgraph ("can_be_applied" failed)')

        # Apply to SDFG
        return instance.apply(sdfg)

    def to_json(self, parent=None):
        props = serialize.all_properties_to_json(self)
        return {
            'type': 'SubgraphTransformation',
            'transformation': type(self).__name__,
            **props
        }

    @staticmethod
    def from_json(json_obj: Dict[str, Any],
                  context: Dict[str, Any] = None) -> 'SubgraphTransformation':
        xform = next(ext for ext in SubgraphTransformation.extensions().keys()
                     if ext.__name__ == json_obj['transformation'])

        # Reconstruct transformation
        ret = xform(json_obj['subgraph'], json_obj['sdfg_id'],
                    json_obj['state_id'])
        context = context or {}
        context['transformation'] = ret
        serialize.set_properties_from_json(
            ret,
            json_obj,
            context=context,
            ignore_properties={'transformation', 'type'})
        return ret
예제 #7
0
class SubgraphTransformation(object):
    """
    Base class for transformations that apply on arbitrary subgraphs, rather than
    matching a specific pattern. Subclasses need to implement the `match` and `apply`
    operations.
    """

    sdfg_id = Property(dtype=int, desc='ID of SDFG to transform')
    state_id = Property(
        dtype=int,
        desc='ID of state to transform subgraph within, or -1 to transform the '
        'SDFG')
    subgraph = SetProperty(element_type=int,
                           desc='Subgraph in transformation instance')

    def __init__(self,
                 subgraph: Union[Set[int], SubgraphView],
                 sdfg_id: int = None,
                 state_id: int = None):
        if (not isinstance(subgraph, (SubgraphView, SDFG, SDFGState))
                and (sdfg_id is None or state_id is None)):
            raise TypeError(
                'Subgraph transformation either expects a SubgraphView or a '
                'set of node IDs, SDFG ID and state ID (or -1).')

        # An entire graph is given as a subgraph
        if isinstance(subgraph, (SDFG, SDFGState)):
            subgraph = SubgraphView(subgraph, subgraph.nodes())

        if isinstance(subgraph, SubgraphView):
            self.subgraph = set(
                subgraph.graph.node_id(n) for n in subgraph.nodes())

            if isinstance(subgraph.graph, SDFGState):
                sdfg = subgraph.graph.parent
                self.sdfg_id = sdfg.sdfg_id
                self.state_id = sdfg.node_id(subgraph.graph)
            elif isinstance(subgraph.graph, SDFG):
                self.sdfg_id = subgraph.graph.sdfg_id
                self.state_id = -1
            else:
                raise TypeError('Unrecognized graph type "%s"' %
                                type(subgraph.graph).__name__)
        else:
            self.subgraph = subgraph
            self.sdfg_id = sdfg_id
            self.state_id = state_id

    def subgraph_view(self, sdfg: SDFG) -> SubgraphView:
        graph = sdfg.sdfg_list[self.sdfg_id]
        if self.state_id != -1:
            graph = graph.node(self.state_id)
        return SubgraphView(graph, [graph.node(idx) for idx in self.subgraph])

    @staticmethod
    def match(sdfg: SDFG, subgraph: SubgraphView) -> bool:
        """
        Tries to match the transformation on a given subgraph, returning
        True if this transformation can be applied.
        :param sdfg: The SDFG that includes the subgraph.
        :param subgraph: The SDFG or state subgraph to try to apply the 
                         transformation on.
        :return: True if the subgraph can be transformed, or False otherwise.
        """
        pass

    def apply(self, sdfg: SDFG):
        """
        Applies the transformation on the given subgraph.
        :param sdfg: The SDFG that includes the subgraph.
        """
        pass

    def to_json(self, parent=None):
        props = dace.serialize.all_properties_to_json(self)
        return {
            'type': 'SubgraphTransformation',
            'transformation': type(self).__name__,
            **props
        }

    @staticmethod
    def from_json(json_obj, context=None):
        xform = next(ext for ext in SubgraphTransformation.extensions().keys()
                     if ext.__name__ == json_obj['transformation'])

        # Reconstruct transformation
        ret = xform(json_obj['subgraph'], json_obj['sdfg_id'],
                    json_obj['state_id'])
        context = context or {}
        context['transformation'] = ret
        dace.serialize.set_properties_from_json(
            ret,
            json_obj,
            context=context,
            ignore_properties={'transformation', 'type'})
        return ret
예제 #8
0
class Tasklet(CodeNode):
    """ A node that contains a tasklet: a functional computation procedure
        that can only access external data specified using connectors.

        Tasklets may be implemented in Python, C++, or any supported
        language by the code generator.
    """

    code = CodeProperty(desc="Tasklet code", default=CodeBlock(""))
    state_fields = ListProperty(
        element_type=str, desc="Fields that are added to the global state")
    code_global = CodeProperty(
        desc="Global scope code needed for tasklet execution",
        default=CodeBlock("", dtypes.Language.CPP))
    code_init = CodeProperty(
        desc="Extra code that is called on DaCe runtime initialization",
        default=CodeBlock("", dtypes.Language.CPP))
    code_exit = CodeProperty(
        desc="Extra code that is called on DaCe runtime cleanup",
        default=CodeBlock("", dtypes.Language.CPP))
    library_expansion_symbols = SetProperty(
        str,
        desc="Free symbols that get lost in the expansion of a Library Node")

    debuginfo = DebugInfoProperty()

    instrument = EnumProperty(
        dtype=dtypes.InstrumentationType,
        desc="Measure execution statistics with given method",
        default=dtypes.InstrumentationType.No_Instrumentation)

    def __init__(self,
                 label,
                 inputs=None,
                 outputs=None,
                 code="",
                 language=dtypes.Language.Python,
                 state_fields=None,
                 code_global="",
                 code_init="",
                 code_exit="",
                 location=None,
                 debuginfo=None,
                 library_expansion_symbols=set()):
        super(Tasklet, self).__init__(label, location, inputs, outputs)

        self.code = CodeBlock(code, language)

        self.state_fields = state_fields or []
        self.code_global = CodeBlock(code_global, dtypes.Language.CPP)
        self.code_init = CodeBlock(code_init, dtypes.Language.CPP)
        self.code_exit = CodeBlock(code_exit, dtypes.Language.CPP)
        self.debuginfo = debuginfo
        self.library_expansion_symbols = library_expansion_symbols

    @property
    def language(self):
        return self.code.language

    @staticmethod
    def from_json(json_obj, context=None):
        ret = Tasklet("dummylabel")
        dace.serialize.set_properties_from_json(ret, json_obj, context=context)
        return ret

    @property
    def name(self):
        return self._label

    def validate(self, sdfg, state):
        if not dtypes.validate_name(self.label):
            raise NameError('Invalid tasklet name "%s"' % self.label)
        for in_conn in self.in_connectors:
            if not dtypes.validate_name(in_conn):
                raise NameError('Invalid input connector "%s"' % in_conn)
        for out_conn in self.out_connectors:
            if not dtypes.validate_name(out_conn):
                raise NameError('Invalid output connector "%s"' % out_conn)

    @property
    def free_symbols(self) -> Set[str]:
        result = super().free_symbols
        result |= self.code.get_free_symbols(self.in_connectors.keys()
                                             | self.out_connectors.keys())
        result |= self.library_expansion_symbols
        return result

    def infer_connector_types(self, sdfg, state):
        # If a MLIR tasklet, simply read out the types (it's explicit)
        if self.code.language == dtypes.Language.MLIR:
            # Inline import because mlir.utils depends on pyMLIR which may not be installed
            # Doesn't cause crashes due to missing pyMLIR if a MLIR tasklet is not present
            from dace.codegen.targets.mlir import utils

            mlir_ast = utils.get_ast(self.code.code)
            mlir_is_generic = utils.is_generic(mlir_ast)
            mlir_entry_func = utils.get_entry_func(mlir_ast, mlir_is_generic)

            mlir_result_type = utils.get_entry_result_type(
                mlir_entry_func, mlir_is_generic)
            mlir_out_name = next(iter(self.out_connectors.keys()))[0]

            if self.out_connectors[
                    mlir_out_name] is None or self.out_connectors[
                        mlir_out_name].ctype == "void":
                self.out_connectors[mlir_out_name] = utils.get_dace_type(
                    mlir_result_type)
            elif self.out_connectors[mlir_out_name] != utils.get_dace_type(
                    mlir_result_type):
                warnings.warn(
                    "Type mismatch between MLIR tasklet out connector and MLIR code"
                )

            for mlir_arg in utils.get_entry_args(mlir_entry_func,
                                                 mlir_is_generic):
                if self.in_connectors[
                        mlir_arg[0]] is None or self.in_connectors[
                            mlir_arg[0]].ctype == "void":
                    self.in_connectors[mlir_arg[0]] = utils.get_dace_type(
                        mlir_arg[1])
                elif self.in_connectors[mlir_arg[0]] != utils.get_dace_type(
                        mlir_arg[1]):
                    warnings.warn(
                        "Type mismatch between MLIR tasklet in connector and MLIR code"
                    )

            return

        # If a Python tasklet, use type inference to figure out all None output
        # connectors
        if all(cval.type is not None for cval in self.out_connectors.values()):
            return
        if self.code.language != dtypes.Language.Python:
            return

        if any(cval.type is None for cval in self.in_connectors.values()):
            raise TypeError('Cannot infer output connectors of tasklet "%s", '
                            'not all input connectors have types' % str(self))

        # Avoid import loop
        from dace.codegen.tools.type_inference import infer_types

        # Get symbols defined at beginning of node, and infer all types in
        # tasklet
        syms = state.symbols_defined_at(self)
        syms.update(self.in_connectors)
        new_syms = infer_types(self.code.code, syms)
        for cname, oconn in self.out_connectors.items():
            if oconn.type is None:
                if cname not in new_syms:
                    raise TypeError('Cannot infer type of tasklet %s output '
                                    '"%s", please specify manually.' %
                                    (self.label, cname))
                self.out_connectors[cname] = new_syms[cname]

    def __str__(self):
        if not self.label:
            return "--Empty--"
        else:
            return self.label
예제 #9
0
class Node(object):
    """ Base node class. """

    in_connectors = SetProperty(
        str, default=set(), desc="A set of input connectors for this node.")
    out_connectors = SetProperty(
        str, default=set(), desc="A set of output connectors for this node.")

    def __init__(self, in_connectors=None, out_connectors=None):
        self.in_connectors = in_connectors or set()
        self.out_connectors = out_connectors or set()

    def __str__(self):
        if hasattr(self, 'label'):
            return self.label
        else:
            return type(self).__name__

    def validate(self, sdfg, state):
        pass

    def to_json(self, parent):
        labelstr = str(self)
        typestr = str(type(self).__name__)

        scope_entry_node = parent.entry_node(self)
        if scope_entry_node is not None:
            ens = parent.exit_nodes(parent.entry_node(self))
            scope_exit_nodes = [str(parent.node_id(x)) for x in ens]
            scope_entry_node = str(parent.node_id(scope_entry_node))
        else:
            scope_entry_node = None
            scope_exit_nodes = []

        retdict = {
            "type": typestr,
            "label": labelstr,
            "attributes": dace.serialize.all_properties_to_json(self),
            "id": parent.node_id(self),
            "scope_entry": scope_entry_node,
            "scope_exits": scope_exit_nodes
        }
        return retdict

    def __repr__(self):
        return type(self).__name__ + ' (' + self.__str__() + ')'

    def add_in_connector(self, connector_name: str):
        """ Adds a new input connector to the node. The operation will fail if
            a connector (either input or output) with the same name already
            exists in the node.

            @param connector_name: The name of the new connector.
            @return: True if the operation is successful, otherwise False.
        """

        if (connector_name in self.in_connectors
                or connector_name in self.out_connectors):
            return False
        connectors = self.in_connectors
        connectors.add(connector_name)
        self.in_connectors = connectors
        return True

    def add_out_connector(self, connector_name: str):
        """ Adds a new output connector to the node. The operation will fail if
            a connector (either input or output) with the same name already
            exists in the node.

            @param connector_name: The name of the new connector.
            @return: True if the operation is successful, otherwise False.
        """

        if (connector_name in self.in_connectors
                or connector_name in self.out_connectors):
            return False
        connectors = self.out_connectors
        connectors.add(connector_name)
        self.out_connectors = connectors
        return True

    def remove_in_connector(self, connector_name: str):
        """ Removes an input connector from the node.
            @param connector_name: The name of the connector to remove.
            @return: True if the operation was successful.
        """

        if connector_name in self.in_connectors:
            connectors = self.in_connectors
            connectors.remove(connector_name)
            self.in_connectors = connectors
        return True

    def remove_out_connector(self, connector_name: str):
        """ Removes an output connector from the node.
            @param connector_name: The name of the connector to remove.
            @return: True if the operation was successful.
        """

        if connector_name in self.out_connectors:
            connectors = self.out_connectors
            connectors.remove(connector_name)
            self.out_connectors = connectors
        return True

    def _next_connector_int(self) -> int:
        """ Returns the next unused connector ID (as an integer). Used for
            filling connectors when adding edges to scopes. """
        next_number = 1
        for conn in itertools.chain(self.in_connectors, self.out_connectors):
            if conn.startswith('IN_'):
                cconn = conn[3:]
            elif conn.startswith('OUT_'):
                cconn = conn[4:]
            else:
                continue
            try:
                curconn = int(cconn)
                if curconn >= next_number:
                    next_number = curconn + 1
            except TypeError:  # not integral
                continue
        return next_number

    def next_connector(self) -> str:
        """ Returns the next unused connector ID (as a string). Used for
            filling connectors when adding edges to scopes. """
        return str(self._next_connector_int())

    def last_connector(self) -> str:
        """ Returns the last used connector ID (as a string). Used for
            filling connectors when adding edges to scopes. """
        return str(self._next_connector_int() - 1)