Exemplo n.º 1
0
 def test_empty_input(self):
     """
     Testing empty input.
     """
     basket = Basket(nodes=[])
     queryb = orm.QueryBuilder()
     queryb.append(orm.Node).append(orm.Node)
     uprule = UpdateRule(queryb, max_iterations=np.inf)
     result = uprule.run(basket.copy())['nodes'].keyset
     self.assertEqual(result, set())
Exemplo n.º 2
0
    def test_cycle(self):
        """
        Testing the case of a cycle (workflow node with a data node that is
        both an input and an output):
        - Update rules with no max iterations should not get stuck.
        - Replace rules should return alternating results.
        """
        data_node = orm.Data().store()
        work_node = orm.WorkflowNode()
        work_node.add_incoming(data_node,
                               link_type=LinkType.INPUT_WORK,
                               link_label='input_link')
        work_node.store()
        data_node.add_incoming(work_node,
                               link_type=LinkType.RETURN,
                               link_label='return_link')

        basket = Basket(nodes=[data_node.id])
        queryb = orm.QueryBuilder()
        queryb.append(orm.Node).append(orm.Node)

        uprule = UpdateRule(queryb, max_iterations=np.inf)
        obtained = uprule.run(basket.copy())['nodes'].keyset
        expected = set([data_node.id, work_node.id])
        self.assertEqual(obtained, expected)

        rerule1 = ReplaceRule(queryb, max_iterations=1)
        result1 = rerule1.run(basket.copy())['nodes'].keyset
        self.assertEqual(result1, set([work_node.id]))

        rerule2 = ReplaceRule(queryb, max_iterations=2)
        result2 = rerule2.run(basket.copy())['nodes'].keyset
        self.assertEqual(result2, set([data_node.id]))

        rerule3 = ReplaceRule(queryb, max_iterations=3)
        result3 = rerule3.run(basket.copy())['nodes'].keyset
        self.assertEqual(result3, set([work_node.id]))

        rerule4 = ReplaceRule(queryb, max_iterations=4)
        result4 = rerule4.run(basket.copy())['nodes'].keyset
        self.assertEqual(result4, set([data_node.id]))
Exemplo n.º 3
0
    def test_edges(self):
        """
        Testing how the links are stored during traversal of the graph.
        """
        nodes = self._create_basic_graph()

        # Forward traversal (check all nodes and all links)
        basket = Basket(nodes=[nodes['data_i'].id])
        queryb = orm.QueryBuilder()
        queryb.append(orm.Node, tag='nodes_in_set')
        queryb.append(orm.Node, with_incoming='nodes_in_set')
        uprule = UpdateRule(queryb, max_iterations=2, track_edges=True)
        uprule_result = uprule.run(basket.copy())

        obtained = uprule_result['nodes'].keyset
        expected = set(anode.id for _, anode in nodes.items())
        self.assertEqual(obtained, expected)

        obtained = set()
        for data in uprule_result['nodes_nodes'].keyset:
            obtained.add((data[0], data[1]))

        expected = {
            (nodes['data_i'].id, nodes['calc_0'].id),
            (nodes['data_i'].id, nodes['work_1'].id),
            (nodes['data_i'].id, nodes['work_2'].id),
            (nodes['calc_0'].id, nodes['data_o'].id),
            (nodes['work_1'].id, nodes['data_o'].id),
            (nodes['work_2'].id, nodes['data_o'].id),
            (nodes['work_2'].id, nodes['work_1'].id),
            (nodes['work_1'].id, nodes['calc_0'].id),
        }
        self.assertEqual(obtained, expected)

        # Backwards traversal (check partial traversal and link direction)
        basket = Basket(nodes=[nodes['data_o'].id])
        queryb = orm.QueryBuilder()
        queryb.append(orm.Node, tag='nodes_in_set')
        queryb.append(orm.Node, with_outgoing='nodes_in_set')
        uprule = UpdateRule(queryb, max_iterations=1, track_edges=True)
        uprule_result = uprule.run(basket.copy())

        obtained = uprule_result['nodes'].keyset
        expected = set(anode.id for _, anode in nodes.items())
        expected = expected.difference(set([nodes['data_i'].id]))
        self.assertEqual(obtained, expected)

        obtained = set()
        for data in uprule_result['nodes_nodes'].keyset:
            obtained.add((data[0], data[1]))

        expected = {
            (nodes['calc_0'].id, nodes['data_o'].id),
            (nodes['work_1'].id, nodes['data_o'].id),
            (nodes['work_2'].id, nodes['data_o'].id),
        }
        self.assertEqual(obtained, expected)
