예제 #1
0
    def test_serialize_round_trip(self):
        """
        Test the serialization of a dictionary with Nodes in various data structure
        Also make sure that the serialized data is json-serializable
        """
        node_a = Node().store()
        node_b = Node().store()

        data = {
            'test': 1,
            'list': [1, 2, 3, node_a],
            'dict': {
                ('Si', ): node_b,
                'foo': 'bar'
            },
            'baz': 'aar'
        }

        serialized_data = serialize_data(data)
        json_dumped = json.dumps(serialized_data)
        deserialized_data = deserialize_data(serialized_data)

        # For now manual element-for-element comparison until we come up with general
        # purpose function that can equate two node instances properly
        self.assertEqual(data['test'], deserialized_data['test'])
        self.assertEqual(data['baz'], deserialized_data['baz'])
        self.assertEqual(data['list'][:3], deserialized_data['list'][:3])
        self.assertEqual(data['list'][3].uuid,
                         deserialized_data['list'][3].uuid)
        self.assertEqual(data['dict'][('Si', )].uuid,
                         deserialized_data['dict'][('Si', )].uuid)
예제 #2
0
    def test_put_object_from_file(self):
        """Test the `put_object_from_file` method."""
        key = os.path.join('subdir', 'a.txt')
        filepath = os.path.join(self.tempdir, key)
        content = self.get_file_content(key)

        node = Node()
        node.put_object_from_file(filepath, key)
        self.assertEqual(node.get_object_content(key), content)
예제 #3
0
    def test_attribute_existence(self):
        # I'm storing a value under key whatever:
        from aiida.orm.node import Node
        from aiida.orm.querybuilder import QueryBuilder
        val = 1.
        res_uuids = set()
        n1 = Node()
        n1._set_attr("whatever", 3.)
        n1._set_attr("test_case", "test_attribute_existence")
        n1.store()

        # I want all the nodes where whatever is smaller than 1. or there is no such value:

        qb = QueryBuilder()
        qb.append(Node,
                  filters={
                      'or': [{
                          'attributes': {
                              '!has_key': 'whatever'
                          }
                      }, {
                          'attributes.whatever': {
                              '<': val
                          }
                      }],
                  },
                  project='uuid')
        res_query = set([str(_[0]) for _ in qb.all()])
        self.assertEqual(res_query, res_uuids)
예제 #4
0
    def test_comment_add(self):
        """ Test adding a comment """
        from aiida.cmdline.commands.cmd_comment import add
        from aiida.orm import Node

        node = Node()
        node.store()

        result = CliRunner().invoke(
            add, ['-c{}'.format(COMMENT), str(node.pk)],
            catch_exceptions=False)
        self.assertEqual(result.exit_code, 0)

        comment = node.get_comments()
        self.assertEquals(len(comment), 1)
        self.assertEqual(comment[0]['content'], COMMENT)
예제 #5
0
    def test_load_node(self):
        """
        Test the functionality of load_node
        """
        node = Node().store()

        # Load through uuid
        loaded_node = load_node(node.uuid)
        self.assertEquals(loaded_node.uuid, node.uuid)

        # Load through pk
        loaded_node = load_node(node.pk)
        self.assertEquals(loaded_node.uuid, node.uuid)

        # Load through uuid explicitly
        loaded_node = load_node(uuid=node.uuid)
        self.assertEquals(loaded_node.uuid, node.uuid)

        # Load through pk explicitly
        loaded_node = load_node(pk=node.pk)
        self.assertEquals(loaded_node.uuid, node.uuid)

        # Load through partial uuid
        loaded_node = load_node(uuid=node.uuid[:2])
        self.assertEquals(loaded_node.uuid, node.uuid)

        # Load through partial uuid
        loaded_node = load_node(uuid=node.uuid[:10])
        self.assertEquals(loaded_node.uuid, node.uuid)

        with self.assertRaises(NotExistent):
            load_group('non-existent-uuid')
