Example #1
0
    def get_state_graph_by_idx(
        self,
        idx: int,
        G: nx.MultiDiGraph,
        device: torch.device,
    ) -> Optional[nx.MultiDiGraph]:
        episode = idx // self.num_frame
        frame = idx % self.num_frame

        G.clear()
        G.graph["feat"] = torch.zeros(1, self.graph_feat_size).to(device)

        if self.get_delta_node_feat(episode, frame, 0) is None:
            return None
        else:
            for i in range(self.num_ingraph_node):
                # load node feat
                G.add_node(
                    i,
                    feat=(torch.Tensor(self.get_node_feat(
                        episode, frame, i)).view(1, -1).to(device)),
                )

                # query all edges in (episode, frame)
                edge_portion = self.edges[np.where(
                    (self.edges[:, 0] == episode) *
                    (self.edges[:, 1] == frame))]
                edge_feat_portion = self.edge_feats[np.where(
                    (self.edges[:, 0] == episode) *
                    (self.edges[:, 1] == frame))]
                num_edge = edge_portion.shape[0]
                for i in range(num_edge):
                    snode = edge_portion[i, 2]
                    rnode = edge_portion[i, 3]

                    G.add_edge(
                        snode,
                        rnode,
                        feat=torch.Tensor(edge_feat_portion[i, :]).view(
                            1, -1).to(device),
                    )
            return G
Example #2
0
    def get_delta_state_graph_by_idx(
        self,
        idx: int,
        G: nx.MultiDiGraph,
        device: torch.device,
    ) -> Optional[nx.MultiDiGraph]:
        episode = idx // self.num_frame
        frame = idx % self.num_frame

        G.clear()
        G.graph["feat"] = torch.zeros(1, self.graph_feat_size).to(device)

        if self.get_delta_node_feat(episode, frame, 0) is None:
            return None
        else:
            for i in range(self.num_ingraph_node):
                # load node feat
                G.add_node(
                    i,
                    feat=(torch.Tensor(
                        self.get_delta_node_feat(episode, frame,
                                                 i)).view(1, -1).to(device)),
                )
            return G
Example #3
0
class NetworkXKB(KnowledgeStore):
    """A NetworkX implementation of a knowledge store."""
    def __init__(self, activation_fn=None):
        """Initialize the NetworkXKB."""
        # parameters
        if activation_fn is None:
            activation_fn = (lambda graph, mem_id: None)
        self.activation_fn = activation_fn
        # variables
        self.graph = MultiDiGraph()
        self.inverted_index = defaultdict(set)
        self.query_results = None
        self.result_index = None
        self.clear()

    def clear(self):  # noqa: D102
        self.graph.clear()
        self.inverted_index.clear()
        self.query_results = None
        self.result_index = None

    def store(self, mem_id=None, **kwargs):  # noqa: D102
        if mem_id is None:
            mem_id = uuid()
        if mem_id not in self.graph:
            self.graph.add_node(mem_id, activation=0)
        else:
            self.activation_fn(self.graph, mem_id)
        for attribute, value in kwargs.items():
            if value not in self.graph:
                self.graph.add_node(value, activation=0)
            self.graph.add_edge(mem_id, value, attribute=attribute)
            self.inverted_index[attribute].add(mem_id)
        return True

    def _activate_and_return(self, mem_id):
        self.activation_fn(self.graph, mem_id)
        result = TreeMultiMap()
        for _, value, data in self.graph.out_edges(mem_id, data=True):
            result.add(data['attribute'], value)
        return result

    def retrieve(self, mem_id):  # noqa: D102
        if mem_id not in self.graph:
            return None
        return self._activate_and_return(mem_id)

    def query(self, attr_vals):  # noqa: D102
        # first pass: get candidates with all the attributes
        candidates = set.intersection(*(self.inverted_index[attribute]
                                        for attribute in attr_vals.keys()))
        # second pass: get candidates with the correct values
        candidates = set(candidate for candidate in candidates if all(
            ((candidate, value) in self.graph.edges and self.graph.
             get_edge_data(candidate, value)[0]['attribute'] == attribute)
            for attribute, value in attr_vals.items()))
        # quit early if there are no results
        if not candidates:
            self.query_results = None
            self.result_index = None
            return None
        # final pass: sort results by activation
        self.query_results = sorted(
            candidates,
            key=(lambda mem_id: self.graph.nodes[mem_id]['activation']),
            reverse=True,
        )
        self.result_index = 0
        return self._activate_and_return(self.query_results[self.result_index])

    @property
    def has_prev_result(self):  # noqa: D102
        return (self.query_results is not None and self.result_index > 0)

    def prev_result(self):  # noqa: D102
        self.result_index -= 1
        return self._activate_and_return(self.query_results[self.result_index])

    @property
    def has_next_result(self):  # noqa: D102
        return (self.query_results is not None
                and self.result_index < len(self.query_results) - 1)

    def next_result(self):  # noqa: D102
        self.result_index += 1
        return self._activate_and_return(self.query_results[self.result_index])

    @staticmethod
    def retrievable(mem_id):  # noqa: D102
        return isinstance(mem_id, Hashable)