Exemplo n.º 4
0
    def test_algebra(self):
        """Test simple addition, in-place addition, simple subtraction, in-place subtraction"""

        depth0 = 4
        branching0 = 2
        tree0 = create_tree(max_depth=depth0, branching=branching0)
        basket0 = Basket(nodes=(tree0['parent'].id,))
        queryb0 = orm.QueryBuilder()
        queryb0.append(orm.Node).append(orm.Node)
        rule0 = UpdateRule(queryb0, max_iterations=depth0)
        res0 = rule0.run(basket0.copy())
        aes0 = res0.nodes

        depth1 = 3
        branching1 = 6
        tree1 = create_tree(max_depth=depth1, branching=branching1)
        basket1 = Basket(nodes=(tree1['parent'].id,))
        queryb1 = orm.QueryBuilder()
        queryb1.append(orm.Node).append(orm.Node)
        rule1 = UpdateRule(queryb1, max_iterations=depth1)
        res1 = rule1.run(basket1.copy())
        aes1 = res1.nodes

        aes2 = aes0 + aes1
        union01 = aes0.keyset | aes1.keyset
        self.assertEqual(aes2.keyset, union01)

        aes0_copy = aes0.copy()
        aes0_copy += aes1
        self.assertEqual(aes0_copy.keyset, union01)

        aes3 = aes0_copy - aes1
        self.assertEqual(aes0.keyset, aes3.keyset)
        self.assertEqual(aes0, aes3)

        aes0_copy -= aes1
        self.assertEqual(aes0.keyset, aes3.keyset, aes0_copy.keyset)
        self.assertEqual(aes0, aes3, aes0_copy)

        aes4 = aes0 - aes0
        self.assertEqual(aes4.keyset, set())

        aes0_copy -= aes0
        self.assertEqual(aes0_copy.keyset, set())