예제 #6
0
파일: query.py 프로젝트: asle85/aiida-core
    def test_operators_eq_lt_gt(self):
        from aiida.orm.querybuilder import QueryBuilder
        from aiida.orm import Node


        nodes = [Node() for _ in range(8)]


        nodes[0]._set_attr('fa', 1)
        nodes[1]._set_attr('fa', 1.0)
        nodes[2]._set_attr('fa', 1.01)
        nodes[3]._set_attr('fa', 1.02)
        nodes[4]._set_attr('fa', 1.03)
        nodes[5]._set_attr('fa', 1.04)
        nodes[6]._set_attr('fa', 1.05)
        nodes[7]._set_attr('fa', 1.06)

        [n.store() for n in nodes]

        self.assertEqual(QueryBuilder().append(Node, filters={'attributes.fa':{'<':1}}).count(), 0)
        self.assertEqual(QueryBuilder().append(Node, filters={'attributes.fa':{'==':1}}).count(), 2)
        self.assertEqual(QueryBuilder().append(Node, filters={'attributes.fa':{'<':1.02}}).count(), 3)
        self.assertEqual(QueryBuilder().append(Node, filters={'attributes.fa':{'<=':1.02}}).count(), 4)
        self.assertEqual(QueryBuilder().append(Node, filters={'attributes.fa':{'>':1.02}}).count(), 4)
        self.assertEqual(QueryBuilder().append(Node, filters={'attributes.fa':{'>=':1.02}}).count(), 5)
예제 #7
0
    def test_joins2(self):
        from aiida.orm import Node, Data, Calculation
        from aiida.orm.querybuilder import QueryBuilder
        # Creating n1, who will be a parent:

        students = [Node() for i in range(10)]
        advisors = [Node() for i in range(3)]
        for i, a in enumerate(advisors):
            a.label = 'advisor {}'.format(i)
            a._set_attr('advisor_id', i)

        for n in advisors + students:
            n.store()

        # advisor 0 get student 0, 1
        for i in (0, 1):
            students[i].add_link_from(advisors[0], label='is_advisor')

        # advisor 1 get student 3, 4
        for i in (3, 4):
            students[i].add_link_from(advisors[1], label='is_advisor')

        # advisor 2 get student 5, 6, 7
        for i in (5, 6, 7):
            students[i].add_link_from(advisors[2], label='is_advisor')

        # let's add a differnt relationship than advisor:
        students[9].add_link_from(advisors[2], label='lover')

        self.assertEqual(
            QueryBuilder().append(Node).append(Node,
                                               edge_filters={
                                                   'label': 'is_advisor'
                                               },
                                               tag='student').count(), 7)

        for adv_id, number_students in zip(range(3), (2, 2, 3)):
            self.assertEqual(
                QueryBuilder().append(Node,
                                      filters={
                                          'attributes.advisor_id': adv_id
                                      }).append(Node,
                                                edge_filters={
                                                    'label': 'is_advisor'
                                                },
                                                tag='student').count(),
                number_students)
예제 #8
0
 def test_dynamic_output(self):
     """Test a process spec with dynamic output enabled."""
     node = Node()
     data = Data()
     self.assertIsNotNone(self.spec.outputs.validate({'key': 'foo'}))
     self.assertIsNotNone(self.spec.outputs.validate({'key': 5}))
     self.assertIsNotNone(self.spec.outputs.validate({'key': node}))
     self.assertIsNone(self.spec.outputs.validate({'key': data}))
예제 #9
0
    def setUpClass(cls):
        """
        Create some code to test the NodeParamType parameter type for the command line infrastructure
        We create an initial code with a random name and then on purpose create two code with a name
        that matches exactly the ID and UUID, respectively, of the first one. This allows us to test
        the rules implemented to solve ambiguities that arise when determing the identifier type
        """
        super(TestNodeParamType, cls).setUpClass()

        cls.param = NodeParamType()
        cls.entity_01 = Node().store()
        cls.entity_02 = Node().store()
        cls.entity_03 = Node().store()

        cls.entity_01.label = 'data_01'
        cls.entity_02.label = str(cls.entity_01.pk)
        cls.entity_03.label = str(cls.entity_01.uuid)
