Ejemplo n.º 1
0
    def add_outgoing(self, node, link_types=(), annotate_links=None, return_pks=True):
        """add nodes and edges for outgoing links to a node

        :param node: node or node pk
        :type node: aiida.orm.nodes.node.Node or int
        :param link_types: filter by link types (Default value = ())
        :type link_types: str or tuple[str] or aiida.common.links.LinkType or tuple[aiida.common.links.LinkType]
        :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = None)
        :type annotate_links: bool or str
        :param return_pks: whether to return a list of nodes, or list of node pks (Default value = True)
        :type return_pks: bool
        :returns: list of nodes or node pks
        """
        if annotate_links not in [None, False, 'label', 'type', 'both']:
            raise ValueError(
                'annotate_links must be one of False, "label", "type" or "both"\ninstead, it is: {}'.
                format(annotate_links)
            )

        # outgoing nodes are found traversing forwards
        node_pk = self._load_node(node).pk
        valid_link_types = self._default_link_types(link_types)
        valid_link_types = self._convert_link_types(valid_link_types)
        traversed_graph = traverse_graph(
            (node_pk,),
            max_iterations=1,
            get_links=True,
            links_forward=valid_link_types,
        )

        traversed_nodes = orm.QueryBuilder().append(
            orm.Node,
            filters={'id': {
                'in': traversed_graph['nodes']
            }},
            project=['id', '*'],
            tag='node',
        )
        traversed_nodes = {query_result[0]: query_result[1] for query_result in traversed_nodes.all()}

        for _, traversed_node in traversed_nodes.items():
            self.add_node(traversed_node, style_override=None)

        for link in traversed_graph['links']:
            source_node = traversed_nodes[link.source_id]
            target_node = traversed_nodes[link.target_id]
            link_pair = LinkPair(self._convert_link_types(link.link_type)[0], link.link_label)
            link_style = self._link_styles(
                link_pair, add_label=annotate_links in ['label', 'both'], add_type=annotate_links in ['type', 'both']
            )
            self.add_edge(source_node, target_node, link_pair, style=link_style)

        if return_pks:
            return list(traversed_nodes.keys())
        # else:
        return list(traversed_nodes.values())
Ejemplo n.º 2
0
    def test_empty_input(self):
        """Testing empty input."""

        all_links = [
            LinkType.INPUT_CALC, LinkType.CALL_CALC, LinkType.CREATE,
            LinkType.INPUT_WORK, LinkType.CALL_WORK, LinkType.RETURN
        ]

        obtained_results = traverse_graph([],
                                          links_forward=all_links,
                                          links_backward=all_links)
        self.assertEqual(obtained_results['nodes'], set())
        self.assertEqual(obtained_results['links'], None)

        obtained_results = traverse_graph([],
                                          get_links=True,
                                          links_forward=all_links,
                                          links_backward=all_links)
        self.assertEqual(obtained_results['nodes'], set())
        self.assertEqual(obtained_results['links'], set())
Ejemplo n.º 3
0
    def test_traversal_errors(self):
        """This will test the errors of the traversers."""
        from aiida.common.exceptions import NotExistent
        from aiida import orm

        test_node = orm.Data().store()
        false_node = -1

        with self.assertRaises(NotExistent):
            _ = traverse_graph([false_node])

        with self.assertRaises(TypeError):
            _ = traverse_graph(['not a node'])

        with self.assertRaises(TypeError):
            _ = traverse_graph('not a list')

        with self.assertRaises(TypeError):
            _ = traverse_graph([test_node], links_forward=1984)

        with self.assertRaises(TypeError):
            _ = traverse_graph([test_node], links_backward=['not a link'])
Ejemplo n.º 4
0
 def _single_test(self,
                  starting_nodes=(),
                  expanded_nodes=(),
                  links_forward=(),
                  links_backward=()):
     """Auxiliary method to perform a single test run and assertion"""
     obtained_nodes = traverse_graph(
         starting_nodes,
         links_forward=links_forward,
         links_backward=links_backward,
     )['nodes']
     expected_nodes = set(starting_nodes + expanded_nodes)
     self.assertEqual(obtained_nodes, expected_nodes)
