def get_state_graph_by_idx( self, idx: int, G: nx.MultiDiGraph, device: torch.device, ) -> Optional[nx.MultiDiGraph]: episode = idx // self.num_frame frame = idx % self.num_frame G.clear() G.graph["feat"] = torch.zeros(1, self.graph_feat_size).to(device) if self.get_delta_node_feat(episode, frame, 0) is None: return None else: for i in range(self.num_ingraph_node): # load node feat G.add_node( i, feat=(torch.Tensor(self.get_node_feat( episode, frame, i)).view(1, -1).to(device)), ) # query all edges in (episode, frame) edge_portion = self.edges[np.where( (self.edges[:, 0] == episode) * (self.edges[:, 1] == frame))] edge_feat_portion = self.edge_feats[np.where( (self.edges[:, 0] == episode) * (self.edges[:, 1] == frame))] num_edge = edge_portion.shape[0] for i in range(num_edge): snode = edge_portion[i, 2] rnode = edge_portion[i, 3] G.add_edge( snode, rnode, feat=torch.Tensor(edge_feat_portion[i, :]).view( 1, -1).to(device), ) return G
def get_delta_state_graph_by_idx( self, idx: int, G: nx.MultiDiGraph, device: torch.device, ) -> Optional[nx.MultiDiGraph]: episode = idx // self.num_frame frame = idx % self.num_frame G.clear() G.graph["feat"] = torch.zeros(1, self.graph_feat_size).to(device) if self.get_delta_node_feat(episode, frame, 0) is None: return None else: for i in range(self.num_ingraph_node): # load node feat G.add_node( i, feat=(torch.Tensor( self.get_delta_node_feat(episode, frame, i)).view(1, -1).to(device)), ) return G
class NetworkXKB(KnowledgeStore): """A NetworkX implementation of a knowledge store.""" def __init__(self, activation_fn=None): """Initialize the NetworkXKB.""" # parameters if activation_fn is None: activation_fn = (lambda graph, mem_id: None) self.activation_fn = activation_fn # variables self.graph = MultiDiGraph() self.inverted_index = defaultdict(set) self.query_results = None self.result_index = None self.clear() def clear(self): # noqa: D102 self.graph.clear() self.inverted_index.clear() self.query_results = None self.result_index = None def store(self, mem_id=None, **kwargs): # noqa: D102 if mem_id is None: mem_id = uuid() if mem_id not in self.graph: self.graph.add_node(mem_id, activation=0) else: self.activation_fn(self.graph, mem_id) for attribute, value in kwargs.items(): if value not in self.graph: self.graph.add_node(value, activation=0) self.graph.add_edge(mem_id, value, attribute=attribute) self.inverted_index[attribute].add(mem_id) return True def _activate_and_return(self, mem_id): self.activation_fn(self.graph, mem_id) result = TreeMultiMap() for _, value, data in self.graph.out_edges(mem_id, data=True): result.add(data['attribute'], value) return result def retrieve(self, mem_id): # noqa: D102 if mem_id not in self.graph: return None return self._activate_and_return(mem_id) def query(self, attr_vals): # noqa: D102 # first pass: get candidates with all the attributes candidates = set.intersection(*(self.inverted_index[attribute] for attribute in attr_vals.keys())) # second pass: get candidates with the correct values candidates = set(candidate for candidate in candidates if all( ((candidate, value) in self.graph.edges and self.graph. get_edge_data(candidate, value)[0]['attribute'] == attribute) for attribute, value in attr_vals.items())) # quit early if there are no results if not candidates: self.query_results = None self.result_index = None return None # final pass: sort results by activation self.query_results = sorted( candidates, key=(lambda mem_id: self.graph.nodes[mem_id]['activation']), reverse=True, ) self.result_index = 0 return self._activate_and_return(self.query_results[self.result_index]) @property def has_prev_result(self): # noqa: D102 return (self.query_results is not None and self.result_index > 0) def prev_result(self): # noqa: D102 self.result_index -= 1 return self._activate_and_return(self.query_results[self.result_index]) @property def has_next_result(self): # noqa: D102 return (self.query_results is not None and self.result_index < len(self.query_results) - 1) def next_result(self): # noqa: D102 self.result_index += 1 return self._activate_and_return(self.query_results[self.result_index]) @staticmethod def retrievable(mem_id): # noqa: D102 return isinstance(mem_id, Hashable)
class NxGraph(BaseGraph): """ NxGraph is a wrapper that provides methods to interact with a networkx.MultiDiGraph. NxGraph extends kgx.graph.base_graph.BaseGraph and implements all the methods from BaseGraph. """ def __init__(self): super().__init__() self.graph = MultiDiGraph() self.name = None def add_node(self, node: str, **kwargs: Any) -> None: """ Add a node to the graph. Parameters ---------- node: str Node identifier **kwargs: Any Any additional node properties """ if "data" in kwargs: data = kwargs["data"] else: data = kwargs self.graph.add_node(node, **data) def add_edge(self, subject_node: str, object_node: str, edge_key: str = None, **kwargs: Any) -> None: """ Add an edge to the graph. Parameters ---------- subject_node: str The subject (source) node object_node: str The object (target) node edge_key: Optional[str] The edge key kwargs: Any Any additional edge properties """ if "data" in kwargs: data = kwargs["data"] else: data = kwargs return self.graph.add_edge(subject_node, object_node, key=edge_key, **data) def add_node_attribute(self, node: str, attr_key: str, attr_value: Any) -> None: """ Add an attribute to a given node. Parameters ---------- node: str The node identifier attr_key: str The key for an attribute attr_value: Any The value corresponding to the key """ self.graph.add_node(node, **{attr_key: attr_value}) def add_edge_attribute( self, subject_node: str, object_node: str, edge_key: Optional[str], attr_key: str, attr_value: Any, ) -> None: """ Add an attribute to a given edge. Parameters ---------- subject_node: str The subject (source) node object_node: str The object (target) node edge_key: Optional[str] The edge key attr_key: str The attribute key attr_value: Any The attribute value """ self.graph.add_edge(subject_node, object_node, key=edge_key, **{attr_key: attr_value}) def update_node_attribute(self, node: str, attr_key: str, attr_value: Any, preserve: bool = False) -> Dict: """ Update an attribute of a given node. Parameters ---------- node: str The node identifier attr_key: str The key for an attribute attr_value: Any The value corresponding to the key preserve: bool Whether or not to preserve existing values for the given attr_key Returns ------- Dict A dictionary corresponding to the updated node properties """ node_data = self.graph.nodes[node] updated = prepare_data_dict(node_data, {attr_key: attr_value}, preserve=preserve) self.graph.add_node(node, **updated) return updated def update_edge_attribute( self, subject_node: str, object_node: str, edge_key: Optional[str], attr_key: str, attr_value: Any, preserve: bool = False, ) -> Dict: """ Update an attribute of a given edge. Parameters ---------- subject_node: str The subject (source) node object_node: str The object (target) node edge_key: Optional[str] The edge key attr_key: str The attribute key attr_value: Any The attribute value preserve: bool Whether or not to preserve existing values for the given attr_key Returns ------- Dict A dictionary corresponding to the updated edge properties """ e = self.graph.edges((subject_node, object_node, edge_key), keys=True, data=True) edge_data = list(e)[0][3] updated = prepare_data_dict(edge_data, {attr_key: attr_value}, preserve) self.graph.add_edge(subject_node, object_node, key=edge_key, **updated) return updated def get_node(self, node: str) -> Dict: """ Get a node and its properties. Parameters ---------- node: str The node identifier Returns ------- Dict The node dictionary """ n = {} if self.graph.has_node(node): n = self.graph.nodes[node] return n def get_edge(self, subject_node: str, object_node: str, edge_key: Optional[str] = None) -> Dict: """ Get an edge and its properties. Parameters ---------- subject_node: str The subject (source) node object_node: str The object (target) node edge_key: Optional[str] The edge key Returns ------- Dict The edge dictionary """ e = {} if self.graph.has_edge(subject_node, object_node, edge_key): e = self.graph.get_edge_data(subject_node, object_node, edge_key) return e def nodes(self, data: bool = True) -> Dict: """ Get all nodes in a graph. Parameters ---------- data: bool Whether or not to fetch node properties Returns ------- Dict A dictionary of nodes """ return self.graph.nodes(data) def edges(self, keys: bool = False, data: bool = True) -> Dict: """ Get all edges in a graph. Parameters ---------- keys: bool Whether or not to include edge keys data: bool Whether or not to fetch node properties Returns ------- Dict A dictionary of edges """ return self.graph.edges(keys=keys, data=data) def in_edges(self, node: str, keys: bool = False, data: bool = False) -> List: """ Get all incoming edges for a given node. Parameters ---------- node: str The node identifier keys: bool Whether or not to include edge keys data: bool Whether or not to fetch node properties Returns ------- List A list of edges """ return self.graph.in_edges(node, keys=keys, data=data) def out_edges(self, node: str, keys: bool = False, data: bool = False) -> List: """ Get all outgoing edges for a given node. Parameters ---------- node: str The node identifier keys: bool Whether or not to include edge keys data: bool Whether or not to fetch node properties Returns ------- List A list of edges """ return self.graph.out_edges(node, keys=keys, data=data) def nodes_iter(self) -> Generator: """ Get an iterable to traverse through all the nodes in a graph. Returns ------- Generator A generator for nodes where each element is a Tuple that contains (node_id, node_data) """ for n in self.graph.nodes(data=True): yield n def edges_iter(self) -> Generator: """ Get an iterable to traverse through all the edges in a graph. Returns ------- Generator A generator for edges where each element is a 4-tuple that contains (subject, object, edge_key, edge_data) """ for u, v, k, data in self.graph.edges(keys=True, data=True): yield u, v, k, data def remove_node(self, node: str) -> None: """ Remove a given node from the graph. Parameters ---------- node: str The node identifier """ self.graph.remove_node(node) def remove_edge(self, subject_node: str, object_node: str, edge_key: Optional[str] = None) -> None: """ Remove a given edge from the graph. Parameters ---------- subject_node: str The subject (source) node object_node: str The object (target) node edge_key: Optional[str] The edge key """ self.graph.remove_edge(subject_node, object_node, edge_key) def has_node(self, node: str) -> bool: """ Check whether a given node exists in the graph. Parameters ---------- node: str The node identifier Returns ------- bool Whether or not the given node exists """ return self.graph.has_node(node) def has_edge(self, subject_node: str, object_node: str, edge_key: Optional[str] = None) -> bool: """ Check whether a given edge exists in the graph. Parameters ---------- subject_node: str The subject (source) node object_node: str The object (target) node edge_key: Optional[str] The edge key Returns ------- bool Whether or not the given edge exists """ return self.graph.has_edge(subject_node, object_node, key=edge_key) def number_of_nodes(self) -> int: """ Returns the number of nodes in a graph. Returns ------- int """ return self.graph.number_of_nodes() def number_of_edges(self) -> int: """ Returns the number of edges in a graph. Returns ------- int """ return self.graph.number_of_edges() def degree(self): """ Get the degree of all the nodes in a graph. """ return self.graph.degree() def clear(self) -> None: """ Remove all the nodes and edges in the graph. """ self.graph.clear() @staticmethod def set_node_attributes(graph: BaseGraph, attributes: Dict) -> None: """ Set nodes attributes from a dictionary of key-values. Parameters ---------- graph: kgx.graph.base_graph.BaseGraph The graph to modify attributes: Dict A dictionary of node identifier to key-value pairs """ return set_node_attributes(graph.graph, attributes) @staticmethod def set_edge_attributes(graph: BaseGraph, attributes: Dict) -> None: """ Set nodes attributes from a dictionary of key-values. Parameters ---------- graph: kgx.graph.base_graph.BaseGraph The graph to modify attributes: Dict A dictionary of node identifier to key-value pairs Returns ------- Any """ return set_edge_attributes(graph.graph, attributes) @staticmethod def get_node_attributes(graph: BaseGraph, attr_key: str) -> Dict: """ Get all nodes that have a value for the given attribute ``attr_key``. Parameters ---------- graph: kgx.graph.base_graph.BaseGraph The graph to modify attr_key: str The attribute key Returns ------- Dict A dictionary where nodes are the keys and the values are the attribute values for ``key`` """ return get_node_attributes(graph.graph, attr_key) @staticmethod def get_edge_attributes(graph: BaseGraph, attr_key: str) -> Dict: """ Get all edges that have a value for the given attribute ``attr_key``. Parameters ---------- graph: kgx.graph.base_graph.BaseGraph The graph to modify attr_key: str The attribute key Returns ------- Dict A dictionary where edges are the keys and the values are the attribute values for ``attr_key`` """ return get_edge_attributes(graph.graph, attr_key) @staticmethod def relabel_nodes(graph: BaseGraph, mapping: Dict) -> None: """ Relabel identifiers for a series of nodes based on mappings. Parameters ---------- graph: kgx.graph.base_graph.BaseGraph The graph to modify mapping: Dict A dictionary of mapping where the key is the old identifier and the value is the new identifier. """ relabel_nodes(graph.graph, mapping, copy=False)
class NetworkXLTM(LongTermMemory): """A NetworkX implementation of LTM.""" def __init__(self, **kwargs): # type: (**Any) -> None """Initialize the NetworkXLTM. Parameters: **kwargs: Arbitrary keyword arguments. """ super().__init__(**kwargs) self.graph = MultiDiGraph() self.inverted_index = defaultdict(set) # type: Dict[Hashable, Set[Hashable]] self.query_results = [] # type: List[Hashable] self.result_index = -1 self.clear() def clear(self): # noqa: D102 # type: () -> None self.graph.clear() self.inverted_index.clear() self.query_results = [] self.result_index = -1 def get_activation(self, mem_id): # type: (Hashable) -> float """Get the activation of a memory element. Parameters: mem_id (Hashable): The ID of the element. Returns: float: The activation of the element. """ return self.graph.nodes[mem_id]['activation'] def store(self, mem_id=None, time=0, **kwargs): # noqa: D102 # type: (Hashable, int, **Any) -> bool if mem_id is None: mem_id = uuid() if mem_id not in self.graph: self.graph.add_node(mem_id, activation=0) else: self.activation_fn(self, mem_id, time) for attribute, value in kwargs.items(): if value not in self.graph: self.graph.add_node(value, activation=0) self.graph.add_edge(mem_id, value, attribute=attribute) self.inverted_index[attribute].add(mem_id) return True def _activate_and_return(self, mem_id, time): # type: (Hashable, int) -> AVLTree self.activation_fn(self, mem_id, time) result = AVLTree() for _, value, data in self.graph.out_edges(mem_id, data=True): result.add(AttrVal(data['attribute'], value)) return result def retrieve(self, mem_id, time=0): # noqa: D102 # type: (Hashable, int) -> Optional[AVLTree] if mem_id not in self.graph: return None return self._activate_and_return(mem_id, time=time) def query(self, attr_vals, time=0): # noqa: D102 # type: (AbstractSet[AttrVal], int) -> Optional[AVLTree] # first pass: get candidates with all the attributes attrs = set(attr for attr, _ in attr_vals) candidates = set.intersection(*( self.inverted_index[attribute] for attribute in attrs )) # second pass: get candidates with the correct values candidates = set( candidate for candidate in candidates if all(( (candidate, val) in self.graph.edges and any( attr_dict['attribute'] == attr for attr_dict in self.graph.get_edge_data(candidate, val).values() ) ) for attr, val in attr_vals) ) # quit early if there are no results if not candidates: self.query_results = [] self.result_index = -1 return None # final pass: sort results by activation self.query_results = sorted( candidates, key=(lambda mem_id: self.graph.nodes[mem_id]['activation']), reverse=True, ) self.result_index = 0 return self._activate_and_return(self.query_results[self.result_index], time=time) @property def has_prev_result(self): # noqa: D102 # type: () -> bool return bool(self.query_results) and self.result_index > 0 def prev_result(self, time=0): # noqa: D102 # type: (int) -> Optional[AVLTree] self.result_index -= 1 return self._activate_and_return(self.query_results[self.result_index], time=time) @property def has_next_result(self): # noqa: D102 # type: () -> bool return bool(self.query_results) and -1 < self.result_index < len(self.query_results) - 1 def next_result(self, time=0): # noqa: D102 # type: (int) -> Optional[AVLTree] self.result_index += 1 return self._activate_and_return(self.query_results[self.result_index], time=time) @staticmethod def retrievable(mem_id): # noqa: D102 # type: (Any) -> bool return mem_id is not None and isinstance(mem_id, Hashable)