예제 #10
0
    def test_put_object_from_filelike(self):
        """Test the `put_object_from_filelike` method."""
        key = os.path.join('subdir', 'a.txt')
        filepath = os.path.join(self.tempdir, key)
        content = self.get_file_content(key)

        with io.open(filepath, 'r') as handle:
            node = Node()
            node.put_object_from_filelike(handle, key)
            self.assertEqual(node.get_object_content(key), content)

        key = os.path.join('subdir', 'nested', 'deep.txt')
        filepath = os.path.join(self.tempdir, key)
        content = self.get_file_content(key)

        with io.open(filepath, 'r') as handle:
            node = Node()
            node.put_object_from_filelike(handle, key)
            self.assertEqual(node.get_object_content(key), content)
예제 #11
0
    def test_dynamic_output(self):
        from aiida.orm import Node
        from aiida.orm.data import Data

        n = Node()
        d = Data()
        self.assertFalse(self.spec.validate_outputs({'key': 'foo'})[0])
        self.assertFalse(self.spec.validate_outputs({'key': 5})[0])
        self.assertFalse(self.spec.validate_outputs({'key': n})[0])
        self.assertTrue(self.spec.validate_outputs({'key': d})[0])
예제 #12
0
    def test_dynamic_output(self):
        from aiida.orm import Node
        from aiida.orm.data import Data

        n = Node()
        d = Data()
        port = self.spec.get_dynamic_output()
        self.assertFalse(port.validate("foo")[0])
        self.assertFalse(port.validate(5)[0])
        self.assertFalse(port.validate(n)[0])
        self.assertTrue(port.validate(d)[0])
예제 #13
0
    def test_statistics_default_class(self):
        """
        Test if the statistics query works properly.

        I try to implement it in a way that does not depend on the past state.
        """
        from aiida.orm import Node, DataFactory, Calculation
        from collections import defaultdict
        from aiida.backends.general.abstractqueries import AbstractQueryManager

        def store_and_add(n, statistics):
            n.store()
            statistics['total'] += 1
            statistics['types'][n._plugin_type_string] += 1
            statistics['ctime_by_day'][n.ctime.strftime('%Y-%m-%d')] += 1

        class QueryManagerDefault(AbstractQueryManager):
            pass

        qmanager_default = QueryManagerDefault()

        current_db_statistics = qmanager_default.get_creation_statistics()
        types = defaultdict(int)
        types.update(current_db_statistics['types'])
        ctime_by_day = defaultdict(int)
        ctime_by_day.update(current_db_statistics['ctime_by_day'])

        expected_db_statistics = {
            'total': current_db_statistics['total'],
            'types': types,
            'ctime_by_day': ctime_by_day
        }

        ParameterData = DataFactory('parameter')

        store_and_add(Node(), expected_db_statistics)
        store_and_add(ParameterData(), expected_db_statistics)
        store_and_add(ParameterData(), expected_db_statistics)
        store_and_add(Calculation(), expected_db_statistics)

        new_db_statistics = qmanager_default.get_creation_statistics()
        # I only check a few fields
        new_db_statistics = {
            k: v
            for k, v in new_db_statistics.iteritems()
            if k in expected_db_statistics
        }

        expected_db_statistics = {
            k: dict(v) if isinstance(v, defaultdict) else v
            for k, v in expected_db_statistics.iteritems()
        }

        self.assertEquals(new_db_statistics, expected_db_statistics)
예제 #14
0
    def setUpClass(cls, *args, **kwargs):
        super(TestVerdiRehash, cls).setUpClass(*args, **kwargs)
        from aiida.orm import Node
        from aiida.orm.data.bool import Bool
        from aiida.orm.data.float import Float
        from aiida.orm.data.int import Int

        cls.node_base = Node().store()
        cls.node_bool_true = Bool(True).store()
        cls.node_bool_false = Bool(False).store()
        cls.node_float = Float(1.0).store()
        cls.node_int = Int(1).store()
