class CodeNode(Node): """ A node that contains runnable code with acyclic external data dependencies. May either be a tasklet or a nested SDFG, and denoted by an octagonal shape. """ label = Property(dtype=str, desc="Name of the CodeNode") location = DictProperty( key_type=str, value_type=dace.symbolic.pystr_to_symbolic, desc='Full storage location identifier (e.g., rank, GPU ID)') environments = SetProperty( str, desc="Environments required by CMake to build and run this code node.", default=set()) def __init__(self, label="", location=None, inputs=None, outputs=None): super(CodeNode, self).__init__(inputs or set(), outputs or set()) # Properties self.label = label self.location = location if location is not None else {} @property def free_symbols(self) -> Set[str]: return set().union(*(map(str, pystr_to_symbolic(v).free_symbols) for v in self.location.values()))
class CodeObject(object): name = Property(dtype=str, desc="Filename to use") code = Property(dtype=str, desc="The code attached to this object") language = Property(dtype=str, desc="Language used for this code (same " + "as its file extension)") target = Property(dtype=type, desc="Target to use for compilation", allow_none=True) target_type = Property( dtype=str, desc="Sub-target within target (e.g., host or device code)", default="") title = Property(dtype=str, desc="Title of code for GUI") extra_compiler_kwargs = DictProperty(key_type=str, value_type=str, desc="Additional compiler argument " "variables to add to template") linkable = Property(dtype=bool, desc='Should this file participate in ' 'overall linkage?') environments = SetProperty( str, desc="Environments required by CMake to build and run this code node.", default=set()) def __init__(self, name, code, language, target, title, target_type="", additional_compiler_kwargs=None, linkable=True, environments=None, sdfg=None): super(CodeObject, self).__init__() self.name = name self.code = code self.language = language self.target = target self.target_type = target_type self.title = title self.extra_compiler_kwargs = additional_compiler_kwargs or {} self.linkable = linkable self.environments = environments or set() if language == 'cpp' and title == 'Frame' and sdfg: sourcemap.create_maps(sdfg, code, self.target.target_name) @property def clean_code(self): return re.sub(r'[ \t]*////__(DACE:|CODEGEN;)[^\n]*', '', self.code)
class CodeObject(object): name = Property(dtype=str, desc="Filename to use") code = Property(dtype=str, desc="The code attached to this object") language = Property( dtype=str, desc="Language used for this code (same " + "as its file extension)") # dtype=dtypes.Language? target = Property( dtype=type, desc="Target to use for compilation", allow_none=True) target_type = Property( dtype=str, desc="Sub-target within target (e.g., host or device code)", default="") title = Property(dtype=str, desc="Title of code for GUI") extra_compiler_kwargs = Property( dtype=dict, desc="Additional compiler argument " "variables to add to template") linkable = Property( dtype=bool, desc='Should this file participate in ' 'overall linkage?') environments = SetProperty( str, desc="Environments required by CMake to build and run this code node.", default=set()) def __init__(self, name, code, language, target, title, target_type="", additional_compiler_kwargs=None, linkable=True, environments=set()): super(CodeObject, self).__init__() self.name = name self.code = code self.language = language self.target = target self.target_type = target_type self.title = title self.extra_compiler_kwargs = additional_compiler_kwargs or {} self.linkable = linkable self.environments = environments
class Node(object): """ Base node class. """ in_connectors = SetProperty( str, default=set(), desc="A set of input connectors for this node.") out_connectors = SetProperty( str, default=set(), desc="A set of output connectors for this node.") def __init__(self, in_connectors=set(), out_connectors=set()): self.in_connectors = in_connectors self.out_connectors = out_connectors def __str__(self): if hasattr(self, 'label'): return self.label else: return type(self).__name__ def validate(self, sdfg, state): pass def toJSON(self, indent=0): labelstr = str(self) typestr = str(type(self).__name__) inconn = "[" + ",".join( ['"' + str(x) + '"' for x in self.in_connectors]) + "]" outconn = "[" + ",".join( ['"' + str(x) + '"' for x in self.out_connectors]) + "]" json = " " * indent + "{ \"label\": \"" + labelstr json += "\", \"type\": \"" + typestr + "\", \"in_connectors\": " + inconn json += ", \"out_connectors\" :" + outconn json += "}\n" return json def __repr__(self): return type(self).__name__ + ' (' + self.__str__() + ')' def add_in_connector(self, connector_name: str): """ Adds a new input connector to the node. The operation will fail if a connector (either input or output) with the same name already exists in the node. @param connector_name: The name of the new connector. @return: True if the operation is successful, otherwise False. """ if (connector_name in self.in_connectors or connector_name in self.out_connectors): return False connectors = self.in_connectors connectors.add(connector_name) self.in_connectors = connectors return True def add_out_connector(self, connector_name: str): """ Adds a new output connector to the node. The operation will fail if a connector (either input or output) with the same name already exists in the node. @param connector_name: The name of the new connector. @return: True if the operation is successful, otherwise False. """ if (connector_name in self.in_connectors or connector_name in self.out_connectors): return False connectors = self.out_connectors connectors.add(connector_name) self.out_connectors = connectors return True def remove_in_connector(self, connector_name: str): """ Removes an input connector from the node. @param connector_name: The name of the connector to remove. @return: True if the operation was successful. """ if connector_name in self.in_connectors: connectors = self.in_connectors connectors.remove(connector_name) self.in_connectors = connectors return True def remove_out_connector(self, connector_name: str): """ Removes an output connector from the node. @param connector_name: The name of the connector to remove. @return: True if the operation was successful. """ if connector_name in self.out_connectors: connectors = self.out_connectors connectors.remove(connector_name) self.out_connectors = connectors return True def _next_connector_int(self) -> int: """ Returns the next unused connector ID (as an integer). Used for filling connectors when adding edges to scopes. """ next_number = 1 for conn in itertools.chain(self.in_connectors, self.out_connectors): if conn.startswith('IN_'): cconn = conn[3:] elif conn.startswith('OUT_'): cconn = conn[4:] else: continue try: curconn = int(cconn) if curconn >= next_number: next_number = curconn + 1 except TypeError: # not integral continue return next_number def next_connector(self) -> str: """ Returns the next unused connector ID (as a string). Used for filling connectors when adding edges to scopes. """ return str(self._next_connector_int()) def last_connector(self) -> str: """ Returns the last used connector ID (as a string). Used for filling connectors when adding edges to scopes. """ return str(self._next_connector_int() - 1)
class Node(object): """ Base node class. """ in_connectors = SetProperty( str, default=set(), desc="A set of input connectors for this node.") out_connectors = SetProperty( str, default=set(), desc="A set of output connectors for this node.") def __init__(self, in_connectors=None, out_connectors=None): self.in_connectors = in_connectors or set() self.out_connectors = out_connectors or set() def __str__(self): if hasattr(self, 'label'): return self.label else: return type(self).__name__ def validate(self, sdfg, state): pass def to_json(self, parent): labelstr = str(self) typestr = getattr(self, '__jsontype__', str(type(self).__name__)) try: scope_entry_node = parent.entry_node(self) except (RuntimeError, StopIteration): scope_entry_node = None if scope_entry_node is not None: ens = parent.exit_node(parent.entry_node(self)) scope_exit_node = str(parent.node_id(ens)) scope_entry_node = str(parent.node_id(scope_entry_node)) else: scope_entry_node = None scope_exit_node = None # The scope exit of an entry node is the matching exit node if isinstance(self, EntryNode): try: scope_exit_node = str(parent.node_id(parent.exit_node(self))) except (RuntimeError, StopIteration): scope_exit_node = None retdict = { "type": typestr, "label": labelstr, "attributes": dace.serialize.all_properties_to_json(self), "id": parent.node_id(self), "scope_entry": scope_entry_node, "scope_exit": scope_exit_node } return retdict def __repr__(self): return type(self).__name__ + ' (' + self.__str__() + ')' def add_in_connector(self, connector_name: str): """ Adds a new input connector to the node. The operation will fail if a connector (either input or output) with the same name already exists in the node. :param connector_name: The name of the new connector. :return: True if the operation is successful, otherwise False. """ if (connector_name in self.in_connectors or connector_name in self.out_connectors): return False connectors = self.in_connectors connectors.add(connector_name) self.in_connectors = connectors return True def add_out_connector(self, connector_name: str): """ Adds a new output connector to the node. The operation will fail if a connector (either input or output) with the same name already exists in the node. :param connector_name: The name of the new connector. :return: True if the operation is successful, otherwise False. """ if (connector_name in self.in_connectors or connector_name in self.out_connectors): return False connectors = self.out_connectors connectors.add(connector_name) self.out_connectors = connectors return True def remove_in_connector(self, connector_name: str): """ Removes an input connector from the node. :param connector_name: The name of the connector to remove. :return: True if the operation was successful. """ if connector_name in self.in_connectors: connectors = self.in_connectors connectors.remove(connector_name) self.in_connectors = connectors return True def remove_out_connector(self, connector_name: str): """ Removes an output connector from the node. :param connector_name: The name of the connector to remove. :return: True if the operation was successful. """ if connector_name in self.out_connectors: connectors = self.out_connectors connectors.remove(connector_name) self.out_connectors = connectors return True def _next_connector_int(self) -> int: """ Returns the next unused connector ID (as an integer). Used for filling connectors when adding edges to scopes. """ next_number = 1 for conn in itertools.chain(self.in_connectors, self.out_connectors): if conn.startswith('IN_'): cconn = conn[3:] elif conn.startswith('OUT_'): cconn = conn[4:] else: continue try: curconn = int(cconn) if curconn >= next_number: next_number = curconn + 1 except (TypeError, ValueError): # not integral continue return next_number def next_connector(self) -> str: """ Returns the next unused connector ID (as a string). Used for filling connectors when adding edges to scopes. """ return str(self._next_connector_int()) def last_connector(self) -> str: """ Returns the last used connector ID (as a string). Used for filling connectors when adding edges to scopes. """ return str(self._next_connector_int() - 1) @property def free_symbols(self) -> Set[str]: """ Returns a set of symbols used in this node's properties. """ return set() def new_symbols(self, sdfg, state, symbols) -> Dict[str, dtypes.typeclass]: """ Returns a mapping between symbols defined by this node (e.g., for scope entries) to their type. """ return {}
class SubgraphTransformation(TransformationBase): """ Base class for transformations that apply on arbitrary subgraphs, rather than matching a specific pattern. Subclasses need to implement the `can_be_applied` and `apply` operations, as well as registered with the subclass registry. See the `Transformation` class docstring for more information. """ sdfg_id = Property(dtype=int, desc='ID of SDFG to transform') state_id = Property( dtype=int, desc='ID of state to transform subgraph within, or -1 to transform the ' 'SDFG') subgraph = SetProperty(element_type=int, desc='Subgraph in transformation instance') def __init__(self, subgraph: Union[Set[int], gr.SubgraphView], sdfg_id: int = None, state_id: int = None): if (not isinstance(subgraph, (gr.SubgraphView, SDFG, SDFGState)) and (sdfg_id is None or state_id is None)): raise TypeError( 'Subgraph transformation either expects a SubgraphView or a ' 'set of node IDs, SDFG ID and state ID (or -1).') # An entire graph is given as a subgraph if isinstance(subgraph, (SDFG, SDFGState)): subgraph = gr.SubgraphView(subgraph, subgraph.nodes()) if isinstance(subgraph, gr.SubgraphView): self.subgraph = set( subgraph.graph.node_id(n) for n in subgraph.nodes()) if isinstance(subgraph.graph, SDFGState): sdfg = subgraph.graph.parent self.sdfg_id = sdfg.sdfg_id self.state_id = sdfg.node_id(subgraph.graph) elif isinstance(subgraph.graph, SDFG): self.sdfg_id = subgraph.graph.sdfg_id self.state_id = -1 else: raise TypeError('Unrecognized graph type "%s"' % type(subgraph.graph).__name__) else: self.subgraph = subgraph self.sdfg_id = sdfg_id self.state_id = state_id def subgraph_view(self, sdfg: SDFG) -> gr.SubgraphView: graph = sdfg.sdfg_list[self.sdfg_id] if self.state_id != -1: graph = graph.node(self.state_id) return gr.SubgraphView(graph, [graph.node(idx) for idx in self.subgraph]) def can_be_applied(self, sdfg: SDFG, subgraph: gr.SubgraphView) -> bool: """ Tries to match the transformation on a given subgraph, returning True if this transformation can be applied. :param sdfg: The SDFG that includes the subgraph. :param subgraph: The SDFG or state subgraph to try to apply the transformation on. :return: True if the subgraph can be transformed, or False otherwise. """ pass def apply(self, sdfg: SDFG): """ Applies the transformation on the given subgraph. :param sdfg: The SDFG that includes the subgraph. """ pass @classmethod def apply_to(cls, sdfg: SDFG, *where: Union[nd.Node, SDFGState, gr.SubgraphView], verify: bool = True, **options: Any): """ Applies this transformation to a given subgraph, defined by a set of nodes. Raises an error if arguments are invalid or transformation is not applicable. To apply the transformation on a specific subgraph, the `where` parameter can be used either on a subgraph object (`SubgraphView`), or on directly on a list of subgraph nodes, given as `Node` or `SDFGState` objects. Transformation properties can then be given as keyword arguments. For example, applying `SubgraphFusion` on a subgraph of three nodes can be called in one of two ways: ``` # Subgraph SubgraphFusion.apply_to( sdfg, SubgraphView(state, [node_a, node_b, node_c])) # Simplified API: list of nodes SubgraphFusion.apply_to(sdfg, node_a, node_b, node_c) ``` :param sdfg: The SDFG to apply the transformation to. :param where: A set of nodes in the SDFG/state, or a subgraph thereof. :param verify: Check that `can_be_applied` returns True before applying. :param options: A set of parameters to use for applying the transformation. """ subgraph = None if len(where) == 1: if isinstance(where[0], (list, tuple)): where = where[0] elif isinstance(where[0], gr.SubgraphView): subgraph = where[0] if len(where) == 0: raise ValueError('At least one node is required') # Check that all keyword arguments are nodes and if interstate or not if subgraph is None: sample_node = where[0] if isinstance(sample_node, SDFGState): graph = sdfg state_id = -1 elif isinstance(sample_node, nd.Node): graph = next(s for s in sdfg.nodes() if sample_node in s.nodes()) state_id = sdfg.node_id(graph) else: raise TypeError('Invalid node type "%s"' % type(sample_node).__name__) # Construct subgraph and instantiate transformation subgraph = gr.SubgraphView(graph, where) instance = cls(subgraph, sdfg.sdfg_id, state_id) else: # Construct instance from subgraph directly instance = cls(subgraph) # Construct transformation parameters for optname, optval in options.items(): if not optname in cls.__properties__: raise ValueError('Property "%s" not found in transformation' % optname) setattr(instance, optname, optval) if verify: if not instance.can_be_applied(sdfg, subgraph): raise ValueError('Transformation cannot be applied on the ' 'given subgraph ("can_be_applied" failed)') # Apply to SDFG return instance.apply(sdfg) def to_json(self, parent=None): props = serialize.all_properties_to_json(self) return { 'type': 'SubgraphTransformation', 'transformation': type(self).__name__, **props } @staticmethod def from_json(json_obj: Dict[str, Any], context: Dict[str, Any] = None) -> 'SubgraphTransformation': xform = next(ext for ext in SubgraphTransformation.extensions().keys() if ext.__name__ == json_obj['transformation']) # Reconstruct transformation ret = xform(json_obj['subgraph'], json_obj['sdfg_id'], json_obj['state_id']) context = context or {} context['transformation'] = ret serialize.set_properties_from_json( ret, json_obj, context=context, ignore_properties={'transformation', 'type'}) return ret
class SubgraphTransformation(object): """ Base class for transformations that apply on arbitrary subgraphs, rather than matching a specific pattern. Subclasses need to implement the `match` and `apply` operations. """ sdfg_id = Property(dtype=int, desc='ID of SDFG to transform') state_id = Property( dtype=int, desc='ID of state to transform subgraph within, or -1 to transform the ' 'SDFG') subgraph = SetProperty(element_type=int, desc='Subgraph in transformation instance') def __init__(self, subgraph: Union[Set[int], SubgraphView], sdfg_id: int = None, state_id: int = None): if (not isinstance(subgraph, (SubgraphView, SDFG, SDFGState)) and (sdfg_id is None or state_id is None)): raise TypeError( 'Subgraph transformation either expects a SubgraphView or a ' 'set of node IDs, SDFG ID and state ID (or -1).') # An entire graph is given as a subgraph if isinstance(subgraph, (SDFG, SDFGState)): subgraph = SubgraphView(subgraph, subgraph.nodes()) if isinstance(subgraph, SubgraphView): self.subgraph = set( subgraph.graph.node_id(n) for n in subgraph.nodes()) if isinstance(subgraph.graph, SDFGState): sdfg = subgraph.graph.parent self.sdfg_id = sdfg.sdfg_id self.state_id = sdfg.node_id(subgraph.graph) elif isinstance(subgraph.graph, SDFG): self.sdfg_id = subgraph.graph.sdfg_id self.state_id = -1 else: raise TypeError('Unrecognized graph type "%s"' % type(subgraph.graph).__name__) else: self.subgraph = subgraph self.sdfg_id = sdfg_id self.state_id = state_id def subgraph_view(self, sdfg: SDFG) -> SubgraphView: graph = sdfg.sdfg_list[self.sdfg_id] if self.state_id != -1: graph = graph.node(self.state_id) return SubgraphView(graph, [graph.node(idx) for idx in self.subgraph]) @staticmethod def match(sdfg: SDFG, subgraph: SubgraphView) -> bool: """ Tries to match the transformation on a given subgraph, returning True if this transformation can be applied. :param sdfg: The SDFG that includes the subgraph. :param subgraph: The SDFG or state subgraph to try to apply the transformation on. :return: True if the subgraph can be transformed, or False otherwise. """ pass def apply(self, sdfg: SDFG): """ Applies the transformation on the given subgraph. :param sdfg: The SDFG that includes the subgraph. """ pass def to_json(self, parent=None): props = dace.serialize.all_properties_to_json(self) return { 'type': 'SubgraphTransformation', 'transformation': type(self).__name__, **props } @staticmethod def from_json(json_obj, context=None): xform = next(ext for ext in SubgraphTransformation.extensions().keys() if ext.__name__ == json_obj['transformation']) # Reconstruct transformation ret = xform(json_obj['subgraph'], json_obj['sdfg_id'], json_obj['state_id']) context = context or {} context['transformation'] = ret dace.serialize.set_properties_from_json( ret, json_obj, context=context, ignore_properties={'transformation', 'type'}) return ret
class Tasklet(CodeNode): """ A node that contains a tasklet: a functional computation procedure that can only access external data specified using connectors. Tasklets may be implemented in Python, C++, or any supported language by the code generator. """ code = CodeProperty(desc="Tasklet code", default=CodeBlock("")) state_fields = ListProperty( element_type=str, desc="Fields that are added to the global state") code_global = CodeProperty( desc="Global scope code needed for tasklet execution", default=CodeBlock("", dtypes.Language.CPP)) code_init = CodeProperty( desc="Extra code that is called on DaCe runtime initialization", default=CodeBlock("", dtypes.Language.CPP)) code_exit = CodeProperty( desc="Extra code that is called on DaCe runtime cleanup", default=CodeBlock("", dtypes.Language.CPP)) library_expansion_symbols = SetProperty( str, desc="Free symbols that get lost in the expansion of a Library Node") debuginfo = DebugInfoProperty() instrument = EnumProperty( dtype=dtypes.InstrumentationType, desc="Measure execution statistics with given method", default=dtypes.InstrumentationType.No_Instrumentation) def __init__(self, label, inputs=None, outputs=None, code="", language=dtypes.Language.Python, state_fields=None, code_global="", code_init="", code_exit="", location=None, debuginfo=None, library_expansion_symbols=set()): super(Tasklet, self).__init__(label, location, inputs, outputs) self.code = CodeBlock(code, language) self.state_fields = state_fields or [] self.code_global = CodeBlock(code_global, dtypes.Language.CPP) self.code_init = CodeBlock(code_init, dtypes.Language.CPP) self.code_exit = CodeBlock(code_exit, dtypes.Language.CPP) self.debuginfo = debuginfo self.library_expansion_symbols = library_expansion_symbols @property def language(self): return self.code.language @staticmethod def from_json(json_obj, context=None): ret = Tasklet("dummylabel") dace.serialize.set_properties_from_json(ret, json_obj, context=context) return ret @property def name(self): return self._label def validate(self, sdfg, state): if not dtypes.validate_name(self.label): raise NameError('Invalid tasklet name "%s"' % self.label) for in_conn in self.in_connectors: if not dtypes.validate_name(in_conn): raise NameError('Invalid input connector "%s"' % in_conn) for out_conn in self.out_connectors: if not dtypes.validate_name(out_conn): raise NameError('Invalid output connector "%s"' % out_conn) @property def free_symbols(self) -> Set[str]: result = super().free_symbols result |= self.code.get_free_symbols(self.in_connectors.keys() | self.out_connectors.keys()) result |= self.library_expansion_symbols return result def infer_connector_types(self, sdfg, state): # If a MLIR tasklet, simply read out the types (it's explicit) if self.code.language == dtypes.Language.MLIR: # Inline import because mlir.utils depends on pyMLIR which may not be installed # Doesn't cause crashes due to missing pyMLIR if a MLIR tasklet is not present from dace.codegen.targets.mlir import utils mlir_ast = utils.get_ast(self.code.code) mlir_is_generic = utils.is_generic(mlir_ast) mlir_entry_func = utils.get_entry_func(mlir_ast, mlir_is_generic) mlir_result_type = utils.get_entry_result_type( mlir_entry_func, mlir_is_generic) mlir_out_name = next(iter(self.out_connectors.keys()))[0] if self.out_connectors[ mlir_out_name] is None or self.out_connectors[ mlir_out_name].ctype == "void": self.out_connectors[mlir_out_name] = utils.get_dace_type( mlir_result_type) elif self.out_connectors[mlir_out_name] != utils.get_dace_type( mlir_result_type): warnings.warn( "Type mismatch between MLIR tasklet out connector and MLIR code" ) for mlir_arg in utils.get_entry_args(mlir_entry_func, mlir_is_generic): if self.in_connectors[ mlir_arg[0]] is None or self.in_connectors[ mlir_arg[0]].ctype == "void": self.in_connectors[mlir_arg[0]] = utils.get_dace_type( mlir_arg[1]) elif self.in_connectors[mlir_arg[0]] != utils.get_dace_type( mlir_arg[1]): warnings.warn( "Type mismatch between MLIR tasklet in connector and MLIR code" ) return # If a Python tasklet, use type inference to figure out all None output # connectors if all(cval.type is not None for cval in self.out_connectors.values()): return if self.code.language != dtypes.Language.Python: return if any(cval.type is None for cval in self.in_connectors.values()): raise TypeError('Cannot infer output connectors of tasklet "%s", ' 'not all input connectors have types' % str(self)) # Avoid import loop from dace.codegen.tools.type_inference import infer_types # Get symbols defined at beginning of node, and infer all types in # tasklet syms = state.symbols_defined_at(self) syms.update(self.in_connectors) new_syms = infer_types(self.code.code, syms) for cname, oconn in self.out_connectors.items(): if oconn.type is None: if cname not in new_syms: raise TypeError('Cannot infer type of tasklet %s output ' '"%s", please specify manually.' % (self.label, cname)) self.out_connectors[cname] = new_syms[cname] def __str__(self): if not self.label: return "--Empty--" else: return self.label
class Node(object): """ Base node class. """ in_connectors = SetProperty( str, default=set(), desc="A set of input connectors for this node.") out_connectors = SetProperty( str, default=set(), desc="A set of output connectors for this node.") def __init__(self, in_connectors=None, out_connectors=None): self.in_connectors = in_connectors or set() self.out_connectors = out_connectors or set() def __str__(self): if hasattr(self, 'label'): return self.label else: return type(self).__name__ def validate(self, sdfg, state): pass def to_json(self, parent): labelstr = str(self) typestr = str(type(self).__name__) scope_entry_node = parent.entry_node(self) if scope_entry_node is not None: ens = parent.exit_nodes(parent.entry_node(self)) scope_exit_nodes = [str(parent.node_id(x)) for x in ens] scope_entry_node = str(parent.node_id(scope_entry_node)) else: scope_entry_node = None scope_exit_nodes = [] retdict = { "type": typestr, "label": labelstr, "attributes": dace.serialize.all_properties_to_json(self), "id": parent.node_id(self), "scope_entry": scope_entry_node, "scope_exits": scope_exit_nodes } return retdict def __repr__(self): return type(self).__name__ + ' (' + self.__str__() + ')' def add_in_connector(self, connector_name: str): """ Adds a new input connector to the node. The operation will fail if a connector (either input or output) with the same name already exists in the node. @param connector_name: The name of the new connector. @return: True if the operation is successful, otherwise False. """ if (connector_name in self.in_connectors or connector_name in self.out_connectors): return False connectors = self.in_connectors connectors.add(connector_name) self.in_connectors = connectors return True def add_out_connector(self, connector_name: str): """ Adds a new output connector to the node. The operation will fail if a connector (either input or output) with the same name already exists in the node. @param connector_name: The name of the new connector. @return: True if the operation is successful, otherwise False. """ if (connector_name in self.in_connectors or connector_name in self.out_connectors): return False connectors = self.out_connectors connectors.add(connector_name) self.out_connectors = connectors return True def remove_in_connector(self, connector_name: str): """ Removes an input connector from the node. @param connector_name: The name of the connector to remove. @return: True if the operation was successful. """ if connector_name in self.in_connectors: connectors = self.in_connectors connectors.remove(connector_name) self.in_connectors = connectors return True def remove_out_connector(self, connector_name: str): """ Removes an output connector from the node. @param connector_name: The name of the connector to remove. @return: True if the operation was successful. """ if connector_name in self.out_connectors: connectors = self.out_connectors connectors.remove(connector_name) self.out_connectors = connectors return True def _next_connector_int(self) -> int: """ Returns the next unused connector ID (as an integer). Used for filling connectors when adding edges to scopes. """ next_number = 1 for conn in itertools.chain(self.in_connectors, self.out_connectors): if conn.startswith('IN_'): cconn = conn[3:] elif conn.startswith('OUT_'): cconn = conn[4:] else: continue try: curconn = int(cconn) if curconn >= next_number: next_number = curconn + 1 except TypeError: # not integral continue return next_number def next_connector(self) -> str: """ Returns the next unused connector ID (as a string). Used for filling connectors when adding edges to scopes. """ return str(self._next_connector_int()) def last_connector(self) -> str: """ Returns the last used connector ID (as a string). Used for filling connectors when adding edges to scopes. """ return str(self._next_connector_int() - 1)