예제 #1
0
def dictToGraphCPD(graphNoTable: gz.Digraph,
                   variables: Dict[Name, Dict]) -> gz.Digraph:
    # import random

    g = graphNoTable.copy()  # make this just a getter, not a setter also!

    for var, values in variables.items():
        g.attr('node', shape='plaintext')

        grids = dictToGrid(var, values)

        table = dictToTable(var, values, grids)

        random.seed(hash(table))

        #g.node('cpd_' + var, label=table, color='gray')
        g.node('cpd_' + var, label=table, color='gray',
               fillcolor='white')  #, color='gray')

        if random.randint(0, 1):
            g.edge('cpd_' + var, var, style='invis')
        else:
            g.edge(var, 'cpd_' + var, style='invis')

    return g
예제 #2
0
파일: claim.py 프로젝트: JoshC8C7/AFCProj
    def generate_graph(self, oie_subclaims, output=False):

        # The graph is directed as all edges are either implication or property-of relations. Use strict Digraphs to
        # prevent any duplicate edges.
        G = Digraph(strict=True, format='pdf')
        arg_set = set()
        verb_set = set()
        coref_nodes, coref_edges = [], []

        # Create the overall leaf, for which the conjunction of the root verbs will imply:
        G.node(self.argID(self.doc[:]), self.doc.text, shape='box')

        # Iterate through all extracted relations
        seenRoots = []
        for claim in oie_subclaims:

            # Plot the verb after having converted it to a node.
            root = claim['V']
            seenRoots.append(root)
            check = ''.join(filter(str.isalnum, str(root))).replace(' ', '')
            if check == "":  # Sometimes some nonsense can be attributed as a verb by oie.
                continue

            # Add arg to graph, with a helpful label for optics.
            G.node(
                self.argID(root), root.text + '/' +
                self.argBaseC[self.argID(root)].get_uvi_output())

            for arg_type, argV in claim.items():
                if arg_type != 'V' and argV not in seenRoots:
                    # Create a node for each argument, and a link to its respective verb labelled by its arg type.
                    G.node(self.argID(argV),
                           argV.text + "/" + str(argV.ents) + "/" +
                           str(list(argV.noun_chunks)),
                           shape='box')
                    G.edge(self.argID(argV),
                           self.argID(root),
                           label=arg_type.replace('-', 'x'),
                           style=get_edge_style(arg_type, argV))
                    # Replace any '-' with 'x' as '-' is a keyword for networkx, but is output by SRL.

                    # Add the argument to the list of arguments eligible to be implied by another subtree.
                    arg_set.add(argV)

                    # Add coreference edges to the graph from the initial text to the entity being coreferenced, only
                    # for the version of the graph that is displayed. Co-references must be omitted from the true
                    # graph as they interfere with splitting into subclaims as the edges becomes bridges. The
                    # coreferences themselves are not lost as they're properties of the doc/spans. They are useful for
                    # illustratory and debugging purposes, and so can be output when requested with output=True.
                    # This is sound to omit from the networkx graph as a coreference does not result in two claims
                    # being co-dependent e.g. 'My son is 11. He likes to eat cake.' - the coreference bridges the two
                    # otherwise separate components when there should be no co-dependence implied. Both are about the
                    # son, but there is not an iff relation between them.
                    arg_corefs = get_span_coref(argV)
                    if output and len(arg_corefs):
                        for cluster in arg_corefs:
                            canonical_reference = cluster.main
                            for inst in cluster.mentions:
                                if inst.start >= argV.start and inst.end <= argV.end and inst != canonical_reference:
                                    coref_nodes.append(
                                        (self.argID(canonical_reference),
                                         canonical_reference.text + "/" +
                                         str(canonical_reference.ents)))
                                    coref_edges.append(
                                        (self.argID(argV),
                                         self.argID(canonical_reference),
                                         inst.text))

                # Add all verbs to the list of eligible roots for subtrees.
                else:
                    verb_set.add(argV)

        # Create 'purple' edges - i.e. edges that link a verb rooting a subtree (of size >= 1) to the argument that they
        # imply. Only one verb can imply any one argument in order to preserve tree-like structure.
        for argV in verb_set:
            shortest_span = self.doc[:]
            for parent in arg_set:
                if argV != parent and argV.start >= parent.start and argV.end <= parent.end and (
                        parent.end - parent.start) < (shortest_span.end -
                                                      shortest_span.start):
                    shortest_span = parent
            G.node(self.argID(shortest_span), shortest_span.text, shape='box')
            G.edge(self.argID(argV), self.argID(shortest_span), color='violet')

        # If visual output requested, then add coref edges determined earlier to a copy of the graph and return that.
        # The returned graph is identical except for nodes created solely as coreference components and the green edges.
        if output:
            H = G.copy()
            for node in coref_nodes:
                H.node(node[0], node[1], shape='box')
            for edge in coref_edges:
                H.edge(edge[0], edge[1], color='green', label=edge[2])
            H.view()
        return G