Exemplo n.º 5
0
    def test_groups(self):
        """
        Testing connection between (aiida-)groups and (aiida-)nodes, which are treated
        as if they both were (graph-)nodes.
        """
        node1 = orm.Data().store()
        node2 = orm.Data().store()
        node3 = orm.Data().store()
        node4 = orm.Data().store()

        group1 = orm.Group(label='group-01').store()
        group1.add_nodes(node1)

        group2 = orm.Group(label='group-02').store()
        group2.add_nodes(node2)
        group2.add_nodes(node3)

        group3 = orm.Group(label='group-03').store()
        group3.add_nodes(node4)
        group4 = orm.Group(label='group-04').store()
        group4.add_nodes(node4)

        # Rule that only gets nodes connected by the same group
        queryb = orm.QueryBuilder()
        queryb.append(orm.Node, tag='nodes_in_set')
        queryb.append(orm.Group, with_node='nodes_in_set', tag='groups_considered', filters={'type_string': 'user'})
        queryb.append(orm.Data, with_group='groups_considered')

        initial_node = [node2.id]
        basket_inp = Basket(nodes=initial_node)
        tested_rule = UpdateRule(queryb, max_iterations=np.inf)
        basket_out = tested_rule.run(basket_inp.copy())

        obtained = basket_out['nodes'].keyset
        expected = set([node2.id, node3.id])
        self.assertEqual(obtained, expected)

        obtained = basket_out['groups'].keyset
        expected = set()
        self.assertEqual(obtained, expected)

        # But two rules chained should get both nodes and groups...
        queryb = orm.QueryBuilder()
        queryb.append(orm.Node, tag='nodes_in_set')
        queryb.append(orm.Group, with_node='nodes_in_set', filters={'type_string': 'user'})
        rule1 = UpdateRule(queryb)

        queryb = orm.QueryBuilder()
        queryb.append(orm.Group, tag='groups_in_set')
        queryb.append(orm.Node, with_group='groups_in_set')
        rule2 = UpdateRule(queryb)

        ruleseq = RuleSequence((rule1, rule2), max_iterations=np.inf)

        # ...both starting with a node
        initial_node = [node2.id]
        basket_inp = Basket(nodes=initial_node)
        basket_out = ruleseq.run(basket_inp.copy())

        obtained = basket_out['nodes'].keyset
        expected = set([node2.id, node3.id])
        self.assertEqual(obtained, expected)

        obtained = basket_out['groups'].keyset
        expected = set([group2.id])
        self.assertEqual(obtained, expected)

        # ...and starting with a group
        initial_group = [group3.id]
        basket_inp = Basket(groups=initial_group)
        basket_out = ruleseq.run(basket_inp.copy())

        obtained = basket_out['nodes'].keyset
        expected = set([node4.id])
        self.assertEqual(obtained, expected)

        obtained = basket_out['groups'].keyset
        expected = set([group3.id, group4.id])
        self.assertEqual(obtained, expected)

        # Testing a "group chain"
        total_groups = 10

        groups = []
        for idx in range(total_groups):
            new_group = orm.Group(label='group-{}'.format(idx)).store()
            groups.append(new_group)

        nodes = []
        edges = set()
        for idx in range(1, total_groups):
            new_node = orm.Data().store()
            groups[idx].add_nodes(new_node)
            groups[idx - 1].add_nodes(new_node)
            nodes.append(new_node)
            edges.add(GroupNodeEdge(node_id=new_node.id, group_id=groups[idx].id))
            edges.add(GroupNodeEdge(node_id=new_node.id, group_id=groups[idx - 1].id))

        qb1 = orm.QueryBuilder()
        qb1.append(orm.Node, tag='nodes_in_set')
        qb1.append(orm.Group, with_node='nodes_in_set', filters={'type_string': 'user'})
        rule1 = UpdateRule(qb1, track_edges=True)

        qb2 = orm.QueryBuilder()
        qb2.append(orm.Group, tag='groups_in_set')
        qb2.append(orm.Node, with_group='groups_in_set')
        rule2 = UpdateRule(qb2, track_edges=True)

        ruleseq = RuleSequence((rule1, rule2), max_iterations=np.inf)

        initial_node = [nodes[-1].id]
        basket_inp = Basket(nodes=initial_node)
        basket_out = ruleseq.run(basket_inp.copy())

        obtained = basket_out['nodes'].keyset
        expected = set(n.id for n in nodes)
        self.assertEqual(obtained, expected)

        obtained = basket_out['groups'].keyset
        expected = set(g.id for g in groups)
        self.assertEqual(obtained, expected)

        # testing the edges between groups and nodes:
        result = basket_out['groups_nodes'].keyset
        self.assertEqual(result, edges)
