Ejemplo n.º 1
0
    def test_queryhelp(self):
        """
        Here I test the queryhelp by seeing whether results are the same as using the append method.
        I also check passing of tuples.
        """

        from aiida.orm.data.structure import StructureData
        from aiida.orm.data.parameter import ParameterData
        from aiida.orm.data import Data
        from aiida.orm.querybuilder import QueryBuilder
        from aiida.orm.group import Group
        from aiida.orm.computer import Computer
        g = Group(name='helloworld').store()
        for cls in (StructureData, ParameterData, Data):
            obj = cls()
            obj._set_attr('foo-qh2', 'bar')
            obj.store()
            g.add_nodes(obj)

        for cls, expected_count, subclassing in (
            (StructureData, 1, True),
            (ParameterData, 1, True),
            (Data, 3, True),
            (Data, 1, False),
            ((ParameterData, StructureData), 2, True),
            ((ParameterData, StructureData), 2, False),
            ((ParameterData, Data), 2, False),
            ((ParameterData, Data), 3, True),
            ((ParameterData, Data, StructureData), 3, False),
        ):
            qb = QueryBuilder()
            qb.append(cls,
                      filters={'attributes.foo-qh2': 'bar'},
                      subclassing=subclassing,
                      project='uuid')
            self.assertEqual(qb.count(), expected_count)

            qh = qb.get_json_compatible_queryhelp()
            qb_new = QueryBuilder(**qh)
            self.assertEqual(qb_new.count(), expected_count)
            self.assertEqual(sorted([uuid for uuid, in qb.all()]),
                             sorted([uuid for uuid, in qb_new.all()]))

        qb = QueryBuilder().append(Group, filters={'name': 'helloworld'})
        self.assertEqual(qb.count(), 1)

        qb = QueryBuilder().append((Group, ), filters={'name': 'helloworld'})
        self.assertEqual(qb.count(), 1)

        qb = QueryBuilder().append(Computer, )
        self.assertEqual(qb.count(), 1)

        qb = QueryBuilder().append(cls=(Computer, ))
        self.assertEqual(qb.count(), 1)
Ejemplo n.º 2
0
class UpdateRule(Operation):
    def __init__(self,
                 querybuilder,
                 mode=MODES.APPEND,
                 max_iterations=1,
                 track_edges=False,
                 track_visits=True):
        def get_spec_from_path(queryhelp, idx):
            if (queryhelp['path'][idx]['type'].startswith('node')
                    or queryhelp['path'][idx]['type'].startswith('data')
                    or queryhelp['path'][idx]['type'] == ''):
                return 'nodes'
            elif queryhelp['path'][idx]['type'] == 'group':
                return 'groups'
            else:
                raise Exception("not understood entity from ( {} )".format(
                    queryhelp['path'][0]['type']))

        queryhelp = querybuilder.get_json_compatible_queryhelp()
        for pathspec in queryhelp['path']:
            if not pathspec['type']:
                pathspec['type'] = 'node.Node.'
        self._querybuilder = QueryBuilder(**queryhelp)
        queryhelp = self._querybuilder.get_json_compatible_queryhelp()
        self._first_tag = queryhelp['path'][0]['tag']
        self._last_tag = queryhelp['path'][-1]['tag']

        self._entity_from = get_spec_from_path(queryhelp, 0)
        self._entity_to = get_spec_from_path(queryhelp, -1)
        super(UpdateRule, self).__init__(mode,
                                         max_iterations,
                                         track_edges=track_edges,
                                         track_visits=track_visits)

    def _init_run(self, entity_set):
        # Removing all other projections in the QueryBuilder instance:
        for tag in self._querybuilder._projections.keys():
            self._querybuilder._projections[tag] = []
        # priming querybuilder to add projection on the key I need:
        self._querybuilder.add_projection(
            self._last_tag, entity_set[self._entity_to].identifier)
        self._entity_to_identifier = entity_set[self._entity_to].identifier
        if self._track_edges:
            self._querybuilder.add_projection(
                self._first_tag, entity_set[self._entity_to].identifier)
            edge_set = entity_set._dict['{}_{}'.format(self._entity_from,
                                                       self._entity_to)]
            self._edge_label = '{}--{}'.format(self._first_tag, self._last_tag)
            self._edge_keys = tuple(
                [(self._first_tag, entity_set[self._entity_from].identifier),
                 (self._last_tag, entity_set[self._entity_to].identifier)] +
                [(self._edge_label, identifier)
                 for identifier in edge_set._additional_identifiers])
            try:
                self._querybuilder.add_projection(
                    self._edge_label, edge_set._additional_identifiers)
            except InputValidationError as e:
                raise KeyError(
                    "The key for the edge is invalid.\n"
                    "Are the entities really connected, or have you overwritten the edge-tag?"
                )

    def _load_results(self, target_set, operational_set):
        """
        :param target_set: The set to load the results into
        :param operational_set: Where the results originate from (walkers)
        """
        # I check that I have primary keys
        primkeys = operational_set[self._entity_from].get_keys()
        # Empty the target set, so that only these results are inside
        target_set.empty()
        if primkeys:
            self._querybuilder.add_filter(self._first_tag, {
                operational_set[self._entity_from].identifier: {
                    'in': primkeys
                }
            })
            qres = self._querybuilder.dict()
            # These are the new results returned by the query
            target_set[self._entity_to].add_entities([
                item[self._last_tag][self._entity_to_identifier]
                for item in qres
            ])
            if self._track_edges:
                target_set['{}_{}'.format(
                    self._entity_to, self._entity_to)].add_entities([
                        tuple(item[key1][key2]
                              for (key1, key2) in self._edge_keys)
                        for item in qres
                    ])