예제 #15
0
파일: query.py 프로젝트: zooks97/aiida_core
    def test_ordering_limits_offsets_of_results_for_SQLA(self):
        from aiida.orm import Node
        from aiida.orm.querybuilder import QueryBuilder
        # Creating 10 nodes with an attribute that can be ordered
        for i in range(10):
            n = Node()
            n._set_attr('foo', i)
            n.store()
        qb = QueryBuilder().append(
                Node, project='attributes.foo'
            ).order_by(
                {Node:{'attributes.foo':{'cast':'i'}}}
            )
        res = list(zip(*qb.all())[0])
        self.assertEqual(res, range(10))

        # Now applying an offset:
        qb.offset(5)
        res = list(zip(*qb.all())[0])
        self.assertEqual(res, range(5,10))

        # Now also applying a limit:
        qb.limit(3)
        res = list(zip(*qb.all())[0])
        self.assertEqual(res, range(5,8))
예제 #16
0
    def test_list_object_names(self):
        """Test the `list_object_names` method."""
        node = Node()
        node.put_object_from_tree(self.tempdir, '')

        self.assertEqual(sorted(node.list_object_names()), ['c.txt', 'subdir'])
        self.assertEqual(sorted(node.list_object_names('subdir')), ['a.txt', 'b.txt', 'nested'])
예제 #17
0
    def test_date(self):
        from aiida.orm.querybuilder import QueryBuilder
        from aiida.utils import timezone
        from datetime import timedelta
        from aiida.orm.node import Node
        n = Node()
        now = timezone.now()
        n._set_attr('now', now)
        n.store()

        qb = QueryBuilder().append(Node,
                                   filters={
                                       'attributes.now': {
                                           "and": [
                                               {
                                                   ">":
                                                   now - timedelta(seconds=1)
                                               },
                                               {
                                                   "<":
                                                   now + timedelta(seconds=1)
                                               },
                                           ]
                                       }
                                   })
        self.assertEqual(qb.count(), 1)
예제 #18
0
    def setUpClass(cls, *args, **kwargs):
        super(TestVerdiExport, cls).setUpClass(*args, **kwargs)
        from aiida.orm import Code, Computer, Group, Node

        cls.computer = Computer(name='comp',
                                hostname='localhost',
                                transport_type='local',
                                scheduler_type='direct',
                                workdir='/tmp/aiida').store()

        cls.code = Code(remote_computer_exec=(cls.computer,
                                              '/bin/true')).store()
        cls.group = Group(name='test_group').store()
        cls.node = Node().store()
예제 #19
0
    def test_detect_invalid_nodes_unknown_node_type(self):
        """Test `verdi database integrity detect-invalid-nodes` when node type is invalid."""
        result = self.cli_runner.invoke(cmd_database.detect_invalid_nodes, [])
        self.assertEqual(result.exit_code, 0)
        self.assertClickResultNoException(result)

        # Create a node with invalid type: a base Node type string is considered invalid
        # Note that there is guard against storing base Nodes for this reason, which we temporarily disable
        Node._storable = True
        Node().store()
        Node._storable = False

        result = self.cli_runner.invoke(cmd_database.detect_invalid_nodes, [])
        self.assertNotEqual(result.exit_code, 0)
        self.assertIsNotNone(result.exception)
예제 #20
0
    def test_comment_remove_all(self):
        """ Test removing all comments from a node """
        from aiida.cmdline.commands.cmd_comment import remove
        from aiida.orm import Node

        node = Node()
        node.store()
        for _ in range(10):
            node.add_comment(COMMENT)

        self.assertEqual(len(node.get_comments()), 10)

        result = CliRunner().invoke(remove, [str(node.pk), '--all', '--force'],
                                    catch_exceptions=False)
        self.assertEqual(result.exit_code, 0)

        self.assertEqual(len(node.get_comments()), 0)