Ejemplo n.º 5
0
    def recurse_ancestors(self,
                          origin,
                          depth=None,
                          link_types=(),
                          annotate_links=False,
                          origin_style=None,
                          include_process_outputs=False,
                          print_func=None):
        """add nodes and edges from an origin recursively,
        following incoming links

        :param origin: node or node pk/uuid
        :type origin: aiida.orm.nodes.node.Node or int
        :param depth: if not None, stop after travelling a certain depth into the graph (Default value = None)
        :type depth: None or int
        :param link_types: filter by subset of link types (Default value = ())
        :type link_types: tuple or str
        :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = False)
        :type annotate_links: bool
        :param origin_style: node style map for origin node (Default value = None)
        :type origin_style: None or dict
        :param include_process_outputs:  include outgoing links for all processes (Default value = False)
        :type include_process_outputs: bool
        :param print_func: a function to stream information to, i.e. print_func(str)

        .. deprecated:: 1.1.0
            `print_func` will be removed in `v2.0.0`
        """
        # pylint: disable=too-many-arguments,too-many-locals
        import warnings
        from aiida.common.warnings import AiidaDeprecationWarning
        if print_func:
            warnings.warn(  # pylint: disable=no-member
                '`print_func` is deprecated because graph traversal has been refactored',
                AiidaDeprecationWarning)

        # Get graph traversal rules where the given link types and direction are all set to True,
        # and all others are set to False
        origin_pk = origin if isinstance(origin, int) else origin.pk
        valid_link_types = self._default_link_types(link_types)
        valid_link_types = self._convert_link_types(valid_link_types)
        traversed_graph = traverse_graph(
            (origin_pk, ),
            max_iterations=depth,
            get_links=True,
            links_backward=valid_link_types,
        )

        # Traverse forward along input_work and input_calc links from all nodes traversed in the previous step
        # and join the result with the original traversed graph. This includes calculation outputs in the Graph
        if include_process_outputs:
            traversed_outputs = traverse_graph(
                traversed_graph['nodes'],
                max_iterations=1,
                get_links=True,
                links_forward=[LinkType.CREATE, LinkType.RETURN])
            traversed_graph['nodes'] = traversed_graph['nodes'].union(
                traversed_outputs['nodes'])
            traversed_graph['links'] = traversed_graph['links'].union(
                traversed_outputs['links'])

        # Do one central query for all nodes in the Graph and generate a {id: Node} dictionary
        traversed_nodes = orm.QueryBuilder().append(
            orm.Node,
            filters={'id': {
                'in': traversed_graph['nodes']
            }},
            project=['id', '*'],
            tag='node',
        )
        traversed_nodes = {
            query_result[0]: query_result[1]
            for query_result in traversed_nodes.all()
        }

        # Pop the origin node and add it to the graph, applying custom styling
        origin_node = traversed_nodes.pop(origin_pk)
        self.add_node(origin_node, style_override=origin_style)

        # Add all traversed nodes to the graph with default styling
        for _, traversed_node in traversed_nodes.items():
            self.add_node(traversed_node, style_override=None)

        # Add the origin node back into traversed nodes so it can be found for adding edges
        traversed_nodes[origin_pk] = origin_node

        # Add all links to the Graph, using the {id: Node} dictionary for queryless Node retrieval, applying
        # appropriate styling
        for link in traversed_graph['links']:
            source_node = traversed_nodes[link.source_id]
            target_node = traversed_nodes[link.target_id]
            link_pair = LinkPair(
                self._convert_link_types(link.link_type)[0], link.link_label)
            link_style = self._link_styles(link_pair,
                                           add_label=annotate_links
                                           in ['label', 'both'],
                                           add_type=annotate_links
                                           in ['type', 'both'])
            self.add_edge(source_node,
                          target_node,
                          link_pair,
                          style=link_style)
Ejemplo n.º 6
0
    def test_traversal_cycle(self):
        """
        This will test that cycles don't go into infinite loops by testing a
        graph with two data nodes data_take and data_drop and a work_select
        that takes both as input but returns only data_take
        """
        from aiida import orm

        data_take = orm.Data().store()
        data_drop = orm.Data().store()
        work_select = orm.WorkflowNode()

        work_select.add_incoming(data_take,
                                 link_type=LinkType.INPUT_WORK,
                                 link_label='input_take')
        work_select.add_incoming(data_drop,
                                 link_type=LinkType.INPUT_WORK,
                                 link_label='input_drop')
        work_select.store()

        data_take.add_incoming(work_select,
                               link_type=LinkType.RETURN,
                               link_label='return_link')

        data_take = data_take.pk
        data_drop = data_drop.pk
        work_select = work_select.pk

        every_node = [data_take, data_drop, work_select]

        for single_node in every_node:
            expected_nodes = set([single_node])
            obtained_nodes = traverse_graph([single_node])['nodes']
            self.assertEqual(obtained_nodes, expected_nodes)

        links_forward = [LinkType.INPUT_WORK, LinkType.RETURN]
        links_backward = []

        # Forward: data_drop to (input) work_select to (return) data_take
        obtained_nodes = traverse_graph([data_drop],
                                        links_forward=links_forward,
                                        links_backward=links_backward)['nodes']
        expected_nodes = set(every_node)
        self.assertEqual(obtained_nodes, expected_nodes)

        # Forward: data_take to (input) work_select (data_drop is not returned)
        obtained_nodes = traverse_graph([data_take],
                                        links_forward=links_forward,
                                        links_backward=links_backward)['nodes']
        expected_nodes = set([work_select, data_take])
        self.assertEqual(obtained_nodes, expected_nodes)

        # Forward: work_select to (return) data_take (data_drop is not returned)
        obtained_nodes = traverse_graph([work_select],
                                        links_forward=links_forward,
                                        links_backward=links_backward)['nodes']
        self.assertEqual(obtained_nodes, expected_nodes)

        links_forward = []
        links_backward = [LinkType.INPUT_WORK, LinkType.RETURN]

        # Backward: data_drop is not returned so it has no backward link
        expected_nodes = set([data_drop])
        obtained_nodes = traverse_graph([data_drop],
                                        links_forward=links_forward,
                                        links_backward=links_backward)['nodes']
        self.assertEqual(obtained_nodes, expected_nodes)

        # Backward: data_take to (return) work_select to (input) data_drop
        expected_nodes = set(every_node)
        obtained_nodes = traverse_graph([data_take],
                                        links_forward=links_forward,
                                        links_backward=links_backward)['nodes']
        self.assertEqual(obtained_nodes, expected_nodes)

        # Backward: work_select to (input) data_take and data_drop
        obtained_nodes = traverse_graph([work_select],
                                        links_forward=links_forward,
                                        links_backward=links_backward)['nodes']
        self.assertEqual(obtained_nodes, expected_nodes)