Exemplo n.º 6
0
    def test_returns_calls(self):
        """Tests return calls (?)"""
        create_reversed = False
        return_reversed = False

        rules = []
        # linking all processes to input data:
        queryb = orm.QueryBuilder()
        queryb.append(orm.Data, tag='predecessor')
        queryb.append(
            orm.ProcessNode,
            with_incoming='predecessor',
            edge_filters={'type': {
                'in': [LinkType.INPUT_CALC.value, LinkType.INPUT_WORK.value]
            }}
        )
        rules.append(UpdateRule(queryb))

        # CREATE/RETURN(ProcessNode, Data) - Forward
        queryb = orm.QueryBuilder()
        queryb.append(orm.ProcessNode, tag='predecessor')
        queryb.append(
            orm.Data,
            with_incoming='predecessor',
            edge_filters={'type': {
                'in': [LinkType.CREATE.value, LinkType.RETURN.value]
            }}
        )
        rules.append(UpdateRule(queryb))

        # CALL(ProcessNode, ProcessNode) - Forward
        queryb = orm.QueryBuilder()
        queryb.append(orm.ProcessNode, tag='predecessor')
        queryb.append(
            orm.ProcessNode,
            with_incoming='predecessor',
            edge_filters={'type': {
                'in': [LinkType.CALL_CALC.value, LinkType.CALL_WORK.value]
            }}
        )
        rules.append(UpdateRule(queryb))

        # CREATE(ProcessNode, Data) - Reversed
        if create_reversed:
            queryb = orm.QueryBuilder()
            queryb.append(orm.ProcessNode, tag='predecessor', project=['id'])
            queryb.append(orm.Data, with_incoming='predecessor', edge_filters={'type': {'in': [LinkType.CREATE.value]}})
            rules.append(UpdateRule(queryb))

        # Case 3:
        # RETURN(ProcessNode, Data) - Reversed
        if return_reversed:
            queryb = orm.QueryBuilder()
            queryb.append(orm.ProcessNode, tag='predecessor')
            queryb.append(orm.Data, output_of='predecessor', edge_filters={'type': {'in': [LinkType.RETURN.value]}})
            rules.append(UpdateRule(queryb))

        # Test was doing the calculation but not checking the results. Will have to think
        # how to do that now. Temporal replacement:
        new_node = orm.Data().store()
        basket = Basket(nodes=(new_node.id,))

        ruleseq = RuleSequence(rules, max_iterations=np.inf)
        resulting_set = ruleseq.run(basket.copy())
        expecting_set = resulting_set
        self.assertEqual(expecting_set, resulting_set)
Exemplo n.º 7
0
    def test_stash(self):
        """Testing sequencies and 'stashing'

        Testing the dependency on the order of the operations in RuleSequence and the
        'stash' functionality. This will be performed in a graph that has a calculation
        (calc_ca) with two input data nodes (data_i1 and data_i2) and two output data
        nodes (data_o1 and data_o2), and another calculation (calc_cb) which takes both
        one of the inputs and one of the outputs of the first one (data_i2 and data_o2)
        as inputs to produce a final output (data_o3).
        """
        nodes = self._create_branchy_graph()
        basket = Basket(nodes=[nodes['data_1'].id])

        queryb_inp = orm.QueryBuilder()
        queryb_inp.append(orm.Node, tag='nodes_in_set')
        queryb_inp.append(orm.Node, with_outgoing='nodes_in_set')
        uprule_inp = UpdateRule(queryb_inp)
        queryb_out = orm.QueryBuilder()
        queryb_out.append(orm.Node, tag='nodes_in_set')
        queryb_out.append(orm.Node, with_incoming='nodes_in_set')
        uprule_out = UpdateRule(queryb_out)

        expect_base = set([nodes['calc_1'].id, nodes['data_1'].id, nodes['calc_2'].id])

        # First get outputs, then inputs.
        rule_seq = RuleSequence((uprule_out, uprule_inp))
        obtained = rule_seq.run(basket.copy())['nodes'].keyset
        expected = expect_base.union(set([nodes['data_i'].id]))
        self.assertEqual(obtained, expected)

        # First get inputs, then outputs.
        rule_seq = RuleSequence((uprule_inp, uprule_out))
        obtained = rule_seq.run(basket.copy())['nodes'].keyset
        expected = expect_base.union(set([nodes['data_o'].id]))
        self.assertEqual(obtained, expected)

        # Now using the stash option in either order.
        stash = basket.get_template()
        rule_save = RuleSaveWalkers(stash)
        rule_load = RuleSetWalkers(stash)

        # Checking whether Rule does the right thing
        # (i.e. If I stash the result, the operational sets should be the original,
        # set, whereas the stash contains the same data as the starting point)
        obtained = rule_save.run(basket.copy())
        expected = basket.copy()
        self.assertEqual(obtained, expected)
        self.assertEqual(stash, basket)

        stash = basket.get_template()
        rule_save = RuleSaveWalkers(stash)
        rule_load = RuleSetWalkers(stash)
        serule_io = RuleSequence((rule_save, uprule_inp, rule_load, uprule_out))
        result_io = serule_io.run(basket.copy())['nodes'].keyset
        self.assertEqual(result_io, expect_base)

        stash = basket.get_template()
        rule_save = RuleSaveWalkers(stash)
        rule_load = RuleSetWalkers(stash)
        serule_oi = RuleSequence((rule_save, uprule_out, rule_load, uprule_inp))
        result_oi = serule_oi.run(basket.copy())['nodes'].keyset
        self.assertEqual(result_oi, expect_base)