예제 #21
0
def test_open_wrapper():
    """Test the wrapper around the return value of ``Node.open``.

    This should be remove in v2.0.0 because the wrapper should be removed.
    """
    filename = 'test'
    node = Node()
    node.put_object_from_filelike(io.StringIO('test'), filename)

    # Both `iter` and `next` should not raise
    next(node.open(filename))
    iter(node.open(filename))
    node.open(filename).__next__()
    node.open(filename).__iter__()
예제 #22
0
    def test_erase_unstored(self):
        """
        Test that _repository.erase removes the content of an unstored
        node.
        """
        node = Node()
        node.put_object_from_tree(self.tempdir, '')

        self.assertEqual(sorted(node.list_object_names()), ['c.txt', 'subdir'])
        self.assertEqual(sorted(node.list_object_names('subdir')),
                         ['a.txt', 'b.txt', 'nested'])

        node._repository.erase()  # pylint: disable=protected-access
        self.assertEqual(node.list_object_names(), [])
예제 #23
0
파일: query.py 프로젝트: asle85/aiida-core
    def test_create_node_and_query(self):
        from aiida.orm import Node
        from aiida.orm.querybuilder import QueryBuilder


        import random


        for i in range(100):
            n = Node()
            n.store()

        for idx, item in enumerate(QueryBuilder().append(Node,project=['id','label']).iterall(batch_size=10)):
            if idx % 10 == 10:
                print "creating new node"
                n = Node()
                n.store()
        self.assertEqual(idx,99)
        self.assertTrue(len(QueryBuilder().append(Node,project=['id','label']).all(batch_size=10)) > 99)
예제 #24
0
    def test_comment_show(self):
        """ Test showing an existing comment """
        from aiida.cmdline.commands.cmd_comment import show
        from aiida.orm import Node

        node = Node()
        node.store()
        node.add_comment(COMMENT)

        result = CliRunner().invoke(show, [str(node.pk)],
                                    catch_exceptions=False)
        self.assertNotEqual(result.output.find(COMMENT), -1)
        self.assertEqual(result.exit_code, 0)
예제 #25
0
    def test_comment_remove(self):
        """ Test removing a comment """
        from aiida.cmdline.commands.cmd_comment import remove
        from aiida.orm import Node

        node = Node()
        node.store()
        comment_id = node.add_comment(COMMENT)

        self.assertEquals(len(node.get_comments()), 1)

        result = CliRunner().invoke(
            remove, [str(node.pk), str(comment_id), '--force'],
            catch_exceptions=False)
        self.assertEqual(result.exit_code, 0)

        self.assertEquals(len(node.get_comments()), 0)
예제 #26
0
    def test_list_behavior(self):
        from aiida.orm import Node
        from aiida.orm.querybuilder import QueryBuilder

        for i in range(4):
            Node().store()
        self.assertEqual(len(QueryBuilder().append(Node).all()), 4)
        self.assertEqual(len(QueryBuilder().append(Node, project='*').all()),
                         4)
        self.assertEqual(
            len(QueryBuilder().append(Node, project=['*', 'id']).all()), 4)
        self.assertEqual(
            len(QueryBuilder().append(Node, project=['id']).all()), 4)
        self.assertEqual(len(QueryBuilder().append(Node).dict()), 4)
        self.assertEqual(len(QueryBuilder().append(Node, project='*').dict()),
                         4)
        self.assertEqual(
            len(QueryBuilder().append(Node, project=['*', 'id']).dict()), 4)
        self.assertEqual(
            len(QueryBuilder().append(Node, project=['id']).dict()), 4)
        self.assertEqual(len(list(QueryBuilder().append(Node).iterall())), 4)
        self.assertEqual(
            len(list(QueryBuilder().append(Node, project='*').iterall())), 4)
        self.assertEqual(
            len(
                list(QueryBuilder().append(Node,
                                           project=['*', 'id']).iterall())), 4)
        self.assertEqual(
            len(list(QueryBuilder().append(Node, project=['id']).iterall())),
            4)
        self.assertEqual(len(list(QueryBuilder().append(Node).iterdict())), 4)
        self.assertEqual(
            len(list(QueryBuilder().append(Node, project='*').iterdict())), 4)
        self.assertEqual(
            len(
                list(QueryBuilder().append(Node, project=['*',
                                                          'id']).iterdict())),
            4)
        self.assertEqual(
            len(list(QueryBuilder().append(Node, project=['id']).iterdict())),
            4)