Example #4
0
class NxGraph(BaseGraph):
    """
    NxGraph is a wrapper that provides methods to interact with a networkx.MultiDiGraph.

    NxGraph extends kgx.graph.base_graph.BaseGraph and implements all the methods from BaseGraph.
    """
    def __init__(self):
        super().__init__()
        self.graph = MultiDiGraph()
        self.name = None

    def add_node(self, node: str, **kwargs: Any) -> None:
        """
        Add a node to the graph.

        Parameters
        ----------
        node: str
            Node identifier
        **kwargs: Any
            Any additional node properties

        """
        if "data" in kwargs:
            data = kwargs["data"]
        else:
            data = kwargs
        self.graph.add_node(node, **data)

    def add_edge(self,
                 subject_node: str,
                 object_node: str,
                 edge_key: str = None,
                 **kwargs: Any) -> None:
        """
        Add an edge to the graph.

        Parameters
        ----------
        subject_node: str
            The subject (source) node
        object_node: str
            The object (target) node
        edge_key: Optional[str]
            The edge key
        kwargs: Any
            Any additional edge properties

        """
        if "data" in kwargs:
            data = kwargs["data"]
        else:
            data = kwargs
        return self.graph.add_edge(subject_node,
                                   object_node,
                                   key=edge_key,
                                   **data)

    def add_node_attribute(self, node: str, attr_key: str,
                           attr_value: Any) -> None:
        """
        Add an attribute to a given node.

        Parameters
        ----------
        node: str
            The node identifier
        attr_key: str
            The key for an attribute
        attr_value: Any
            The value corresponding to the key

        """
        self.graph.add_node(node, **{attr_key: attr_value})

    def add_edge_attribute(
        self,
        subject_node: str,
        object_node: str,
        edge_key: Optional[str],
        attr_key: str,
        attr_value: Any,
    ) -> None:
        """
        Add an attribute to a given edge.

        Parameters
        ----------
        subject_node: str
            The subject (source) node
        object_node: str
            The object (target) node
        edge_key: Optional[str]
            The edge key
        attr_key: str
            The attribute key
        attr_value: Any
            The attribute value

        """
        self.graph.add_edge(subject_node,
                            object_node,
                            key=edge_key,
                            **{attr_key: attr_value})

    def update_node_attribute(self,
                              node: str,
                              attr_key: str,
                              attr_value: Any,
                              preserve: bool = False) -> Dict:
        """
        Update an attribute of a given node.

        Parameters
        ----------
        node: str
            The node identifier
        attr_key: str
            The key for an attribute
        attr_value: Any
            The value corresponding to the key
        preserve: bool
            Whether or not to preserve existing values for the given attr_key

        Returns
        -------
        Dict
            A dictionary corresponding to the updated node properties

        """
        node_data = self.graph.nodes[node]
        updated = prepare_data_dict(node_data, {attr_key: attr_value},
                                    preserve=preserve)
        self.graph.add_node(node, **updated)
        return updated

    def update_edge_attribute(
        self,
        subject_node: str,
        object_node: str,
        edge_key: Optional[str],
        attr_key: str,
        attr_value: Any,
        preserve: bool = False,
    ) -> Dict:
        """
        Update an attribute of a given edge.

        Parameters
        ----------
        subject_node: str
            The subject (source) node
        object_node: str
            The object (target) node
        edge_key: Optional[str]
            The edge key
        attr_key: str
            The attribute key
        attr_value: Any
            The attribute value
        preserve: bool
            Whether or not to preserve existing values for the given attr_key

        Returns
        -------
        Dict
            A dictionary corresponding to the updated edge properties

        """
        e = self.graph.edges((subject_node, object_node, edge_key),
                             keys=True,
                             data=True)
        edge_data = list(e)[0][3]
        updated = prepare_data_dict(edge_data, {attr_key: attr_value},
                                    preserve)
        self.graph.add_edge(subject_node, object_node, key=edge_key, **updated)
        return updated

    def get_node(self, node: str) -> Dict:
        """
        Get a node and its properties.

        Parameters
        ----------
        node: str
            The node identifier

        Returns
        -------
        Dict
            The node dictionary

        """
        n = {}
        if self.graph.has_node(node):
            n = self.graph.nodes[node]
        return n

    def get_edge(self,
                 subject_node: str,
                 object_node: str,
                 edge_key: Optional[str] = None) -> Dict:
        """
        Get an edge and its properties.

        Parameters
        ----------
        subject_node: str
            The subject (source) node
        object_node: str
            The object (target) node
        edge_key: Optional[str]
            The edge key

        Returns
        -------
        Dict
            The edge dictionary

        """
        e = {}
        if self.graph.has_edge(subject_node, object_node, edge_key):
            e = self.graph.get_edge_data(subject_node, object_node, edge_key)
        return e

    def nodes(self, data: bool = True) -> Dict:
        """
        Get all nodes in a graph.

        Parameters
        ----------
        data: bool
            Whether or not to fetch node properties

        Returns
        -------
        Dict
            A dictionary of nodes

        """
        return self.graph.nodes(data)

    def edges(self, keys: bool = False, data: bool = True) -> Dict:
        """
        Get all edges in a graph.

        Parameters
        ----------
        keys: bool
            Whether or not to include edge keys
        data: bool
            Whether or not to fetch node properties

        Returns
        -------
        Dict
            A dictionary of edges

        """
        return self.graph.edges(keys=keys, data=data)

    def in_edges(self,
                 node: str,
                 keys: bool = False,
                 data: bool = False) -> List:
        """
        Get all incoming edges for a given node.

        Parameters
        ----------
        node: str
            The node identifier
        keys: bool
            Whether or not to include edge keys
        data: bool
            Whether or not to fetch node properties

        Returns
        -------
        List
            A list of edges

        """
        return self.graph.in_edges(node, keys=keys, data=data)

    def out_edges(self,
                  node: str,
                  keys: bool = False,
                  data: bool = False) -> List:
        """
        Get all outgoing edges for a given node.

        Parameters
        ----------
        node: str
            The node identifier
        keys: bool
            Whether or not to include edge keys
        data: bool
            Whether or not to fetch node properties

        Returns
        -------
        List
            A list of edges

        """
        return self.graph.out_edges(node, keys=keys, data=data)

    def nodes_iter(self) -> Generator:
        """
        Get an iterable to traverse through all the nodes in a graph.

        Returns
        -------
        Generator
            A generator for nodes where each element is a Tuple that
            contains (node_id, node_data)

        """
        for n in self.graph.nodes(data=True):
            yield n

    def edges_iter(self) -> Generator:
        """
        Get an iterable to traverse through all the edges in a graph.

        Returns
        -------
        Generator
            A generator for edges where each element is a 4-tuple that
            contains (subject, object, edge_key, edge_data)

        """
        for u, v, k, data in self.graph.edges(keys=True, data=True):
            yield u, v, k, data

    def remove_node(self, node: str) -> None:
        """
        Remove a given node from the graph.

        Parameters
        ----------
        node: str
            The node identifier

        """
        self.graph.remove_node(node)

    def remove_edge(self,
                    subject_node: str,
                    object_node: str,
                    edge_key: Optional[str] = None) -> None:
        """
        Remove a given edge from the graph.

        Parameters
        ----------
        subject_node: str
            The subject (source) node
        object_node: str
            The object (target) node
        edge_key: Optional[str]
            The edge key

        """
        self.graph.remove_edge(subject_node, object_node, edge_key)

    def has_node(self, node: str) -> bool:
        """
        Check whether a given node exists in the graph.

        Parameters
        ----------
        node: str
            The node identifier

        Returns
        -------
        bool
            Whether or not the given node exists

        """
        return self.graph.has_node(node)

    def has_edge(self,
                 subject_node: str,
                 object_node: str,
                 edge_key: Optional[str] = None) -> bool:
        """
        Check whether a given edge exists in the graph.

        Parameters
        ----------
        subject_node: str
            The subject (source) node
        object_node: str
            The object (target) node
        edge_key: Optional[str]
            The edge key

        Returns
        -------
        bool
            Whether or not the given edge exists

        """
        return self.graph.has_edge(subject_node, object_node, key=edge_key)

    def number_of_nodes(self) -> int:
        """
        Returns the number of nodes in a graph.

        Returns
        -------
        int

        """
        return self.graph.number_of_nodes()

    def number_of_edges(self) -> int:
        """
        Returns the number of edges in a graph.

        Returns
        -------
        int

        """
        return self.graph.number_of_edges()

    def degree(self):
        """
        Get the degree of all the nodes in a graph.
        """
        return self.graph.degree()

    def clear(self) -> None:
        """
        Remove all the nodes and edges in the graph.
        """
        self.graph.clear()

    @staticmethod
    def set_node_attributes(graph: BaseGraph, attributes: Dict) -> None:
        """
        Set nodes attributes from a dictionary of key-values.

        Parameters
        ----------
        graph: kgx.graph.base_graph.BaseGraph
            The graph to modify
        attributes: Dict
            A dictionary of node identifier to key-value pairs

        """
        return set_node_attributes(graph.graph, attributes)

    @staticmethod
    def set_edge_attributes(graph: BaseGraph, attributes: Dict) -> None:
        """
        Set nodes attributes from a dictionary of key-values.

        Parameters
        ----------
        graph: kgx.graph.base_graph.BaseGraph
            The graph to modify
        attributes: Dict
            A dictionary of node identifier to key-value pairs

        Returns
        -------
        Any

        """
        return set_edge_attributes(graph.graph, attributes)

    @staticmethod
    def get_node_attributes(graph: BaseGraph, attr_key: str) -> Dict:
        """
        Get all nodes that have a value for the given attribute ``attr_key``.

        Parameters
        ----------
        graph: kgx.graph.base_graph.BaseGraph
            The graph to modify
        attr_key: str
            The attribute key

        Returns
        -------
        Dict
            A dictionary where nodes are the keys and the values
            are the attribute values for ``key``

        """
        return get_node_attributes(graph.graph, attr_key)

    @staticmethod
    def get_edge_attributes(graph: BaseGraph, attr_key: str) -> Dict:
        """
        Get all edges that have a value for the given attribute ``attr_key``.

        Parameters
        ----------
        graph: kgx.graph.base_graph.BaseGraph
            The graph to modify
        attr_key: str
            The attribute key

        Returns
        -------
        Dict
            A dictionary where edges are the keys and the values
            are the attribute values for ``attr_key``

        """
        return get_edge_attributes(graph.graph, attr_key)

    @staticmethod
    def relabel_nodes(graph: BaseGraph, mapping: Dict) -> None:
        """
        Relabel identifiers for a series of nodes based on mappings.

        Parameters
        ----------
        graph: kgx.graph.base_graph.BaseGraph
            The graph to modify
        mapping: Dict
            A dictionary of mapping where the key is the old identifier
            and the value is the new identifier.

        """
        relabel_nodes(graph.graph, mapping, copy=False)