Exemplo n.º 8
0
    def test_basic_graph(self):
        """
        Testing basic operations for the explorer:
        - Selection of ascendants.
        - Selection of descendants.
        - Ascendants and descendants through specific links.
        - Ascendants and descendants of specific type.
        """

        nodes = self._create_basic_graph()
        basket_w1 = Basket(nodes=[nodes['work_1'].id])
        basket_w2 = Basket(nodes=[nodes['work_2'].id])

        # Find all the descendants of work_1
        queryb = orm.QueryBuilder()
        queryb.append(orm.Node, tag='nodes_in_set')
        queryb.append(orm.Node, with_incoming='nodes_in_set')
        uprule = UpdateRule(queryb, max_iterations=10)

        obtained = uprule.run(basket_w1.copy())['nodes'].keyset
        expected = set((nodes['work_1'].id, nodes['calc_0'].id, nodes['data_o'].id))
        self.assertEqual(obtained, expected)

        # Find all the descendants of work_1 through call_calc (calc_0)
        edge_cacalc = {'type': {'in': [LinkType.CALL_CALC.value]}}
        queryb = orm.QueryBuilder()
        queryb.append(orm.Node, tag='nodes_in_set')
        queryb.append(orm.Node, with_incoming='nodes_in_set', edge_filters=edge_cacalc)
        uprule = UpdateRule(queryb, max_iterations=10)

        obtained = uprule.run(basket_w1.copy())['nodes'].keyset
        expected = set((nodes['work_1'].id, nodes['calc_0'].id))
        self.assertEqual(obtained, expected)

        # Find all the descendants of work_1 that are data nodes (data_o)
        queryb = orm.QueryBuilder()
        queryb.append(orm.Node, tag='nodes_in_set')
        queryb.append(orm.Data, with_incoming='nodes_in_set')
        uprule = UpdateRule(queryb, max_iterations=10)

        obtained = uprule.run(basket_w1.copy())['nodes'].keyset
        expected = set((nodes['work_1'].id, nodes['data_o'].id))
        self.assertEqual(obtained, expected)

        # Find all the ascendants of work_1
        queryb = orm.QueryBuilder()
        queryb.append(orm.Node, tag='nodes_in_set')
        queryb.append(orm.Node, with_outgoing='nodes_in_set')
        uprule = UpdateRule(queryb, max_iterations=10)

        obtained = uprule.run(basket_w1.copy())['nodes'].keyset
        expected = set((nodes['work_1'].id, nodes['work_2'].id, nodes['data_i'].id))
        self.assertEqual(obtained, expected)

        # Find all the ascendants of work_1 through input_work (data_i)
        edge_inpwork = {'type': {'in': [LinkType.INPUT_WORK.value]}}
        queryb = orm.QueryBuilder()
        queryb.append(orm.Node, tag='nodes_in_set')
        queryb.append(orm.Node, with_outgoing='nodes_in_set', edge_filters=edge_inpwork)
        uprule = UpdateRule(queryb, max_iterations=10)

        obtained = uprule.run(basket_w1.copy())['nodes'].keyset
        expected = set((nodes['work_1'].id, nodes['data_i'].id))
        self.assertEqual(obtained, expected)

        # Find all the ascendants of work_1 that are workflow nodes (work_2)
        queryb = orm.QueryBuilder()
        queryb.append(orm.Node, tag='nodes_in_set')
        queryb.append(orm.ProcessNode, with_outgoing='nodes_in_set')
        uprule = UpdateRule(queryb, max_iterations=10)

        obtained = uprule.run(basket_w1.copy())['nodes'].keyset
        expected = set((nodes['work_1'].id, nodes['work_2'].id))
        self.assertEqual(obtained, expected)

        # Only get the descendants that are direct (1st level) (work_1, data_o)
        queryb = orm.QueryBuilder()
        queryb.append(orm.Node, tag='nodes_in_set')
        queryb.append(orm.Node, with_incoming='nodes_in_set')
        rerule = ReplaceRule(queryb, max_iterations=1)

        obtained = rerule.run(basket_w2.copy())['nodes'].keyset
        expected = set((nodes['work_1'].id, nodes['data_o'].id))
        self.assertEqual(obtained, expected)

        # Only get the descendants of the descendants (2nd level) (calc_0, data_o)
        queryb = orm.QueryBuilder()
        queryb.append(orm.Node, tag='nodes_in_set')
        queryb.append(orm.Node, with_incoming='nodes_in_set')
        rerule = ReplaceRule(queryb, max_iterations=2)

        obtained = rerule.run(basket_w2.copy())['nodes'].keyset
        expected = set((nodes['calc_0'].id, nodes['data_o'].id))
        self.assertEqual(obtained, expected)
