def test_queryhelp(self): """ Here I test the queryhelp by seeing whether results are the same as using the append method. I also check passing of tuples. """ from aiida.orm.data.structure import StructureData from aiida.orm.data.parameter import ParameterData from aiida.orm.data import Data from aiida.orm.querybuilder import QueryBuilder from aiida.orm.group import Group from aiida.orm.computer import Computer g = Group(name='helloworld').store() for cls in (StructureData, ParameterData, Data): obj = cls() obj._set_attr('foo-qh2', 'bar') obj.store() g.add_nodes(obj) for cls, expected_count, subclassing in ( (StructureData, 1, True), (ParameterData, 1, True), (Data, 3, True), (Data, 1, False), ((ParameterData, StructureData), 2, True), ((ParameterData, StructureData), 2, False), ((ParameterData, Data), 2, False), ((ParameterData, Data), 3, True), ((ParameterData, Data, StructureData), 3, False), ): qb = QueryBuilder() qb.append(cls, filters={'attributes.foo-qh2': 'bar'}, subclassing=subclassing, project='uuid') self.assertEqual(qb.count(), expected_count) qh = qb.get_json_compatible_queryhelp() qb_new = QueryBuilder(**qh) self.assertEqual(qb_new.count(), expected_count) self.assertEqual(sorted([uuid for uuid, in qb.all()]), sorted([uuid for uuid, in qb_new.all()])) qb = QueryBuilder().append(Group, filters={'name': 'helloworld'}) self.assertEqual(qb.count(), 1) qb = QueryBuilder().append((Group, ), filters={'name': 'helloworld'}) self.assertEqual(qb.count(), 1) qb = QueryBuilder().append(Computer, ) self.assertEqual(qb.count(), 1) qb = QueryBuilder().append(cls=(Computer, )) self.assertEqual(qb.count(), 1)
class UpdateRule(Operation): def __init__(self, querybuilder, mode=MODES.APPEND, max_iterations=1, track_edges=False, track_visits=True): def get_spec_from_path(queryhelp, idx): if (queryhelp['path'][idx]['type'].startswith('node') or queryhelp['path'][idx]['type'].startswith('data') or queryhelp['path'][idx]['type'] == ''): return 'nodes' elif queryhelp['path'][idx]['type'] == 'group': return 'groups' else: raise Exception("not understood entity from ( {} )".format( queryhelp['path'][0]['type'])) queryhelp = querybuilder.get_json_compatible_queryhelp() for pathspec in queryhelp['path']: if not pathspec['type']: pathspec['type'] = 'node.Node.' self._querybuilder = QueryBuilder(**queryhelp) queryhelp = self._querybuilder.get_json_compatible_queryhelp() self._first_tag = queryhelp['path'][0]['tag'] self._last_tag = queryhelp['path'][-1]['tag'] self._entity_from = get_spec_from_path(queryhelp, 0) self._entity_to = get_spec_from_path(queryhelp, -1) super(UpdateRule, self).__init__(mode, max_iterations, track_edges=track_edges, track_visits=track_visits) def _init_run(self, entity_set): # Removing all other projections in the QueryBuilder instance: for tag in self._querybuilder._projections.keys(): self._querybuilder._projections[tag] = [] # priming querybuilder to add projection on the key I need: self._querybuilder.add_projection( self._last_tag, entity_set[self._entity_to].identifier) self._entity_to_identifier = entity_set[self._entity_to].identifier if self._track_edges: self._querybuilder.add_projection( self._first_tag, entity_set[self._entity_to].identifier) edge_set = entity_set._dict['{}_{}'.format(self._entity_from, self._entity_to)] self._edge_label = '{}--{}'.format(self._first_tag, self._last_tag) self._edge_keys = tuple( [(self._first_tag, entity_set[self._entity_from].identifier), (self._last_tag, entity_set[self._entity_to].identifier)] + [(self._edge_label, identifier) for identifier in edge_set._additional_identifiers]) try: self._querybuilder.add_projection( self._edge_label, edge_set._additional_identifiers) except InputValidationError as e: raise KeyError( "The key for the edge is invalid.\n" "Are the entities really connected, or have you overwritten the edge-tag?" ) def _load_results(self, target_set, operational_set): """ :param target_set: The set to load the results into :param operational_set: Where the results originate from (walkers) """ # I check that I have primary keys primkeys = operational_set[self._entity_from].get_keys() # Empty the target set, so that only these results are inside target_set.empty() if primkeys: self._querybuilder.add_filter(self._first_tag, { operational_set[self._entity_from].identifier: { 'in': primkeys } }) qres = self._querybuilder.dict() # These are the new results returned by the query target_set[self._entity_to].add_entities([ item[self._last_tag][self._entity_to_identifier] for item in qres ]) if self._track_edges: target_set['{}_{}'.format( self._entity_to, self._entity_to)].add_entities([ tuple(item[key1][key2] for (key1, key2) in self._edge_keys) for item in qres ])