예제 #3
0
class Graph:
    """a class to create graphviz graphs of the AiiDA node provenance"""
    def __init__(self,
                 engine=None,
                 graph_attr=None,
                 global_node_style=None,
                 global_edge_style=None,
                 include_sublabels=True,
                 link_style_fn=None,
                 node_style_fn=None,
                 node_sublabel_fn=None,
                 node_id_type='pk'):
        """a class to create graphviz graphs of the AiiDA node provenance

        Nodes and edges, are cached, so that they are only created once

        :param engine: the graphviz engine, e.g. dot, circo (Default value = None)
        :type engine: str or None
        :param graph_attr: attributes for the graphviz graph (Default value = None)
        :type graph_attr: dict or None
        :param global_node_style: styles which will be added to all nodes.
            Note this will override any builtin attributes (Default value = None)
        :type global_node_style: dict or None
        :param global_edge_style: styles which will be added to all edges.
            Note this will override any builtin attributes (Default value = None)
        :type global_edge_style: dict or None
        :param include_sublabels: if True, the note text will include node dependant sub-labels (Default value = True)
        :type include_sublabels: bool
        :param link_style_fn: callable mapping LinkType to graphviz style dict;
            link_style_fn(link_type) -> dict (Default value = None)
        :param node_sublabel_fn: callable mapping nodes to a graphviz style dict;
            node_sublabel_fn(node) -> dict (Default value = None)
        :param node_sublabel_fn: callable mapping data node to a sublabel (e.g. specifying some attribute values)
            node_sublabel_fn(node) -> str (Default value = None)
        :param node_id_type: the type of identifier to within the node text ('pk', 'uuid' or 'label')
        :type node_id_type: str
        """
        # pylint: disable=too-many-arguments

        self._graph = Digraph(engine=engine, graph_attr=graph_attr)
        self._nodes = set()
        self._edges = set()
        self._global_node_style = global_node_style or {}
        self._global_edge_style = global_edge_style or {}
        self._include_sublabels = include_sublabels
        self._link_styles = link_style_fn or default_link_styles
        self._node_styles = node_style_fn or default_node_styles
        self._node_sublabels = node_sublabel_fn or default_node_sublabels
        self._node_id_type = node_id_type

    @property
    def graphviz(self):
        """return a copy of the graphviz.Digraph"""
        return self._graph.copy()

    @property
    def nodes(self):
        """return a copy of the nodes"""
        return self._nodes.copy()

    @property
    def edges(self):
        """return a copy of the edges"""
        return self._edges.copy()

    @staticmethod
    def _load_node(node):
        """ load a node (if not already loaded)

        :param node: node or node pk/uuid
        :type node: int or str or aiida.orm.nodes.node.Node
        :returns: aiida.orm.nodes.node.Node
        """
        if isinstance(node, (int, str)):
            return orm.load_node(node)
        return node

    @staticmethod
    def _default_link_types(link_types):
        """If link_types is empty, it will return all the links_types

        :param links: iterable with the link_types ()
        :returns: list of :py:class:`aiida.common.links.LinkType`
        """
        if not link_types:
            all_link_types = [LinkType.CREATE]
            all_link_types.append(LinkType.RETURN)
            all_link_types.append(LinkType.INPUT_CALC)
            all_link_types.append(LinkType.INPUT_WORK)
            all_link_types.append(LinkType.CALL_CALC)
            all_link_types.append(LinkType.CALL_WORK)
            return all_link_types

        return link_types

    def add_node(self, node, style_override=None, overwrite=False):
        """add single node to the graph

        :param node: node or node pk/uuid
        :type node: int or str or aiida.orm.nodes.node.Node
        :param style_override: graphviz style parameters that will override default values
        :type style_override: dict or None
        :param overwrite: whether to overrite an existing node (Default value = False)
        :type overwrite: bool
        """
        node = self._load_node(node)
        style = {} if style_override is None else style_override
        style.update(self._global_node_style)
        if node.pk not in self._nodes or overwrite:
            _add_graphviz_node(self._graph,
                               node,
                               node_style_func=self._node_styles,
                               node_sublabel_func=self._node_sublabels,
                               style_override=style,
                               include_sublabels=self._include_sublabels,
                               id_type=self._node_id_type)
            self._nodes.add(node.pk)
        return node

    def add_edge(self,
                 in_node,
                 out_node,
                 link_pair=None,
                 style=None,
                 overwrite=False):
        """add single node to the graph

        :param in_node: node or node pk/uuid
        :type in_node: int or aiida.orm.nodes.node.Node
        :param out_node: node or node pk/uuid
        :type out_node: int or str or aiida.orm.nodes.node.Node
        :param link_pair: defining the relationship between the nodes
        :type link_pair: None or aiida.orm.utils.links.LinkPair
        :param style: graphviz style parameters (Default value = None)
        :type style: dict or None
        :param overwrite: whether to overrite existing edge (Default value = False)
        :type overwrite: bool
        """
        in_node = self._load_node(in_node)
        if in_node.pk not in self._nodes:
            raise AssertionError(
                'in_node pk={} must have already been added to the graph'.
                format(in_node.pk))
        out_node = self._load_node(out_node)
        if out_node.pk not in self._nodes:
            raise AssertionError(
                'out_node pk={} must have already been added to the graph'.
                format(out_node.pk))

        if (in_node.pk, out_node.pk,
                link_pair) in self._edges and not overwrite:
            return

        style = {} if style is None else style
        self._edges.add((in_node.pk, out_node.pk, link_pair))
        style.update(self._global_edge_style)

        _add_graphviz_edge(self._graph, in_node, out_node, style)

    @staticmethod
    def _convert_link_types(link_types):
        """convert link types, which may be strings, to a member of LinkType"""
        if link_types is None:
            return None
        if isinstance(link_types, str):
            link_types = [link_types]
        link_types = tuple([
            getattr(LinkType, l.upper()) if isinstance(l, str) else l
            for l in link_types
        ])
        return link_types

    def add_incoming(self,
                     node,
                     link_types=(),
                     annotate_links=None,
                     return_pks=True):
        """add nodes and edges for incoming links to a node

        :param node: node or node pk/uuid
        :type node: aiida.orm.nodes.node.Node or int
        :param link_types: filter by link types (Default value = ())
        :type link_types: str or tuple[str] or aiida.common.links.LinkType or tuple[aiida.common.links.LinkType]
        :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = None)
        :type annotate_links: bool or str
        :param return_pks: whether to return a list of nodes, or list of node pks (Default value = True)
        :type return_pks: bool
        :returns: list of nodes or node pks
        """
        if annotate_links not in [None, False, 'label', 'type', 'both']:
            raise ValueError(
                'annotate_links must be one of False, "label", "type" or "both"\ninstead, it is: {}'
                .format(annotate_links))

        # incoming nodes are found traversing backwards
        node_pk = node if isinstance(node, int) else node.pk
        valid_link_types = self._default_link_types(link_types)
        valid_link_types = self._convert_link_types(valid_link_types)
        traversed_graph = traverse_graph(
            (node_pk, ),
            max_iterations=1,
            get_links=True,
            links_backward=valid_link_types,
        )

        traversed_nodes = orm.QueryBuilder().append(
            orm.Node,
            filters={'id': {
                'in': traversed_graph['nodes']
            }},
            project=['id', '*'],
            tag='node',
        )
        traversed_nodes = {
            query_result[0]: query_result[1]
            for query_result in traversed_nodes.all()
        }

        for _, traversed_node in traversed_nodes.items():
            self.add_node(traversed_node, style_override=None)

        for link in traversed_graph['links']:
            source_node = traversed_nodes[link.source_id]
            target_node = traversed_nodes[link.target_id]
            link_pair = LinkPair(
                self._convert_link_types(link.link_type)[0], link.link_label)
            link_style = self._link_styles(link_pair,
                                           add_label=annotate_links
                                           in ['label', 'both'],
                                           add_type=annotate_links
                                           in ['type', 'both'])
            self.add_edge(source_node,
                          target_node,
                          link_pair,
                          style=link_style)

        if return_pks:
            return list(traversed_nodes.keys())
        # else:
        return list(traversed_nodes.values())

    def add_outgoing(self,
                     node,
                     link_types=(),
                     annotate_links=None,
                     return_pks=True):
        """add nodes and edges for outgoing links to a node

        :param node: node or node pk
        :type node: aiida.orm.nodes.node.Node or int
        :param link_types: filter by link types (Default value = ())
        :type link_types: str or tuple[str] or aiida.common.links.LinkType or tuple[aiida.common.links.LinkType]
        :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = None)
        :type annotate_links: bool or str
        :param return_pks: whether to return a list of nodes, or list of node pks (Default value = True)
        :type return_pks: bool
        :returns: list of nodes or node pks
        """
        if annotate_links not in [None, False, 'label', 'type', 'both']:
            raise ValueError(
                'annotate_links must be one of False, "label", "type" or "both"\ninstead, it is: {}'
                .format(annotate_links))

        # outgoing nodes are found traversing forwards
        node_pk = node if isinstance(node, int) else node.pk
        valid_link_types = self._default_link_types(link_types)
        valid_link_types = self._convert_link_types(valid_link_types)
        traversed_graph = traverse_graph(
            (node_pk, ),
            max_iterations=1,
            get_links=True,
            links_forward=valid_link_types,
        )

        traversed_nodes = orm.QueryBuilder().append(
            orm.Node,
            filters={'id': {
                'in': traversed_graph['nodes']
            }},
            project=['id', '*'],
            tag='node',
        )
        traversed_nodes = {
            query_result[0]: query_result[1]
            for query_result in traversed_nodes.all()
        }

        for _, traversed_node in traversed_nodes.items():
            self.add_node(traversed_node, style_override=None)

        for link in traversed_graph['links']:
            source_node = traversed_nodes[link.source_id]
            target_node = traversed_nodes[link.target_id]
            link_pair = LinkPair(
                self._convert_link_types(link.link_type)[0], link.link_label)
            link_style = self._link_styles(link_pair,
                                           add_label=annotate_links
                                           in ['label', 'both'],
                                           add_type=annotate_links
                                           in ['type', 'both'])
            self.add_edge(source_node,
                          target_node,
                          link_pair,
                          style=link_style)

        if return_pks:
            return list(traversed_nodes.keys())
        # else:
        return list(traversed_nodes.values())

    def recurse_descendants(self,
                            origin,
                            depth=None,
                            link_types=(),
                            annotate_links=False,
                            origin_style=None,
                            include_process_inputs=False,
                            print_func=None):
        """add nodes and edges from an origin recursively,
        following outgoing links

        :param origin: node or node pk/uuid
        :type origin: aiida.orm.nodes.node.Node or int
        :param depth: if not None, stop after travelling a certain depth into the graph (Default value = None)
        :type depth: None or int
        :param link_types: filter by subset of link types (Default value = ())
        :type link_types: tuple or str
        :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = False)
        :type annotate_links: bool or str
        :param origin_style: node style map for origin node (Default value = None)
        :type origin_style: None or dict
        :param include_calculation_inputs: include incoming links for all processes (Default value = False)
        :type include_calculation_inputs: bool
        :param print_func:
            a function to stream information to, i.e. print_func(str)
            (this feature is deprecated since `v1.1.0` and will be removed in `v2.0.0`)

        """
        # pylint: disable=too-many-arguments,too-many-locals
        import warnings
        from aiida.common.warnings import AiidaDeprecationWarning
        if print_func:
            warnings.warn(  # pylint: disable=no-member
                '`print_func` is deprecated because graph traversal has been refactored',
                AiidaDeprecationWarning)

        # Get graph traversal rules where the given link types and direction are all set to True,
        # and all others are set to False
        origin_pk = origin if isinstance(origin, int) else origin.pk
        valid_link_types = self._default_link_types(link_types)
        valid_link_types = self._convert_link_types(valid_link_types)
        traversed_graph = traverse_graph(
            (origin_pk, ),
            max_iterations=depth,
            get_links=True,
            links_forward=valid_link_types,
        )

        # Traverse backward along input_work and input_calc links from all nodes traversed in the previous step
        # and join the result with the original traversed graph. This includes calculation inputs in the Graph
        if include_process_inputs:
            traversed_outputs = traverse_graph(
                traversed_graph['nodes'],
                max_iterations=1,
                get_links=True,
                links_backward=[LinkType.INPUT_WORK, LinkType.INPUT_CALC])
            traversed_graph['nodes'] = traversed_graph['nodes'].union(
                traversed_outputs['nodes'])
            traversed_graph['links'] = traversed_graph['links'].union(
                traversed_outputs['links'])

        # Do one central query for all nodes in the Graph and generate a {id: Node} dictionary
        traversed_nodes = orm.QueryBuilder().append(
            orm.Node,
            filters={'id': {
                'in': traversed_graph['nodes']
            }},
            project=['id', '*'],
            tag='node',
        )
        traversed_nodes = {
            query_result[0]: query_result[1]
            for query_result in traversed_nodes.all()
        }

        # Pop the origin node and add it to the graph, applying custom styling
        origin_node = traversed_nodes.pop(origin_pk)
        self.add_node(origin_node, style_override=origin_style)

        # Add all traversed nodes to the graph with default styling
        for _, traversed_node in traversed_nodes.items():
            self.add_node(traversed_node, style_override={})

        # Add the origin node back into traversed nodes so it can be found for adding edges
        traversed_nodes[origin_pk] = origin_node

        # Add all links to the Graph, using the {id: Node} dictionary for queryless Node retrieval, applying
        # appropriate styling
        for link in traversed_graph['links']:
            source_node = traversed_nodes[link.source_id]
            target_node = traversed_nodes[link.target_id]
            link_pair = LinkPair(
                self._convert_link_types(link.link_type)[0], link.link_label)
            link_style = self._link_styles(link_pair,
                                           add_label=annotate_links
                                           in ['label', 'both'],
                                           add_type=annotate_links
                                           in ['type', 'both'])
            self.add_edge(source_node,
                          target_node,
                          link_pair,
                          style=link_style)

    def recurse_ancestors(self,
                          origin,
                          depth=None,
                          link_types=(),
                          annotate_links=False,
                          origin_style=None,
                          include_process_outputs=False,
                          print_func=None):
        """add nodes and edges from an origin recursively,
        following incoming links

        :param origin: node or node pk/uuid
        :type origin: aiida.orm.nodes.node.Node or int
        :param depth: if not None, stop after travelling a certain depth into the graph (Default value = None)
        :type depth: None or int
        :param link_types: filter by subset of link types (Default value = ())
        :type link_types: tuple or str
        :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = False)
        :type annotate_links: bool
        :param origin_style: node style map for origin node (Default value = None)
        :type origin_style: None or dict
        :param include_process_outputs:  include outgoing links for all processes (Default value = False)
        :type include_process_outputs: bool
        :param print_func: a function to stream information to, i.e. print_func(str)

        .. deprecated:: 1.1.0
            `print_func` will be removed in `v2.0.0`
        """
        # pylint: disable=too-many-arguments,too-many-locals
        import warnings
        from aiida.common.warnings import AiidaDeprecationWarning
        if print_func:
            warnings.warn(  # pylint: disable=no-member
                '`print_func` is deprecated because graph traversal has been refactored',
                AiidaDeprecationWarning)

        # Get graph traversal rules where the given link types and direction are all set to True,
        # and all others are set to False
        origin_pk = origin if isinstance(origin, int) else origin.pk
        valid_link_types = self._default_link_types(link_types)
        valid_link_types = self._convert_link_types(valid_link_types)
        traversed_graph = traverse_graph(
            (origin_pk, ),
            max_iterations=depth,
            get_links=True,
            links_backward=valid_link_types,
        )

        # Traverse forward along input_work and input_calc links from all nodes traversed in the previous step
        # and join the result with the original traversed graph. This includes calculation outputs in the Graph
        if include_process_outputs:
            traversed_outputs = traverse_graph(
                traversed_graph['nodes'],
                max_iterations=1,
                get_links=True,
                links_forward=[LinkType.CREATE, LinkType.RETURN])
            traversed_graph['nodes'] = traversed_graph['nodes'].union(
                traversed_outputs['nodes'])
            traversed_graph['links'] = traversed_graph['links'].union(
                traversed_outputs['links'])

        # Do one central query for all nodes in the Graph and generate a {id: Node} dictionary
        traversed_nodes = orm.QueryBuilder().append(
            orm.Node,
            filters={'id': {
                'in': traversed_graph['nodes']
            }},
            project=['id', '*'],
            tag='node',
        )
        traversed_nodes = {
            query_result[0]: query_result[1]
            for query_result in traversed_nodes.all()
        }

        # Pop the origin node and add it to the graph, applying custom styling
        origin_node = traversed_nodes.pop(origin_pk)
        self.add_node(origin_node, style_override=origin_style)

        # Add all traversed nodes to the graph with default styling
        for _, traversed_node in traversed_nodes.items():
            self.add_node(traversed_node, style_override=None)

        # Add the origin node back into traversed nodes so it can be found for adding edges
        traversed_nodes[origin_pk] = origin_node

        # Add all links to the Graph, using the {id: Node} dictionary for queryless Node retrieval, applying
        # appropriate styling
        for link in traversed_graph['links']:
            source_node = traversed_nodes[link.source_id]
            target_node = traversed_nodes[link.target_id]
            link_pair = LinkPair(
                self._convert_link_types(link.link_type)[0], link.link_label)
            link_style = self._link_styles(link_pair,
                                           add_label=annotate_links
                                           in ['label', 'both'],
                                           add_type=annotate_links
                                           in ['type', 'both'])
            self.add_edge(source_node,
                          target_node,
                          link_pair,
                          style=link_style)

    def add_origin_to_targets(self,
                              origin,
                              target_cls,
                              target_filters=None,
                              include_target_inputs=False,
                              include_target_outputs=False,
                              origin_style=(),
                              annotate_links=False):
        """Add nodes and edges from an origin node to all nodes of a target node class.

        :param origin: node or node pk/uuid
        :type origin: aiida.orm.nodes.node.Node or int
        :param target_cls: target node class
        :param target_filters:  (Default value = None)
        :type target_filters: dict or None
        :param include_target_inputs:  (Default value = False)
        :type include_target_inputs: bool
        :param include_target_outputs:  (Default value = False)
        :type include_target_outputs: bool
        :param origin_style: node style map for origin node (Default value = ())
        :type origin_style: dict or tuple
        :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = False)
        :type annotate_links: bool
        """
        # pylint: disable=too-many-arguments
        origin_node = self._load_node(origin)

        if target_filters is None:
            target_filters = {}

        self.add_node(origin_node, style_override=dict(origin_style))

        query = orm.QueryBuilder(
            **{
                'path': [{
                    'cls': origin_node.__class__,
                    'filters': {
                        'id': origin_node.pk
                    },
                    'tag': 'origin'
                }, {
                    'cls': target_cls,
                    'filters': target_filters,
                    'with_ancestors': 'origin',
                    'tag': 'target',
                    'project': '*'
                }]
            })

        for (target_node, ) in query.iterall():
            self.add_node(target_node)
            self.add_edge(origin_node,
                          target_node,
                          style={
                              'style': 'dashed',
                              'color': 'grey'
                          })

            if include_target_inputs:
                self.add_incoming(target_node, annotate_links=annotate_links)

            if include_target_outputs:
                self.add_outgoing(target_node, annotate_links=annotate_links)

    def add_origins_to_targets(self,
                               origin_cls,
                               target_cls,
                               origin_filters=None,
                               target_filters=None,
                               include_target_inputs=False,
                               include_target_outputs=False,
                               origin_style=(),
                               annotate_links=False):
        """Add nodes and edges from all nodes of an origin class to all node of a target node class.

        :param origin_cls: origin node class
        :param target_cls: target node class
        :param origin_filters:  (Default value = None)
        :type origin_filters: dict or None
        :param target_filters:  (Default value = None)
        :type target_filters: dict or None
        :param include_target_inputs:  (Default value = False)
        :type include_target_inputs: bool
        :param include_target_outputs:  (Default value = False)
        :type include_target_outputs: bool
        :param origin_style: node style map for origin node (Default value = ())
        :type origin_style: dict or tuple
        :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = False)
        :type annotate_links: bool
        """
        # pylint: disable=too-many-arguments
        if origin_filters is None:
            origin_filters = {}

        query = orm.QueryBuilder(
            **{
                'path': [{
                    'cls': origin_cls,
                    'filters': origin_filters,
                    'tag': 'origin',
                    'project': '*'
                }]
            })

        for (node, ) in query.iterall():
            self.add_origin_to_targets(
                node,
                target_cls,
                target_filters=target_filters,
                include_target_inputs=include_target_inputs,
                include_target_outputs=include_target_outputs,
                origin_style=origin_style,
                annotate_links=annotate_links)