Exemplo n.º 9
0
def traverse_graph(starting_pks, max_iterations=None, get_links=False, links_forward=(), links_backward=()):
    """
    This function will return the set of all nodes that can be connected
    to a list of initial nodes through any sequence of specified links.
    Optionally, it may also return the links that connect these nodes.

    :type starting_pks: list or tuple or set
    :param starting_pks: Contains the (valid) pks of the starting nodes.

    :type max_iterations: int or None
    :param max_iterations:
        The number of iterations to apply the set of rules (a value of 'None' will
        iterate until no new nodes are added).

    :param bool get_links:
        Pass True to also return the links between all nodes (found + initial).

    :type links_forward: aiida.common.links.LinkType
    :param links_forward:
        List with all the links that should be traversed in the forward direction.

    :type links_backward: aiida.common.links.LinkType
    :param links_backward:
        List with all the links that should be traversed in the backward direction.
    """
    # pylint: disable=too-many-locals,too-many-statements,too-many-branches
    from aiida import orm
    from aiida.tools.graph.age_entities import Basket
    from aiida.tools.graph.age_rules import UpdateRule, RuleSequence, RuleSaveWalkers, RuleSetWalkers
    from aiida.common import exceptions

    if max_iterations is None:
        max_iterations = inf
    elif not (isinstance(max_iterations, int) or max_iterations is inf):
        raise TypeError('Max_iterations has to be an integer or infinity')

    linktype_list = []
    for linktype in links_forward:
        if not isinstance(linktype, LinkType):
            raise TypeError('links_forward should contain links, but one of them is: {}'.format(type(linktype)))
        linktype_list.append(linktype.value)
    filters_forwards = {'type': {'in': linktype_list}}

    linktype_list = []
    for linktype in links_backward:
        if not isinstance(linktype, LinkType):
            raise TypeError('links_backward should contain links, but one of them is: {}'.format(type(linktype)))
        linktype_list.append(linktype.value)
    filters_backwards = {'type': {'in': linktype_list}}

    if not isinstance(starting_pks, (list, set, tuple)):
        raise TypeError('starting_pks must be of type list, set or tuple\ninstead, it is {}'.format(type(starting_pks)))

    if not starting_pks:
        if get_links:
            output = {'nodes': set(), 'links': set()}
        else:
            output = {'nodes': set(), 'links': None}
        return output

    if any([not isinstance(pk, int) for pk in starting_pks]):
        raise TypeError('one of the starting_pks is not of type int:\n {}'.format(starting_pks))
    operational_set = set(starting_pks)

    query_nodes = orm.QueryBuilder()
    query_nodes.append(orm.Node, project=['id'], filters={'id': {'in': operational_set}})
    existing_pks = set(query_nodes.all(flat=True))
    missing_pks = operational_set.difference(existing_pks)
    if missing_pks:
        raise exceptions.NotExistent(
            'The following pks are not in the database and must be pruned before this   call: {}'.format(missing_pks)
        )

    rules = []
    basket = Basket(nodes=operational_set)

    # When max_iterations is finite, the order of traversal may affect the result
    # (its not the same to first go backwards and then forwards than vice-versa)
    # In order to make it order-independent, the result of the first operation needs
    # to be stashed and the second operation must be performed only on the nodes
    # that were already in the set at the begining of the iteration: this way, both
    # rules are applied on the same set of nodes and the order doesn't matter.
    # The way to do this is saving and seting the walkers at the right moments only
    # when both forwards and backwards rules are present.
    if links_forward and links_backward:
        stash = basket.get_template()
        rules += [RuleSaveWalkers(stash)]

    if links_forward:
        query_outgoing = orm.QueryBuilder()
        query_outgoing.append(orm.Node, tag='sources')
        query_outgoing.append(orm.Node, edge_filters=filters_forwards, with_incoming='sources')
        rule_outgoing = UpdateRule(query_outgoing, max_iterations=1, track_edges=get_links)
        rules += [rule_outgoing]

    if links_forward and links_backward:
        rules += [RuleSetWalkers(stash)]

    if links_backward:
        query_incoming = orm.QueryBuilder()
        query_incoming.append(orm.Node, tag='sources')
        query_incoming.append(orm.Node, edge_filters=filters_backwards, with_outgoing='sources')
        rule_incoming = UpdateRule(query_incoming, max_iterations=1, track_edges=get_links)
        rules += [rule_incoming]

    rulesequence = RuleSequence(rules, max_iterations=max_iterations)

    results = rulesequence.run(basket)

    output = {}
    output['nodes'] = results.nodes.keyset
    output['links'] = None
    if get_links:
        output['links'] = results['nodes_nodes'].keyset

    return output