예제 #27
0
    def setUp(self):
        """Sets up a few nodes to play around with."""
        from aiida.orm import Node

        nodes = {key: Node().store() for key in self.node_labels}

        nodes['wf'].add_link_from(nodes['in1'], link_type=LinkType.INPUT)
        nodes['slave1'].add_link_from(nodes['in1'], link_type=LinkType.INPUT)
        nodes['slave1'].add_link_from(nodes['in2'], link_type=LinkType.INPUT)
        nodes['slave2'].add_link_from(nodes['in2'], link_type=LinkType.INPUT)
        nodes['slave1'].add_link_from(nodes['wf'], link_type=LinkType.CALL)
        nodes['slave2'].add_link_from(nodes['wf'], link_type=LinkType.CALL)
        nodes['outp1'].add_link_from(nodes['slave1'],
                                     link_type=LinkType.CREATE)
        nodes['outp2'].add_link_from(nodes['slave2'],
                                     link_type=LinkType.CREATE)
        nodes['outp2'].add_link_from(nodes['wf'], link_type=LinkType.RETURN)
        nodes['outp3'].add_link_from(nodes['wf'], link_type=LinkType.CREATE)
        nodes['outp4'].add_link_from(nodes['wf'], link_type=LinkType.RETURN)
        self.nodes = nodes

        self.runner = CliRunner()
예제 #28
0
파일: query.py 프로젝트: zooks97/aiida_core
    def test_subclassing(self):
        from aiida.orm.data.structure import StructureData
        from aiida.orm.data.parameter import ParameterData
        from aiida.orm import Node, Data
        from aiida.orm.querybuilder import QueryBuilder
        s = StructureData()
        s._set_attr('cat', 'miau')
        s.store()

        d = Data()
        d._set_attr('cat', 'miau')
        d.store()

        p = ParameterData(dict=dict(cat='miau'))
        p.store()

        n = Node()
        n._set_attr('cat', 'miau')
        n.store()

        # Now when asking for a node with attr.cat==miau, I want 4 esults:
        qb = QueryBuilder().append(Node, filters={'attributes.cat': 'miau'})
        self.assertEqual(qb.count(), 4)

        qb = QueryBuilder().append(Data, filters={'attributes.cat': 'miau'})
        self.assertEqual(qb.count(), 3)

        # If I'm asking for the specific lowest subclass, I want one result
        for cls in (StructureData, ParameterData):
            qb = QueryBuilder().append(cls, filters={'attributes.cat': 'miau'})
            self.assertEqual(qb.count(), 1)

        # Now I am not allow the subclassing, which should give 1 result for each
        for cls in (StructureData, ParameterData, Node, Data):
            qb = QueryBuilder().append(cls,
                                       filters={'attributes.cat': 'miau'},
                                       subclassing=False)
            self.assertEqual(qb.count(), 1)