Example #5
0
class NetworkXLTM(LongTermMemory):
    """A NetworkX implementation of LTM."""

    def __init__(self, **kwargs):
        # type: (**Any) -> None
        """Initialize the NetworkXLTM.

        Parameters:
            **kwargs: Arbitrary keyword arguments.
        """
        super().__init__(**kwargs)
        self.graph = MultiDiGraph()
        self.inverted_index = defaultdict(set) # type: Dict[Hashable, Set[Hashable]]
        self.query_results = [] # type: List[Hashable]
        self.result_index = -1
        self.clear()

    def clear(self): # noqa: D102
        # type: () -> None
        self.graph.clear()
        self.inverted_index.clear()
        self.query_results = []
        self.result_index = -1

    def get_activation(self, mem_id):
        # type: (Hashable) -> float
        """Get the activation of a memory element.

        Parameters:
            mem_id (Hashable): The ID of the element.

        Returns:
            float: The activation of the element.
        """
        return self.graph.nodes[mem_id]['activation']

    def store(self, mem_id=None, time=0, **kwargs): # noqa: D102
        # type: (Hashable, int, **Any) -> bool
        if mem_id is None:
            mem_id = uuid()
        if mem_id not in self.graph:
            self.graph.add_node(mem_id, activation=0)
        else:
            self.activation_fn(self, mem_id, time)
        for attribute, value in kwargs.items():
            if value not in self.graph:
                self.graph.add_node(value, activation=0)
            self.graph.add_edge(mem_id, value, attribute=attribute)
            self.inverted_index[attribute].add(mem_id)
        return True

    def _activate_and_return(self, mem_id, time):
        # type: (Hashable, int) -> AVLTree
        self.activation_fn(self, mem_id, time)
        result = AVLTree()
        for _, value, data in self.graph.out_edges(mem_id, data=True):
            result.add(AttrVal(data['attribute'], value))
        return result

    def retrieve(self, mem_id, time=0): # noqa: D102
        # type: (Hashable, int) -> Optional[AVLTree]
        if mem_id not in self.graph:
            return None
        return self._activate_and_return(mem_id, time=time)

    def query(self, attr_vals, time=0): # noqa: D102
        # type: (AbstractSet[AttrVal], int) -> Optional[AVLTree]
        # first pass: get candidates with all the attributes
        attrs = set(attr for attr, _ in attr_vals)
        candidates = set.intersection(*(
            self.inverted_index[attribute] for attribute in attrs
        ))
        # second pass: get candidates with the correct values
        candidates = set(
            candidate for candidate in candidates
            if all((
                (candidate, val) in self.graph.edges
                and any(
                    attr_dict['attribute'] == attr
                    for attr_dict in self.graph.get_edge_data(candidate, val).values()
                )
            ) for attr, val in attr_vals)
        )
        # quit early if there are no results
        if not candidates:
            self.query_results = []
            self.result_index = -1
            return None
        # final pass: sort results by activation
        self.query_results = sorted(
            candidates,
            key=(lambda mem_id: self.graph.nodes[mem_id]['activation']),
            reverse=True,
        )
        self.result_index = 0
        return self._activate_and_return(self.query_results[self.result_index], time=time)

    @property
    def has_prev_result(self): # noqa: D102
        # type: () -> bool
        return bool(self.query_results) and self.result_index > 0

    def prev_result(self, time=0): # noqa: D102
        # type: (int) -> Optional[AVLTree]
        self.result_index -= 1
        return self._activate_and_return(self.query_results[self.result_index], time=time)

    @property
    def has_next_result(self): # noqa: D102
        # type: () -> bool
        return bool(self.query_results) and -1 < self.result_index < len(self.query_results) - 1

    def next_result(self, time=0): # noqa: D102
        # type: (int) -> Optional[AVLTree]
        self.result_index += 1
        return self._activate_and_return(self.query_results[self.result_index], time=time)

    @staticmethod
    def retrievable(mem_id): # noqa: D102
        # type: (Any) -> bool
        return mem_id is not None and isinstance(mem_id, Hashable)