Exemplo n.º 10
0
def traverse_graph(
    starting_pks: Iterable[int],
    max_iterations: Optional[int] = None,
    get_links: bool = False,
    links_forward: Iterable[LinkType] = (),
    links_backward: Iterable[LinkType] = (),
    missing_callback: Optional[Callable[[Iterable[int]], None]] = None
) -> TraverseGraphOutput:
    """
    This function will return the set of all nodes that can be connected
    to a list of initial nodes through any sequence of specified links.
    Optionally, it may also return the links that connect these nodes.

    :param starting_pks: Contains the (valid) pks of the starting nodes.

    :param max_iterations:
        The number of iterations to apply the set of rules (a value of 'None' will
        iterate until no new nodes are added).

    :param get_links: Pass True to also return the links between all nodes (found + initial).

    :param links_forward: List with all the links that should be traversed in the forward direction.
    :param links_backward: List with all the links that should be traversed in the backward direction.

    :param missing_callback: A callback to handle missing starting_pks or if None raise NotExistent
    """
    # pylint: disable=too-many-locals,too-many-statements,too-many-branches

    if max_iterations is None:
        max_iterations = cast(int, inf)
    elif not (isinstance(max_iterations, int) or max_iterations is inf):
        raise TypeError('Max_iterations has to be an integer or infinity')

    linktype_list = []
    for linktype in links_forward:
        if not isinstance(linktype, LinkType):
            raise TypeError(
                f'links_forward should contain links, but one of them is: {type(linktype)}'
            )
        linktype_list.append(linktype.value)
    filters_forwards = {'type': {'in': linktype_list}}

    linktype_list = []
    for linktype in links_backward:
        if not isinstance(linktype, LinkType):
            raise TypeError(
                f'links_backward should contain links, but one of them is: {type(linktype)}'
            )
        linktype_list.append(linktype.value)
    filters_backwards = {'type': {'in': linktype_list}}

    if not isinstance(starting_pks, Iterable):  # pylint: disable=isinstance-second-argument-not-valid-type
        raise TypeError(
            f'starting_pks must be an iterable\ninstead, it is {type(starting_pks)}'
        )

    if any([not isinstance(pk, int) for pk in starting_pks]):
        raise TypeError(
            f'one of the starting_pks is not of type int:\n {starting_pks}')
    operational_set = set(starting_pks)

    if not operational_set:
        if get_links:
            return {'nodes': set(), 'links': set()}
        return {'nodes': set(), 'links': None}

    query_nodes = orm.QueryBuilder()
    query_nodes.append(orm.Node,
                       project=['id'],
                       filters={'id': {
                           'in': operational_set
                       }})
    existing_pks = set(query_nodes.all(flat=True))
    missing_pks = operational_set.difference(existing_pks)
    if missing_pks and missing_callback is None:
        raise exceptions.NotExistent(
            f'The following pks are not in the database and must be pruned before this call: {missing_pks}'
        )
    elif missing_pks and missing_callback is not None:
        missing_callback(missing_pks)

    rules = []
    basket = Basket(nodes=existing_pks)

    # When max_iterations is finite, the order of traversal may affect the result
    # (its not the same to first go backwards and then forwards than vice-versa)
    # In order to make it order-independent, the result of the first operation needs
    # to be stashed and the second operation must be performed only on the nodes
    # that were already in the set at the begining of the iteration: this way, both
    # rules are applied on the same set of nodes and the order doesn't matter.
    # The way to do this is saving and seting the walkers at the right moments only
    # when both forwards and backwards rules are present.
    if links_forward and links_backward:
        stash = basket.get_template()
        rules += [RuleSaveWalkers(stash)]

    if links_forward:
        query_outgoing = orm.QueryBuilder()
        query_outgoing.append(orm.Node, tag='sources')
        query_outgoing.append(orm.Node,
                              edge_filters=filters_forwards,
                              with_incoming='sources')
        rule_outgoing = UpdateRule(query_outgoing,
                                   max_iterations=1,
                                   track_edges=get_links)
        rules += [rule_outgoing]

    if links_forward and links_backward:
        rules += [RuleSetWalkers(stash)]

    if links_backward:
        query_incoming = orm.QueryBuilder()
        query_incoming.append(orm.Node, tag='sources')
        query_incoming.append(orm.Node,
                              edge_filters=filters_backwards,
                              with_outgoing='sources')
        rule_incoming = UpdateRule(query_incoming,
                                   max_iterations=1,
                                   track_edges=get_links)
        rules += [rule_incoming]

    rulesequence = RuleSequence(rules, max_iterations=max_iterations)

    results = rulesequence.run(basket)

    output = {}
    output['nodes'] = results.nodes.keyset
    output['links'] = None
    if get_links:
        output['links'] = results['nodes_nodes'].keyset

    return cast(TraverseGraphOutput, output)