예제 #29
0
def get_closest_parents(pks, *args, **kwargs):
    """
    Get the closest parents dbnodes of a set of nodes.

    :param pks: one pk or an iterable of pks of nodes
    :param chunk_size: we chunk the pks into groups of this size,
        to optimize the speed (default=50)
    :param print_progress: print the the progression if True (default=False).
    :param args: additional query parameters
    :param kwargs: additional query parameters
    :returns: a dictionary of the form
        pk1: pk of closest parent of node with pk1,
        pk2: pk of closest parent of node with pk2

    .. note:: It works also if pks is a list of nodes rather than their pks

    .. todo:: find a way to always get a parent (when there is one) from each pk.
        Now, when the same parent has several children in pks, only
        one of them is kept. This is a BUG, related to the use of a dictionary
        (children_dict, see below...).
        For now a work around is to use chunk_size=1.

    """
    from aiida.orm import Node
    from aiida.backends.djsite.db import models
    from copy import deepcopy
    from aiida.common.utils import grouper
    try:
        the_pks = list(pks)
    except TypeError:
        the_pks = list(set([pks]))

    chunk_size = kwargs.pop('chunk_size', 50)
    print_progress = kwargs.pop('print_progress', False)

    result_dict = {}
    all_chunk_pks = grouper(chunk_size, the_pks)
    if print_progress:
        print "Chunk size:", chunk_size

    for i, chunk_pks in enumerate(all_chunk_pks):
        if print_progress:
            print "Dealing with chunk #", i
        result_chunk_dict = {}
        q_pks = Node.query(pk__in=chunk_pks).values_list('pk', flat=True)
        # Now I am looking for parents (depth=0) of the nodes in the chunk:
        q_inputs = models.DbNode.objects.filter(
            outputs__pk__in=q_pks).distinct()
        depth = -1  # to be consistent with the DbPath depth (=0 for direct inputs)
        children_dict = dict([
            (k, v) for k, v in q_inputs.values_list('pk', 'outputs__pk')
            if v in q_pks
        ])
        # While I haven't found a closest ancestor for every member of chunk_pks:
        while q_inputs.count() > 0 and len(result_chunk_dict) < len(chunk_pks):
            depth += 1
            q = q_inputs.filter(*args, **kwargs)
            if q.count() > 0:
                result_chunk_dict.update(
                    dict([(children_dict[k], k)
                          for k in q.values_list('pk', flat=True)
                          if children_dict[k] not in result_chunk_dict]))
            inputs = list(q_inputs.values_list('pk', flat=True))
            q_inputs = models.DbNode.objects.filter(
                outputs__pk__in=inputs).distinct()
            q_inputs_dict = dict([
                (k, children_dict[v])
                for k, v in q_inputs.values_list('pk', 'outputs__pk')
                if v in inputs
            ])
            children_dict = deepcopy(q_inputs_dict)

        result_dict.update(result_chunk_dict)

    return result_dict
