def __init__(self): """Set up the tree.""" self._root = Node(None) # keep a flat dictionary of nodes contained in the tree for better # __contains__ self._all_nodes = _DataIDContainer()
def add_child(self, parent, child): """Add a child to the tree.""" Node.add_child(parent, child) # Sanity check: Node objects should be unique. They can be added # multiple times if more than one Node depends on them # but they should all map to the same Node object. if self.contains(child.name): assert self._all_nodes[child.name] is child if child is self.empty_node: # No need to store "empty" nodes return self._all_nodes[child.name] = child
def _get_unique_node_from_id(self, result, data): try: # now that we know we have the exact DataID see if we have already created a Node for it return self.getitem(result) except KeyError: # we haven't created a node yet, create it now return Node(result, data)
def add_leaf(self, ds_id, parent=None): """Add a leaf to the tree.""" if parent is None: parent = self._root try: node = self[ds_id] except KeyError: node = Node(ds_id) self.add_child(parent, node) return node
class Tree: """A tree implementation.""" # simplify future logic by only having one "sentinel" empty node # making it a class attribute ensures it is the same across instances empty_node = Node(EMPTY_LEAF_NAME) def __init__(self): """Set up the tree.""" self._root = Node(None) # keep a flat dictionary of nodes contained in the tree for better # __contains__ self._all_nodes = _DataIDContainer() def leaves(self, limit_nodes_to: Optional[Iterable[DataID]] = None, unique: bool = True) -> list[Node]: """Get the leaves of the tree starting at the root. Args: limit_nodes_to: Limit leaves to Nodes with the names (DataIDs) specified. unique: Only include individual leaf nodes once. Returns: list of leaf nodes """ if limit_nodes_to is None: return self._root.leaves(unique=unique) res = list() for child_id in limit_nodes_to: for sub_child in self._all_nodes[child_id].leaves(unique=unique): if not unique or sub_child not in res: res.append(sub_child) return res def trunk( self, limit_nodes_to: Optional[Iterable[DataID]] = None, unique: bool = True, limit_children_to: Optional[Container[DataID]] = None, ) -> list[Node]: """Get the trunk nodes of the tree starting at this root. Args: limit_nodes_to: Limit searching to trunk nodes with the names (DataIDs) specified and the children of these nodes. unique: Only include individual trunk nodes once limit_children_to: Limit searching to the children with the specified names. These child nodes will be included in the result, but not their children. Returns: list of trunk nodes """ if limit_nodes_to is None: return self._root.trunk(unique=unique, limit_children_to=limit_children_to) res = list() for child_id in limit_nodes_to: child_node = self._all_nodes[child_id] for sub_child in child_node.trunk( unique=unique, limit_children_to=limit_children_to): if not unique or sub_child not in res: res.append(sub_child) return res def add_child(self, parent, child): """Add a child to the tree.""" Node.add_child(parent, child) # Sanity check: Node objects should be unique. They can be added # multiple times if more than one Node depends on them # but they should all map to the same Node object. if self.contains(child.name): if self._all_nodes[child.name] is not child: raise RuntimeError if child is self.empty_node: # No need to store "empty" nodes return self._all_nodes[child.name] = child def add_leaf(self, ds_id, parent=None): """Add a leaf to the tree.""" if parent is None: parent = self._root try: node = self[ds_id] except KeyError: node = Node(ds_id) self.add_child(parent, node) return node def __contains__(self, item): """Check if a item is in the tree.""" return item in self._all_nodes def __getitem__(self, item): """Get an item of the tree.""" return self._all_nodes[item] def contains(self, item): """Check contains when we know the *exact* DataID or DataQuery.""" return super(_DataIDContainer, self._all_nodes).__contains__(item) def getitem(self, item): """Get Node when we know the *exact* DataID or DataQuery.""" return super(_DataIDContainer, self._all_nodes).__getitem__(item) def __str__(self): """Render the dependency tree as a string.""" return self._root.display()