예제 #4
0
class Graph(object):
    """a class to create graphviz graphs of the AiiDA node provenance"""

    def __init__(
        self,
        engine=None,
        graph_attr=None,
        global_node_style=None,
        global_edge_style=None,
        include_sublabels=True,
        link_style_fn=None,
        node_style_fn=None,
        node_sublabel_fn=None,
        node_id_type='pk'
    ):
        """a class to create graphviz graphs of the AiiDA node provenance

        Nodes and edges, are cached, so that they are only created once

        :param engine: the graphviz engine, e.g. dot, circo (Default value = None)
        :type engine: str or None
        :param graph_attr: attributes for the graphviz graph (Default value = None)
        :type graph_attr: dict or None
        :param global_node_style: styles which will be added to all nodes.
            Note this will override any builtin attributes (Default value = None)
        :type global_node_style: dict or None
        :param global_edge_style: styles which will be added to all edges.
            Note this will override any builtin attributes (Default value = None)
        :type global_edge_style: dict or None
        :param include_sublabels: if True, the note text will include node dependant sub-labels (Default value = True)
        :type include_sublabels: bool
        :param link_style_fn: callable mapping LinkType to graphviz style dict;
            link_style_fn(link_type) -> dict (Default value = None)
        :param node_sublabel_fn: callable mapping nodes to a graphviz style dict;
            node_sublabel_fn(node) -> dict (Default value = None)
        :param node_sublabel_fn: callable mapping data node to a sublabel (e.g. specifying some attribute values)
            node_sublabel_fn(node) -> str (Default value = None)
        :param node_id_type: the type of identifier to within the node text ('pk', 'uuid' or 'label')
        :type node_id_type: str

        """
        # pylint: disable=too-many-arguments
        self._graph = Digraph(engine=engine, graph_attr=graph_attr)
        self._nodes = set()
        self._edges = set()
        self._global_node_style = global_node_style or {}
        self._global_edge_style = global_edge_style or {}
        self._include_sublabels = include_sublabels
        self._link_styles = link_style_fn or default_link_styles
        self._node_styles = node_style_fn or default_node_styles
        self._node_sublabels = node_sublabel_fn or default_node_sublabels
        self._node_id_type = node_id_type

    @property
    def graphviz(self):
        """return a copy of the graphviz.Digraph"""
        return self._graph.copy()

    @property
    def nodes(self):
        """return a copy of the nodes"""
        return self._nodes.copy()

    @property
    def edges(self):
        """return a copy of the edges"""
        return self._edges.copy()

    @staticmethod
    def _load_node(node):
        """ load a node (if not already loaded)

        :param node: node or node pk/uuid
        :type node: int or str or aiida.orm.nodes.node.Node
        :returns: aiida.orm.nodes.node.Node

        """
        if isinstance(node, (int, six.string_types)):
            return load_node(node)
        return node

    def add_node(self, node, style_override=None, overwrite=False):
        """add single node to the graph

        :param node: node or node pk/uuid
        :type node: int or str or aiida.orm.nodes.node.Node
        :param style_override: graphviz style parameters that will override default values
        :type style_override: dict or None
        :param overwrite: whether to overrite an existing node (Default value = False)
        :type overwrite: bool

        """
        node = self._load_node(node)
        style = {} if style_override is None else style_override
        style.update(self._global_node_style)
        if node.pk not in self._nodes or overwrite:
            _add_graphviz_node(
                self._graph,
                node,
                node_style_func=self._node_styles,
                node_sublabel_func=self._node_sublabels,
                style_override=style,
                include_sublabels=self._include_sublabels,
                id_type=self._node_id_type
            )
            self._nodes.add(node.pk)
        return node

    def add_edge(self, in_node, out_node, link_pair=None, style=None, overwrite=False):
        """add single node to the graph

        :param in_node: node or node pk/uuid
        :type in_node: int or aiida.orm.nodes.node.Node
        :param out_node: node or node pk/uuid
        :type out_node: int or str or aiida.orm.nodes.node.Node
        :param link_pair: defining the relationship between the nodes
        :type link_pair: None or aiida.orm.utils.links.LinkPair
        :param style: graphviz style parameters (Default value = None)
        :type style: dict or None
        :param overwrite: whether to overrite existing edge (Default value = False)
        :type overwrite: bool

        """
        in_node = self._load_node(in_node)
        if in_node.pk not in self._nodes:
            raise AssertionError('in_node pk={} must have already been added to the graph'.format(in_node.pk))
        out_node = self._load_node(out_node)
        if out_node.pk not in self._nodes:
            raise AssertionError('out_node pk={} must have already been added to the graph'.format(out_node.pk))

        if (in_node.pk, out_node.pk, link_pair) in self._edges and not overwrite:
            return

        style = {} if style is None else style
        self._edges.add((in_node.pk, out_node.pk, link_pair))
        style.update(self._global_edge_style)

        _add_graphviz_edge(self._graph, in_node, out_node, style)

    @staticmethod
    def _convert_link_types(link_types):
        """ convert link types, which may be strings, to a member of LinkType
        """
        if link_types is None:
            return None
        if isinstance(link_types, six.string_types):
            link_types = [link_types]
        link_types = tuple([getattr(LinkType, l.upper()) if isinstance(l, six.string_types) else l for l in link_types])
        return link_types

    def add_incoming(self, node, link_types=(), annotate_links=None, return_pks=True):
        """add nodes and edges for incoming links to a node

        :param node: node or node pk/uuid
        :type node: aiida.orm.nodes.node.Node or int
        :param link_types: filter by link types (Default value = ())
        :type link_types: str or tuple[str] or aiida.common.links.LinkType or tuple[aiida.common.links.LinkType]
        :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = None)
        :type annotate_links: bool or str
        :param return_pks: whether to return a list of nodes, or list of node pks (Default value = True)
        :type return_pks: bool
        :returns: list of nodes or node pks

        """
        if annotate_links not in [None, False, 'label', 'type', 'both']:
            raise AssertionError('annotate_links must be one of False, "label", "type" or "both"')

        node = self.add_node(node)

        nodes = []
        for link_triple in node.get_incoming(link_type=self._convert_link_types(link_types)).link_triples:
            self.add_node(link_triple.node)
            link_pair = LinkPair(link_triple.link_type, link_triple.link_label)
            style = self._link_styles(
                link_pair, add_label=annotate_links in ['label', 'both'], add_type=annotate_links in ['type', 'both']
            )
            self.add_edge(link_triple.node, node, link_pair, style=style)
            nodes.append(link_triple.node.pk if return_pks else link_triple.node)

        return nodes

    def add_outgoing(self, node, link_types=(), annotate_links=None, return_pks=True):
        """add nodes and edges for outgoing links to a node

        :param node: node or node pk
        :type node: aiida.orm.nodes.node.Node or int
        :param link_types: filter by link types (Default value = ())
        :type link_types: str or tuple[str] or aiida.common.links.LinkType or tuple[aiida.common.links.LinkType]
        :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = None)
        :type annotate_links: bool or str
        :param return_pks: whether to return a list of nodes, or list of node pks (Default value = True)
        :type return_pks: bool
        :returns: list of nodes or node pks

        """
        if annotate_links not in [None, False, 'label', 'type', 'both']:
            raise AssertionError('annotate_links must be one of False, "label", "type" or "both"')

        node = self.add_node(node)

        nodes = []
        for link_triple in node.get_outgoing(link_type=self._convert_link_types(link_types)).link_triples:
            self.add_node(link_triple.node)
            link_pair = LinkPair(link_triple.link_type, link_triple.link_label)
            style = self._link_styles(
                link_pair, add_label=annotate_links in ['label', 'both'], add_type=annotate_links in ['type', 'both']
            )
            self.add_edge(node, link_triple.node, link_pair, style=style)
            nodes.append(link_triple.node.pk if return_pks else link_triple.node)

        return nodes

    def recurse_descendants(
        self,
        origin,
        depth=None,
        link_types=(),
        annotate_links=False,
        origin_style=(),
        include_process_inputs=False,
        print_func=None
    ):
        """add nodes and edges from an origin recursively,
        following outgoing links

        :param origin: node or node pk/uuid
        :type origin: aiida.orm.nodes.node.Node or int
        :param depth: if not None, stop after travelling a certain depth into the graph (Default value = None)
        :type depth: None or int
        :param link_types: filter by subset of link types (Default value = ())
        :type link_types: tuple or str
        :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = False)
        :type annotate_links: bool or str
        :param origin_style: node style map for origin node (Default value = ())
        :type origin_style: dict or tuple
        :param include_calculation_inputs: include incoming links for all processes (Default value = False)
        :type include_calculation_inputs: bool
        :param print_func: a function to stream information to, i.e. print_func(str)

        """
        # pylint: disable=too-many-arguments
        origin_node = self._load_node(origin)

        self.add_node(origin_node, style_override=dict(origin_style))

        leaf_nodes = [origin_node]
        traversed_pks = [origin_node.pk]
        cur_depth = 0
        while leaf_nodes:
            cur_depth += 1
            # checking of maximum descendant depth is set and applies.
            if depth is not None and cur_depth > depth:
                break
            if print_func:
                print_func('- Depth: {}'.format(cur_depth))
            new_nodes = []
            for node in leaf_nodes:
                outgoing_nodes = self.add_outgoing(
                    node, link_types=link_types, annotate_links=annotate_links, return_pks=False
                )
                if outgoing_nodes and print_func:
                    print_func('  {} -> {}'.format(node.pk, [on.pk for on in outgoing_nodes]))
                new_nodes.extend(outgoing_nodes)

                if include_process_inputs and isinstance(node, ProcessNode):
                    self.add_incoming(node, link_types=link_types, annotate_links=annotate_links)

            # ensure the same path isn't traversed multiple times
            leaf_nodes = []
            for new_node in new_nodes:
                if new_node.pk in traversed_pks:
                    continue
                leaf_nodes.append(new_node)
                traversed_pks.append(new_node.pk)

    def recurse_ancestors(
        self,
        origin,
        depth=None,
        link_types=(),
        annotate_links=False,
        origin_style=(),
        include_process_outputs=False,
        print_func=None
    ):
        """add nodes and edges from an origin recursively,
        following incoming links

        :param origin: node or node pk/uuid
        :type origin: aiida.orm.nodes.node.Node or int
        :param depth: if not None, stop after travelling a certain depth into the graph (Default value = None)
        :type depth: None or int
        :param link_types: filter by subset of link types (Default value = ())
        :type link_types: tuple or str
        :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = False)
        :type annotate_links: bool
        :param origin_style: node style map for origin node (Default value = ())
        :type origin_style: dict or tuple
        :param include_process_outputs:  include outgoing links for all processes (Default value = False)
        :type include_process_outputs: bool
        :param print_func: a function to stream information to, i.e. print_func(str)

        """
        # pylint: disable=too-many-arguments
        origin_node = self._load_node(origin)

        self.add_node(origin_node, style_override=dict(origin_style))

        last_nodes = [origin_node]
        traversed_pks = [origin_node.pk]
        cur_depth = 0
        while last_nodes:
            cur_depth += 1
            # checking of maximum descendant depth is set and applies.
            if depth is not None and cur_depth > depth:
                break
            if print_func:
                print_func('- Depth: {}'.format(cur_depth))
            new_nodes = []
            for node in last_nodes:
                incoming_nodes = self.add_incoming(
                    node, link_types=link_types, annotate_links=annotate_links, return_pks=False
                )
                if incoming_nodes and print_func:
                    print_func('  {} -> {}'.format(node.pk, [n.pk for n in incoming_nodes]))
                new_nodes.extend(incoming_nodes)

                if include_process_outputs and isinstance(node, ProcessNode):
                    self.add_outgoing(node, link_types=link_types, annotate_links=annotate_links)

            # ensure the same path isn't traversed multiple times
            last_nodes = []
            for new_node in new_nodes:
                if new_node.pk in traversed_pks:
                    continue
                last_nodes.append(new_node)
                traversed_pks.append(new_node.pk)

    def add_origin_to_targets(
        self,
        origin,
        target_cls,
        target_filters=None,
        include_target_inputs=False,
        include_target_outputs=False,
        origin_style=(),
        annotate_links=False
    ):
        """Add nodes and edges from an origin node to all nodes of a target node class.

        :param origin: node or node pk/uuid
        :type origin: aiida.orm.nodes.node.Node or int
        :param target_cls: target node class
        :param target_filters:  (Default value = None)
        :type target_filters: dict or None
        :param include_target_inputs:  (Default value = False)
        :type include_target_inputs: bool
        :param include_target_outputs:  (Default value = False)
        :type include_target_outputs: bool
        :param origin_style: node style map for origin node (Default value = ())
        :type origin_style: dict or tuple
        :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = False)
        :type annotate_links: bool

        """
        # pylint: disable=too-many-arguments
        origin_node = self._load_node(origin)

        if target_filters is None:
            target_filters = {}

        self.add_node(origin_node, style_override=dict(origin_style))

        query = QueryBuilder(
            **{
                'path': [{
                    'cls': origin_node.__class__,
                    'filters': {
                        'id': origin_node.pk
                    },
                    'tag': 'origin'
                }, {
                    'cls': target_cls,
                    'filters': target_filters,
                    'with_ancestors': 'origin',
                    'tag': 'target',
                    'project': '*'
                }]
            }
        )

        for (target_node,) in query.iterall():
            self.add_node(target_node)
            self.add_edge(origin_node, target_node, style={'style': 'dashed', 'color': 'grey'})

            if include_target_inputs:
                self.add_incoming(target_node, annotate_links=annotate_links)

            if include_target_outputs:
                self.add_outgoing(target_node, annotate_links=annotate_links)

    def add_origins_to_targets(
        self,
        origin_cls,
        target_cls,
        origin_filters=None,
        target_filters=None,
        include_target_inputs=False,
        include_target_outputs=False,
        origin_style=(),
        annotate_links=False
    ):
        """Add nodes and edges from all nodes of an origin class to all node of a target node class.

        :param origin_cls: origin node class
        :param target_cls: target node class
        :param origin_filters:  (Default value = None)
        :type origin_filters: dict or None
        :param target_filters:  (Default value = None)
        :type target_filters: dict or None
        :param include_target_inputs:  (Default value = False)
        :type include_target_inputs: bool
        :param include_target_outputs:  (Default value = False)
        :type include_target_outputs: bool
        :param origin_style: node style map for origin node (Default value = ())
        :type origin_style: dict or tuple
        :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = False)
        :type annotate_links: bool

        """
        # pylint: disable=too-many-arguments
        if origin_filters is None:
            origin_filters = {}

        query = QueryBuilder(
            **{'path': [{
                'cls': origin_cls,
                'filters': origin_filters,
                'tag': 'origin',
                'project': '*'
            }]}
        )

        for (node,) in query.iterall():
            self.add_origin_to_targets(
                node,
                target_cls,
                target_filters=target_filters,
                include_target_inputs=include_target_inputs,
                include_target_outputs=include_target_outputs,
                origin_style=origin_style,
                annotate_links=annotate_links
            )