예제 #30
0
파일: query.py 프로젝트: asle85/aiida-core
    def test_query_path(self):

        from aiida.orm.querybuilder import QueryBuilder
        from aiida.orm import Node

        n1 = Node()
        n1.label='n1'
        n1.store()
        n2 = Node()
        n2.label='n2'
        n2.store()
        n3 = Node()
        n3.label='n3'
        n3.store()
        n4 = Node()
        n4.label='n4'
        n4.store()
        n5 = Node()
        n5.label='n5'
        n5.store()
        n6 = Node()
        n6.label='n6'
        n6.store()
        n7 = Node()
        n7.label='n7'
        n7.store()
        n8 = Node()
        n8.label='n8'
        n8.store()
        n9 = Node()
        n9.label='n9'
        n9.store()

        # I create a strange graph, inserting links in a order
        # such that I often have to create the transitive closure
        # between two graphs
        n3.add_link_from(n2)
        n2.add_link_from(n1)
        n5.add_link_from(n3)
        n5.add_link_from(n4)
        n4.add_link_from(n2)

        n7.add_link_from(n6)
        n8.add_link_from(n7)


        for with_dbpath in (True, False):

            # Yet, no links from 1 to 8
            self.assertEquals(
                    QueryBuilder(with_dbpath=with_dbpath).append(
                        Node, filters={'id':n1.pk}, tag='anc'
                    ).append(Node, descendant_of='anc',  filters={'id':n8.pk}
                    ).count(), 0)


            self.assertEquals(
                    QueryBuilder(with_dbpath=with_dbpath).append(
                        Node, filters={'id':n8.pk}, tag='desc'
                    ).append(Node, ancestor_of='desc',  filters={'id':n1.pk}
                    ).count(), 0)


        n6.add_link_from(n5)
        # Yet, now 2 links from 1 to 8


        for with_dbpath in (True, False):

            self.assertEquals(
                QueryBuilder(with_dbpath=with_dbpath).append(
                        Node, filters={'id':n1.pk}, tag='anc'
                    ).append(Node, descendant_of='anc',  filters={'id':n8.pk}
                    ).count(), 2
                )

            self.assertEquals(
                    QueryBuilder(with_dbpath=with_dbpath).append(
                        Node, filters={'id':n8.pk}, tag='desc'
                    ).append(Node, ancestor_of='desc',  filters={'id':n1.pk}
                    ).count(), 2)

        qb = QueryBuilder(with_dbpath=False,expand_path=True).append(
                Node, filters={'id':n8.pk}, tag='desc',
            ).append(Node, ancestor_of='desc', edge_project='path', filters={'id':n1.pk})
        queried_path_set = set([frozenset(p) for p, in qb.all()])

        paths_there_should_be = set([
                frozenset([n1.pk, n2.pk, n3.pk, n5.pk, n6.pk, n7.pk, n8.pk]),
                frozenset([n1.pk, n2.pk, n4.pk, n5.pk, n6.pk, n7.pk, n8.pk])
            ])

        self.assertTrue(queried_path_set == paths_there_should_be)

        qb = QueryBuilder(with_dbpath=False, expand_path=True).append(
                Node, filters={'id':n1.pk}, tag='anc'
            ).append(
                Node, descendant_of='anc',  filters={'id':n8.pk}, edge_project='path'
            )

        self.assertTrue(set(
                [frozenset(p) for p, in qb.all()]
            ) == set(
                [frozenset([n1.pk, n2.pk, n3.pk, n5.pk, n6.pk, n7.pk, n8.pk]),
                frozenset([n1.pk, n2.pk, n4.pk, n5.pk, n6.pk, n7.pk, n8.pk])]
            ))

        n7.add_link_from(n9)
        # Still two links...

        for with_dbpath in (True, False):
            self.assertEquals(
                QueryBuilder(with_dbpath=with_dbpath).append(
                        Node, filters={'id':n1.pk}, tag='anc'
                    ).append(Node, descendant_of='anc',  filters={'id':n8.pk}
                    ).count(), 2
                )

            self.assertEquals(
                QueryBuilder(with_dbpath=with_dbpath).append(
                        Node, filters={'id':n8.pk}, tag='desc'
                    ).append(Node, ancestor_of='desc',  filters={'id':n1.pk}
                    ).count(), 2)
        n9.add_link_from(n6)
        # And now there should be 4 nodes
        for with_dbpath in (True, False):
            self.assertEquals(
                QueryBuilder(with_dbpath=with_dbpath).append(
                        Node, filters={'id':n1.pk}, tag='anc'
                    ).append(Node, descendant_of='anc',  filters={'id':n8.pk}
                    ).count(), 4)

            self.assertEquals(
                QueryBuilder(with_dbpath=with_dbpath).append(
                        Node, filters={'id':n8.pk}, tag='desc'
                    ).append(Node, ancestor_of='desc',  filters={'id':n1.pk}
                    ).count(), 4)


        for with_dbpath in (True, False):
            qb = QueryBuilder(with_dbpath=True).append(
                    Node, filters={'id':n1.pk}, tag='anc'
                ).append(
                    Node, descendant_of='anc',  filters={'id':n8.pk}, edge_tag='edge'
                )
            qb.add_projection('edge', 'depth')
            self.assertTrue(set(zip(*qb.all())[0]), set([5,6]))
            qb.add_filter('edge', {'depth':6})
            self.assertTrue(set(zip(*qb.all())[0]), set([6]))