Exemplo n.º 1
0
    def test_no_node_export(self, temp_dir):
        """Test migration of export file that has no Nodes"""
        input_file = get_archive_file('export_v0.3_no_Nodes.aiida', **self.external_archive)
        output_file = os.path.join(temp_dir, 'output_file.aiida')

        # Known entities
        computer_uuids = [self.computer.uuid]  # pylint: disable=no-member
        user_emails = [orm.User.objects.get_default().email]

        # Known export file content used for checks
        node_count = 0
        computer_count = 1 + 1  # localhost is always present
        computer_uuids.append('4f33c6fd-b624-47df-9ffb-a58f05d323af')
        user_emails.append('aiida@localhost')

        # Perform the migration
        migrate_archive(input_file, output_file)

        # Load the migrated file
        import_data(output_file, silent=True)

        # Check known number of entities is present
        self.assertEqual(orm.QueryBuilder().append(orm.Node).count(), node_count)
        self.assertEqual(orm.QueryBuilder().append(orm.Computer).count(), computer_count)

        # Check unique identifiers
        computers = orm.QueryBuilder().append(orm.Computer, project=['uuid']).all()[0][0]
        users = orm.QueryBuilder().append(orm.User, project=['email']).all()[0][0]
        self.assertIn(computers, computer_uuids)
        self.assertIn(users, user_emails)
Exemplo n.º 2
0
    def test_delete_many_ids(self):
        """Test `delete_many` method filtering on both `id` and `uuid`"""
        comment1 = self.create_comment()
        comment2 = self.create_comment()
        comment3 = self.create_comment()
        comment_uuids = []
        for comment in [comment1, comment2, comment3]:
            comment.store()
            comment_uuids.append(str(comment.uuid))

        # Make sure they exist
        count_comments_found = orm.QueryBuilder().append(orm.Comment, filters={'uuid': {'in': comment_uuids}}).count()
        self.assertEqual(
            count_comments_found,
            len(comment_uuids),
            msg='There should be {} Comments, instead {} Comment(s) was/were found'.format(
                len(comment_uuids), count_comments_found
            )
        )

        # Delete last two comments (comment2, comment3)
        filters = {'or': [{'id': comment2.id}, {'uuid': str(comment3.uuid)}]}
        self.backend.comments.delete_many(filters=filters)

        # Check they were deleted
        builder = orm.QueryBuilder().append(orm.Comment, filters={'uuid': {'in': comment_uuids}}, project='uuid').all()
        found_comments_uuid = [_[0] for _ in builder]
        self.assertEqual([comment_uuids[0]], found_comments_uuid)
Exemplo n.º 3
0
    def test_delete_many_user_id(self):
        """Test `delete_many` method filtering on `user_id`"""
        # Create comments and separate user
        user_two = self.backend.users.create(
            email='tester_two@localhost').store()
        comment1 = self.create_comment(user=user_two)
        comment2 = self.create_comment()
        comment3 = self.create_comment()
        comment_uuids = []
        for comment in [comment1, comment2, comment3]:
            comment.store()
            comment_uuids.append(str(comment.uuid))

        # Make sure they exist
        builder = orm.QueryBuilder().append(orm.Comment, project='uuid')
        self.assertGreater(builder.count(), 0)
        found_comments_uuid = [_[0] for _ in builder.all()]
        for comment_uuid in comment_uuids:
            self.assertIn(comment_uuid, found_comments_uuid)

        # Delete last comments for `self.user`
        filters = {'user_id': self.user.id}
        self.backend.comments.delete_many(filters=filters)

        # Check they were deleted
        builder = orm.QueryBuilder().append(orm.Comment, project='uuid')
        found_comments_uuid = [_[0] for _ in builder.all()]
        self.assertGreater(builder.count(), 0)
        for comment_uuid in comment_uuids[1:]:
            self.assertNotIn(comment_uuid, found_comments_uuid)

        # Make sure the first comment (comment1) was not deleted
        self.assertIn(comment_uuids[0], found_comments_uuid)
Exemplo n.º 4
0
def delete_database_proportion(fraction=0.0):
    """Description pending"""
    from random import shuffle
    from aiida import orm
    from aiida.tools import delete_nodes

    query = orm.QueryBuilder()
    query.append(orm.Node, project=['id'])
    initial_nodecount = query.count()

    nodepk_list = [nodepk for nodepk, in query.all()]
    shuffle(nodepk_list)

    numto_delete = int(fraction * len(nodepk_list))
    numto_delete = min(numto_delete, initial_nodecount)

    delete_nodes(nodepk_list[0:numto_delete], dry_run=False)

    query = orm.QueryBuilder()
    query.append(orm.Node)
    final_nodecount = query.count()

    return {
        'initial_nodecount': initial_nodecount,
        'final_nodecount': final_nodecount
    }
Exemplo n.º 5
0
    def test_group_import_existing(self, temp_dir):
        """
        Testing what happens when I try to import a group that already exists in the
        database. This should raise an appropriate exception
        """
        grouplabel = 'node_group_existing'

        # Create another user
        new_email = '[email protected]'
        user = orm.User(email=new_email)
        user.store()

        # Create a structure data node
        sd1 = orm.StructureData()
        sd1.user = user
        sd1.label = 'sd'
        sd1.store()

        # Create a group and add the data inside
        group = orm.Group(label=grouplabel)
        group.store()
        group.add_nodes([sd1])

        # At this point we export the generated data
        filename = os.path.join(temp_dir, 'export1.tar.gz')
        export([group], outfile=filename, silent=True)
        self.clean_db()
        self.insert_data()

        # Creating a group of the same name
        group = orm.Group(label='node_group_existing')
        group.store()
        import_data(filename, silent=True)
        # The import should have created a new group with a suffix
        # I check for this:
        builder = orm.QueryBuilder().append(
            orm.Group, filters={'label': {
                'like': grouplabel + '%'
            }})
        self.assertEqual(builder.count(), 2)
        # Now I check for the group having one member, and whether the name is different:
        builder = orm.QueryBuilder()
        builder.append(orm.Group,
                       filters={'label': {
                           'like': grouplabel + '%'
                       }},
                       tag='g',
                       project='label')
        builder.append(orm.StructureData, with_group='g')
        self.assertEqual(builder.count(), 1)
        # I check that the group name was changed:
        self.assertTrue(builder.all()[0][0] != grouplabel)
        # I import another name, the group should not be imported again
        import_data(filename, silent=True)
        builder = orm.QueryBuilder()
        builder.append(orm.Group,
                       filters={'label': {
                           'like': grouplabel + '%'
                       }})
        self.assertEqual(builder.count(), 2)
Exemplo n.º 6
0
    def test_exclude_logs_flag(self, temp_dir):
        """Test that the `include_logs` argument for `export` works."""
        log_msg = 'Testing logging of critical failure'

        # Create node
        calc = orm.CalculationNode()
        calc.store()
        calc.seal()

        # Create log message
        calc.logger.critical(log_msg)

        # Save uuids prior to export
        calc_uuid = calc.uuid

        # Export, excluding logs
        export_file = os.path.join(temp_dir, 'export.tar.gz')
        export([calc], outfile=export_file, silent=True, include_logs=False)

        # Clean database and reimport exported data
        self.reset_database()
        import_data(export_file, silent=True)

        # Finding all the log messages
        import_calcs = orm.QueryBuilder().append(orm.CalculationNode,
                                                 project=['uuid']).all()
        import_logs = orm.QueryBuilder().append(orm.Log,
                                                project=['uuid']).all()

        # There should be exactly: 1 orm.CalculationNode, 0 Logs
        self.assertEqual(len(import_calcs), 1)
        self.assertEqual(len(import_logs), 0)

        # Check it's the correct node
        self.assertEqual(str(import_calcs[0][0]), calc_uuid)
Exemplo n.º 7
0
    def test_delete_many_ids(self):
        """Test `delete_many` method filtering on both `id` and `uuid`"""
        # Create logs
        log1 = self.create_log()
        log2 = self.create_log()
        log3 = self.create_log()
        log_uuids = []
        for log in [log1, log2, log3]:
            log.store()
            log_uuids.append(str(log.uuid))

        # Make sure they exist
        count_logs_found = orm.QueryBuilder().append(orm.Log, filters={'uuid': {'in': log_uuids}}).count()
        self.assertEqual(
            count_logs_found,
            len(log_uuids),
            msg=f'There should be {len(log_uuids)} Logs, instead {count_logs_found} Log(s) was/were found'
        )

        # Delete last two logs (log2, log3)
        filters = {'or': [{'id': log2.id}, {'uuid': str(log3.uuid)}]}
        self.backend.logs.delete_many(filters=filters)

        # Check they were deleted
        builder = orm.QueryBuilder().append(orm.Log, filters={'uuid': {'in': log_uuids}}, project='uuid').all()
        found_logs_uuid = [_[0] for _ in builder]
        self.assertEqual([log_uuids[0]], found_logs_uuid)
Exemplo n.º 8
0
    def test_no_node_migration(self, tmp_path, external_archive):
        """Test migration of archive file that has no Node entities."""
        input_file = get_archive_file('export_v0.3_no_Nodes.aiida',
                                      **external_archive)
        output_file = tmp_path / 'output_file.aiida'

        migrator_cls = get_migrator(detect_archive_type(input_file))
        migrator = migrator_cls(input_file)

        # Perform the migration
        migrator.migrate(newest_version, output_file)

        # Load the migrated file
        import_data(output_file)

        # Check known entities
        assert orm.QueryBuilder().append(orm.Node).count() == 0
        computer_query = orm.QueryBuilder().append(orm.Computer,
                                                   project=['uuid'])
        assert computer_query.all(flat=True) == [
            '4f33c6fd-b624-47df-9ffb-a58f05d323af'
        ]
        user_query = orm.QueryBuilder().append(orm.User, project=['email'])
        assert set(user_query.all(flat=True)) == {
            orm.User.objects.get_default().email, 'aiida@localhost'
        }
Exemplo n.º 9
0
    def test_querying():
        """Test querying for groups with and without subclassing."""
        orm.Group(label='group').store()
        orm.AutoGroup(label='auto-group').store()

        # Fake a subclass by manually setting the type string
        group = orm.Group(label='custom-group')
        group.backend_entity.dbmodel.type_string = 'custom.group'
        group.store()

        assert orm.QueryBuilder().append(orm.AutoGroup).count() == 1
        assert orm.QueryBuilder().append(orm.AutoGroup,
                                         subclassing=False).count() == 1
        assert orm.QueryBuilder().append(orm.Group,
                                         subclassing=False).count() == 1
        assert orm.QueryBuilder().append(orm.Group).count() == 3
        assert orm.QueryBuilder().append(orm.Group,
                                         filters={
                                             'type_string': 'custom.group'
                                         }).count() == 1

        # Removing it as other methods might get a warning instead
        group_pk = group.pk
        del group
        orm.Group.objects.delete(id=group_pk)
Exemplo n.º 10
0
def get_input_node(cls, value):
    """Return a `Node` of a given class and given value.

    If a `Node` of the given type and value already exists, that will be returned, otherwise a new one will be created,
    stored and returned.

    :param cls: the `Node` class
    :param value: the value of the `Node`
    """
    from aiida import orm

    if cls in (orm.Bool, orm.Float, orm.Int, orm.Str):

        result = orm.QueryBuilder().append(cls, filters={'attributes.value': value}).first()

        if result is None:
            node = cls(value).store()
        else:
            node = result[0]

    elif cls is orm.Dict:
        result = orm.QueryBuilder().append(cls, filters={'attributes': {'==': value}}).first()

        if result is None:
            node = cls(dict=value).store()
        else:
            node = result[0]

    else:
        raise NotImplementedError

    return node
Exemplo n.º 11
0
    def test_delete_many_dbnode_id(self):
        """Test `delete_many` method filtering on `dbnode_id`"""
        # Create logs and separate node
        calc = self.backend.nodes.create(
            node_type='', user=self.user, computer=self.computer, label='label', description='description'
        ).store()
        log1 = self.create_log(dbnode_id=calc.id)
        log2 = self.create_log()
        log3 = self.create_log()
        log_uuids = []
        for log in [log1, log2, log3]:
            log.store()
            log_uuids.append(str(log.uuid))

        # Make sure they exist
        count_logs_found = orm.QueryBuilder().append(orm.Log, filters={'uuid': {'in': log_uuids}}).count()
        self.assertEqual(
            count_logs_found,
            len(log_uuids),
            msg=f'There should be {len(log_uuids)} Logs, instead {count_logs_found} Log(s) was/were found'
        )

        # Delete logs for self.node
        filters = {'dbnode_id': self.node.id}
        self.backend.logs.delete_many(filters=filters)

        # Check they were deleted
        builder = orm.QueryBuilder().append(orm.Log, filters={'uuid': {'in': log_uuids}}, project='uuid').all()
        found_logs_uuid = [_[0] for _ in builder]
        self.assertEqual([log_uuids[0]], found_logs_uuid)
Exemplo n.º 12
0
    def test_group_uuid_hashing_for_querybuidler(self):
        """QueryBuilder results should be reusable and shouldn't brake hashing."""
        group = orm.Group(label='test_group')
        group.store()

        # Search for the UUID of the stored group
        builder = orm.QueryBuilder()
        builder.append(orm.Group,
                       project=['uuid'],
                       filters={'label': {
                           '==': 'test_group'
                       }})
        [uuid] = builder.first()

        # Look the node with the previously returned UUID
        builder = orm.QueryBuilder()
        builder.append(orm.Group,
                       project=['id'],
                       filters={'uuid': {
                           '==': uuid
                       }})

        # Check that the query doesn't fail
        builder.all()

        # And that the results are correct
        self.assertEqual(builder.count(), 1)
        self.assertEqual(builder.first()[0], group.id)
Exemplo n.º 13
0
    def test_delete_many_dbnode_id(self):
        """Test `delete_many` method filtering on `dbnode_id`"""
        # Create comments and separate node
        calc = self.backend.nodes.create(
            node_type='', user=self.user, computer=self.computer, label='label', description='description'
        ).store()
        comment1 = self.create_comment(node=calc)
        comment2 = self.create_comment()
        comment3 = self.create_comment()
        comment_uuids = []
        for comment in [comment1, comment2, comment3]:
            comment.store()
            comment_uuids.append(str(comment.uuid))

        # Make sure they exist
        count_comments_found = orm.QueryBuilder().append(orm.Comment, filters={'uuid': {'in': comment_uuids}}).count()
        self.assertEqual(
            count_comments_found,
            len(comment_uuids),
            msg='There should be {} Comments, instead {} Comment(s) was/were found'.format(
                len(comment_uuids), count_comments_found
            )
        )

        # Delete comments for self.node
        filters = {'dbnode_id': self.node.id}
        self.backend.comments.delete_many(filters=filters)

        # Check they were deleted
        builder = orm.QueryBuilder().append(orm.Comment, filters={'uuid': {'in': comment_uuids}}, project='uuid').all()
        found_comments_uuid = [_[0] for _ in builder]
        self.assertEqual([comment_uuids[0]], found_comments_uuid)
Exemplo n.º 14
0
    def test_dangling_link_to_existing_db_node(self, temp_dir):
        """A dangling link that references a Node that is not included in the archive should `not` be importable"""
        struct = orm.StructureData()
        struct.store()
        struct_uuid = struct.uuid

        calc = orm.CalculationNode()
        calc.add_incoming(struct, LinkType.INPUT_CALC, 'input')
        calc.store()
        calc.seal()
        calc_uuid = calc.uuid

        filename = os.path.join(temp_dir, 'export.aiida')
        export([struct], filename=filename, file_format='tar.gz')

        unpack = SandboxFolder()
        with tarfile.open(filename, 'r:gz', format=tarfile.PAX_FORMAT) as tar:
            tar.extractall(unpack.abspath)

        with open(unpack.get_abs_path('data.json'), 'r',
                  encoding='utf8') as fhandle:
            data = json.load(fhandle)
        data['links_uuid'].append({
            'output': calc.uuid,
            'input': struct.uuid,
            'label': 'input',
            'type': LinkType.INPUT_CALC.value
        })

        with open(unpack.get_abs_path('data.json'), 'wb') as fhandle:
            json.dump(data, fhandle)

        with tarfile.open(filename, 'w:gz', format=tarfile.PAX_FORMAT) as tar:
            tar.add(unpack.abspath, arcname='')

        # Make sure the CalculationNode is still in the database
        builder = orm.QueryBuilder().append(orm.CalculationNode,
                                            project='uuid')
        self.assertEqual(
            builder.count(),
            1,
            msg=
            f'There should be a single CalculationNode, instead {builder.count()} has been found'
        )
        self.assertEqual(builder.all()[0][0], calc_uuid)

        with self.assertRaises(DanglingLinkError):
            import_data(filename)

        # Using the flag `ignore_unknown_nodes` should import it without problems
        import_data(filename, ignore_unknown_nodes=True)
        builder = orm.QueryBuilder().append(orm.StructureData, project='uuid')
        self.assertEqual(
            builder.count(),
            1,
            msg=
            f'There should be a single StructureData, instead {builder.count()} has been found'
        )
        self.assertEqual(builder.all()[0][0], struct_uuid)
Exemplo n.º 15
0
def test_full_type_backwards_compatibility(node_class, restapi_server,
                                           server_url):
    """Functionality test for the compatibility with old process_type entries.

    This will only check the integrity of the shape of the tree, there is no checking on how the
    data should be represented internally in term of full types, labels, etc. The only thing that
    must work correctly is the internal consistency of `full_type` as it is interpreted by the
    get_full_type_filters and the querybuilder.
    """
    node_empty = node_class()
    node_empty.process_type = ''
    node_empty.store()

    node_nones = node_class()
    node_nones.process_type = None
    node_nones.store()

    server = restapi_server()
    server_thread = Thread(target=server.serve_forever)

    try:
        server_thread.start()
        type_count_response = requests.get(f'{server_url}/nodes/full_types',
                                           timeout=10)
    finally:
        server.shutdown()

    # All nodes (contains either a process branch or data branch)
    # The main branch for all nodes does not currently return a queryable full_type
    current_namespace = type_count_response.json()['data']
    assert len(current_namespace['subspaces']) == 1

    # All subnodes (contains a workflow, calculation or data_type branch)
    current_namespace = current_namespace['subspaces'][0]
    query_all = orm.QueryBuilder().append(orm.Node,
                                          filters=get_full_type_filters(
                                              current_namespace['full_type']))
    assert len(current_namespace['subspaces']) == 1
    assert query_all.count() == 2

    # If this is a process node, there is an extra branch before the leaf
    # (calcfunction, workfunction, calcjob or workchain)
    if issubclass(node_class, orm.ProcessNode):
        current_namespace = current_namespace['subspaces'][0]
        query_all = orm.QueryBuilder().append(
            orm.Node,
            filters=get_full_type_filters(current_namespace['full_type']))
        assert len(current_namespace['subspaces']) == 1
        assert query_all.count() == 2

    # This will be the last leaf: the specific data type or the empty process_type
    current_namespace = current_namespace['subspaces'][0]
    query_all = orm.QueryBuilder().append(orm.Node,
                                          filters=get_full_type_filters(
                                              current_namespace['full_type']))
    assert len(current_namespace['subspaces']) == 0
    assert query_all.count() == 2
Exemplo n.º 16
0
    def test_edges(self):
        """
        Testing how the links are stored during traversal of the graph.
        """
        nodes = self._create_basic_graph()

        # Forward traversal (check all nodes and all links)
        basket = Basket(nodes=[nodes['data_i'].id])
        queryb = orm.QueryBuilder()
        queryb.append(orm.Node, tag='nodes_in_set')
        queryb.append(orm.Node, with_incoming='nodes_in_set')
        uprule = UpdateRule(queryb, max_iterations=2, track_edges=True)
        uprule_result = uprule.run(basket.copy())

        obtained = uprule_result['nodes'].keyset
        expected = set(anode.id for _, anode in nodes.items())
        self.assertEqual(obtained, expected)

        obtained = set()
        for data in uprule_result['nodes_nodes'].keyset:
            obtained.add((data[0], data[1]))

        expected = {
            (nodes['data_i'].id, nodes['calc_0'].id),
            (nodes['data_i'].id, nodes['work_1'].id),
            (nodes['data_i'].id, nodes['work_2'].id),
            (nodes['calc_0'].id, nodes['data_o'].id),
            (nodes['work_1'].id, nodes['data_o'].id),
            (nodes['work_2'].id, nodes['data_o'].id),
            (nodes['work_2'].id, nodes['work_1'].id),
            (nodes['work_1'].id, nodes['calc_0'].id),
        }
        self.assertEqual(obtained, expected)

        # Backwards traversal (check partial traversal and link direction)
        basket = Basket(nodes=[nodes['data_o'].id])
        queryb = orm.QueryBuilder()
        queryb.append(orm.Node, tag='nodes_in_set')
        queryb.append(orm.Node, with_outgoing='nodes_in_set')
        uprule = UpdateRule(queryb, max_iterations=1, track_edges=True)
        uprule_result = uprule.run(basket.copy())

        obtained = uprule_result['nodes'].keyset
        expected = set(anode.id for _, anode in nodes.items())
        expected = expected.difference(set([nodes['data_i'].id]))
        self.assertEqual(obtained, expected)

        obtained = set()
        for data in uprule_result['nodes_nodes'].keyset:
            obtained.add((data[0], data[1]))

        expected = {
            (nodes['calc_0'].id, nodes['data_o'].id),
            (nodes['work_1'].id, nodes['data_o'].id),
            (nodes['work_2'].id, nodes['data_o'].id),
        }
        self.assertEqual(obtained, expected)
    def test_empty_repo_folder_export(self, temp_dir):
        """Check a Node's empty repository folder is exported properly"""
        from aiida.common.folders import Folder
        from aiida.tools.importexport.dbexport import export_tree

        node = orm.Dict().store()
        node_uuid = node.uuid

        node_repo = RepositoryFolder(section=Repository._section_name, uuid=node_uuid)  # pylint: disable=protected-access
        self.assertTrue(
            node_repo.exists(), msg='Newly created and stored Node should have had an existing repository folder'
        )
        for filename, is_file in node_repo.get_content_list(only_paths=False):
            abspath_filename = os.path.join(node_repo.abspath, filename)
            if is_file:
                os.remove(abspath_filename)
            else:
                shutil.rmtree(abspath_filename, ignore_errors=False)
        self.assertFalse(
            node_repo.get_content_list(),
            msg='Repository folder should be empty, instead the following was found: {}'.format(
                node_repo.get_content_list()
            )
        )

        archive_variants = {
            'archive folder': os.path.join(temp_dir, 'export_tree'),
            'tar archive': os.path.join(temp_dir, 'export.tar.gz'),
            'zip archive': os.path.join(temp_dir, 'export.zip')
        }

        export_tree([node], folder=Folder(archive_variants['archive folder']), silent=True)
        export([node], filename=archive_variants['tar archive'], file_format='tar.gz', silent=True)
        export([node], filename=archive_variants['zip archive'], file_format='zip', silent=True)

        for variant, filename in archive_variants.items():
            self.reset_database()
            node_count = orm.QueryBuilder().append(orm.Dict, project='uuid').count()
            self.assertEqual(node_count, 0, msg='After DB reset {} Dict Nodes was (wrongly) found'.format(node_count))

            import_data(filename, silent=True)
            builder = orm.QueryBuilder().append(orm.Dict, project='uuid')
            imported_node_count = builder.count()
            self.assertEqual(
                imported_node_count,
                1,
                msg='After {} import a single Dict Node should have been found, '
                'instead {} was/were found'.format(variant, imported_node_count)
            )
            imported_node_uuid = builder.all()[0][0]
            self.assertEqual(
                imported_node_uuid,
                node_uuid,
                msg='The wrong UUID was found for the imported {}: '
                '{}. It should have been: {}'.format(variant, imported_node_uuid, node_uuid)
            )
Exemplo n.º 18
0
def cmd_completion_scf(max_atoms, number_species):
    """Determine the completion rate of the reconnaissance SCF step."""
    from aiida import orm

    filters_structure = {}

    if max_atoms is not None:
        filters_structure['attributes.sites'] = {'shorter': max_atoms + 1}

    if number_species is not None:
        filters_structure['attributes.kinds'] = {'of_length': number_species}

    query = orm.QueryBuilder()
    query.append(orm.Group, filters={'label': 'structure/unique'}, tag='group')
    query.append(orm.Node, with_group='group', filters=filters_structure)

    nstructures = query.count()

    query = orm.QueryBuilder()
    query.append(orm.Group, filters={'label': 'workchain/scf'}, tag='group')
    query.append(orm.WorkChainNode,
                 with_group='group',
                 tag='workchain',
                 project='attributes.exit_status')
    query.append(orm.StructureData,
                 with_outgoing='workchain',
                 filters=filters_structure,
                 project='id')

    active = []
    failed = []
    success = []
    submitted = []

    for exit_status, pk in query.iterall():
        if exit_status == 0:
            success.append(pk)
        elif exit_status is None:
            active.append(pk)
        else:
            failed.append(pk)

        submitted.append(pk)

    submitted = set(submitted)

    table = [
        ['Structures', nstructures],
        ['Submitted', len(submitted)],
        ['Submitted success', len(success)],
        ['Submitted failed', len(failed)],
        ['Submitted active', len(active)],
        ['Submittable', nstructures - len(submitted)],
    ]

    click.echo(tabulate.tabulate(table, tablefmt='plain'))
Exemplo n.º 19
0
def cif_clean(ctx, database, cif_filter, cif_select, concurrent, interval, dry_run):
    """Clean the CIF files imported from an external database.

    Run the `aiida-codtools workflow launch cif-clean` CLI script to clean the imported CIFs of the given database.
    """
    from datetime import datetime
    from time import sleep
    from aiida import orm
    from aiida_codtools.cli.workflows.cif_clean import launch_cif_clean

    now = datetime.utcnow().isoformat

    group_cif_raw = orm.Group.get(label='{database}/cif/raw'.format(database=database))
    group_cif_clean = orm.Group.get(label='{database}/cif/clean'.format(database=database))
    group_structure = orm.Group.get(label='{database}/structure/primitive'.format(database=database))
    group_workchain = orm.Group.get(label='{database}/workchain/clean'.format(database=database))

    while(True):

        filters = {'attributes.process_state': {'or': [{'==': 'excepted'}, {'==': 'killed'}]}}
        builder = orm.QueryBuilder().append(orm.ProcessNode, filters=filters)
        if builder.count() > 0:
            echo.echo_critical('found {} excepted or killed processes, exiting'.format(builder.count()))

        filters = {'attributes.process_state': {'or': [{'==': 'waiting'}, {'==': 'running'}, {'==': 'created'}]}}
        builder = orm.QueryBuilder().append(orm.WorkChainNode, filters=filters)
        current = builder.count()
        max_entries = concurrent - current

        if current < concurrent:
            echo.echo('{} | currently {} running workchains, submitting {} more'.format(now(), current, max_entries))

            inputs = {
                'cif_filter': cif_filter,
                'cif_select': cif_select,
                'group_cif_raw': group_cif_raw,
                'group_cif_clean': group_cif_clean,
                'group_structure': group_structure,
                'group_workchain': group_workchain,
                'max_entries': max_entries,
                'skip_check': False,
                'parse_engine': 'ase',
                'daemon': True,
            }
            if dry_run:
                from pprint import pprint
                echo.echo(pprint(inputs))
                return
            else:
                ctx.invoke(launch_cif_clean, **inputs)
        else:
            echo.echo('{} | currently {} running workchains, nothing to submit'.format(now(), current))

        echo.echo('{} | sleeping {} seconds'.format(now(), interval))
        sleep(interval)
Exemplo n.º 20
0
    def test_mtime_of_imported_comments(self, temp_dir):
        """
        Test mtime does not change for imported comments
        This is related to correct usage of `comment_mode` when importing.
        """
        # Get user
        user = orm.User.objects.get_default()

        comment_content = 'You get what you give'

        # Create node
        calc = orm.CalculationNode().store()
        calc.seal()

        # Create comment
        orm.Comment(calc, user, comment_content).store()
        calc.store()

        # Save UUIDs and mtime
        calc_uuid = calc.uuid
        builder = orm.QueryBuilder().append(orm.Comment, project=['uuid', 'mtime']).all()
        comment_uuid = str(builder[0][0])
        comment_mtime = builder[0][1]

        builder = orm.QueryBuilder().append(orm.CalculationNode, project=['uuid', 'mtime']).all()
        calc_uuid = str(builder[0][0])
        calc_mtime = builder[0][1]

        # Export, reset database and reimport
        export_file = os.path.join(temp_dir, 'export.aiida')
        export([calc], filename=export_file, silent=True)
        self.reset_database()
        import_data(export_file, silent=True)

        # Retrieve node and comment
        builder = orm.QueryBuilder().append(orm.CalculationNode, tag='calc', project=['uuid', 'mtime'])
        builder.append(orm.Comment, with_node='calc', project=['uuid', 'mtime'])

        import_entities = builder.all()[0]

        self.assertEqual(len(import_entities), 4)  # Check we have the correct amount of returned values

        import_calc_uuid = str(import_entities[0])
        import_calc_mtime = import_entities[1]
        import_comment_uuid = str(import_entities[2])
        import_comment_mtime = import_entities[3]

        # Check we have the correct UUIDs
        self.assertEqual(import_calc_uuid, calc_uuid)
        self.assertEqual(import_comment_uuid, comment_uuid)

        # Make sure the mtime is the same after import as it was before export
        self.assertEqual(import_comment_mtime, comment_mtime)
        self.assertEqual(import_calc_mtime, calc_mtime)
Exemplo n.º 21
0
    def test_delete_many_ctime_mtime(self):
        """Test `delete_many` method filtering on `ctime` and `mtime`"""
        from datetime import timedelta

        # Initialization
        comment_uuids = []
        found_comments_ctime = []
        found_comments_mtime = []
        found_comments_uuid = []

        now = timezone.now()
        two_days_ago = now - timedelta(days=2)
        one_day_ago = now - timedelta(days=1)
        comment_times = [now, one_day_ago, two_days_ago]

        # Create comments
        comment1 = self.create_comment(ctime=now, mtime=now)
        comment2 = self.create_comment(ctime=one_day_ago, mtime=now)
        comment3 = self.create_comment(ctime=two_days_ago, mtime=one_day_ago)
        for comment in [comment1, comment2, comment3]:
            comment.store()
            comment_uuids.append(str(comment.uuid))

        # Make sure they exist with the correct times
        builder = orm.QueryBuilder().append(orm.Comment, project=['ctime', 'mtime', 'uuid'])
        self.assertGreater(builder.count(), 0)
        for comment in builder.all():
            found_comments_ctime.append(comment[0])
            found_comments_mtime.append(comment[1])
            found_comments_uuid.append(comment[2])
        for time, uuid in zip(comment_times, comment_uuids):
            self.assertIn(time, found_comments_ctime)
            self.assertIn(uuid, found_comments_uuid)
            if time != two_days_ago:
                self.assertIn(time, found_comments_mtime)

        # Delete comments that are created more than 1 hour ago,
        # unless they have been modified within 5 hours
        ctime_turning_point = now - timedelta(seconds=60 * 60)
        mtime_turning_point = now - timedelta(seconds=60 * 60 * 5)
        filters = {'and': [{'ctime': {'<': ctime_turning_point}}, {'mtime': {'<': mtime_turning_point}}]}
        self.backend.comments.delete_many(filters=filters)

        # Check only the most stale comment (comment3) was deleted
        builder = orm.QueryBuilder().append(orm.Comment, project='uuid')
        self.assertGreater(builder.count(), 1)  # There should still be at least 2
        found_comments_uuid = [_[0] for _ in builder.all()]
        self.assertNotIn(comment_uuids[2], found_comments_uuid)

        # Make sure the other comments were not deleted
        for comment_uuid in comment_uuids[:-1]:
            self.assertIn(comment_uuid, found_comments_uuid)
Exemplo n.º 22
0
    def test_querying():
        """Test querying for groups with and without subclassing."""
        orm.Group(label='group').store()
        orm.AutoGroup(label='auto-group').store()

        # Fake a subclass by manually setting the type string
        group = orm.Group(label='custom-group')
        group.backend_entity.dbmodel.type_string = 'custom.group'
        group.store()

        assert orm.QueryBuilder().append(orm.AutoGroup).count() == 1
        assert orm.QueryBuilder().append(orm.AutoGroup, subclassing=False).count() == 1
        assert orm.QueryBuilder().append(orm.Group, subclassing=False).count() == 1
        assert orm.QueryBuilder().append(orm.Group).count() == 3
        assert orm.QueryBuilder().append(orm.Group, filters={'type_string': 'custom.group'}).count() == 1
Exemplo n.º 23
0
    def test_delete_many_time(self):
        """Test `delete_many` method filtering on `time`"""
        from datetime import timedelta

        # Initialization
        log_uuids = []
        found_logs_time = []
        found_logs_uuid = []

        now = timezone.now()
        two_days_ago = now - timedelta(days=2)
        one_day_ago = now - timedelta(days=1)
        log_times = [now, one_day_ago, two_days_ago]

        # Create logs
        log1 = self.create_log(time=now)
        log2 = self.create_log(time=one_day_ago)
        log3 = self.create_log(time=two_days_ago)
        for log in [log1, log2, log3]:
            log.store()
            log_uuids.append(str(log.uuid))

        # Make sure they exist with the correct times
        builder = orm.QueryBuilder().append(orm.Log, project=['time', 'uuid'])
        self.assertGreater(builder.count(), 0)
        for log in builder.all():
            found_logs_time.append(log[0])
            found_logs_uuid.append(log[1])
        for log_time in log_times:
            self.assertIn(log_time, found_logs_time)
        for log_uuid in log_uuids:
            self.assertIn(log_uuid, found_logs_uuid)

        # Delete logs that are older than 1 hour
        turning_point = now - timedelta(seconds=60 * 60)
        filters = {'time': {'<': turning_point}}
        self.backend.logs.delete_many(filters=filters)

        # Check they were deleted
        builder = orm.QueryBuilder().append(orm.Log, project='uuid')
        self.assertGreater(builder.count(),
                           0)  # There should still be at least 1
        found_logs_uuid = [_[0] for _ in builder.all()]
        for log_uuid in log_uuids[1:]:
            self.assertNotIn(log_uuid, found_logs_uuid)

        # Make sure the newest log (log1) was not deleted
        self.assertIn(log_uuids[0], found_logs_uuid)
Exemplo n.º 24
0
    def link_flags_import_helper(self, test_data):
        """Helper function"""
        for test, (export_file, _, expected_nodes) in test_data.items():
            self.clean_db()

            import_data(export_file)

            nodes_util = {
                'work': orm.WorkflowNode,
                'calc': orm.CalculationNode,
                'data': orm.Data
            }
            for node_type, node_cls in nodes_util.items():
                if node_type in expected_nodes:
                    builder = orm.QueryBuilder().append(node_cls,
                                                        project='uuid')
                    self.assertEqual(
                        builder.count(),
                        len(expected_nodes[node_type]),
                        msg='Expected {} {} node(s), but got {}. Test: "{}"'.
                        format(len(expected_nodes[node_type]), node_type,
                               builder.count(), test))
                    for node_uuid in builder.iterall():
                        self.assertIn(node_uuid[0],
                                      expected_nodes[node_type],
                                      msg=f'Failed for test: "{test}"')
Exemplo n.º 25
0
    def test_complex_workflow_graph_export_sets(self, temp_dir):
        """Test ex-/import of individual nodes in complex graph"""
        for export_conf in range(0, 9):

            _, (export_node,
                export_target) = self.construct_complex_graph(export_conf)
            export_target_uuids = set(_.uuid for _ in export_target)

            export_file = os.path.join(temp_dir, 'export.aiida')
            export([export_node], filename=export_file, overwrite=True)
            export_node_str = str(export_node)

            self.clean_db()

            import_data(export_file)

            # Get all the nodes of the database
            builder = orm.QueryBuilder()
            builder.append(orm.Node, project='uuid')
            imported_node_uuids = set(_[0] for _ in builder.all())

            self.assertSetEqual(
                export_target_uuids, imported_node_uuids,
                'Problem in comparison of export node: ' + export_node_str +
                '\n' + 'Expected set: ' + str(export_target_uuids) + '\n' +
                'Imported set: ' + str(imported_node_uuids) + '\n' +
                'Difference: ' + str(
                    export_target_uuids.symmetric_difference(
                        imported_node_uuids)))
Exemplo n.º 26
0
    def children(self):
        # type: () -> Iterable[GroupPath]
        """Iterate through all (direct) children of this path."""
        query = orm.QueryBuilder()
        filters = {}
        if self.path:
            filters['label'] = {'like': self.path + self.delimiter + '%'}
        query.append(self.cls,
                     subclassing=False,
                     filters=filters,
                     project='label')
        if query.count() == 0 and self.is_virtual:
            raise NoGroupsInPathError(self)

        yielded = []
        for (label, ) in query.iterall():
            path = label.split(self._delimiter)
            if len(path) <= len(self._path_list):
                continue
            path_string = self._delimiter.join(path[:len(self._path_list) + 1])
            if (path_string not in yielded
                    and path[:len(self._path_list)] == self._path_list):
                yielded.append(path_string)
                try:
                    yield GroupPath(
                        path=path_string,
                        cls=self.cls,
                        warn_invalid_child=self._warn_invalid_child)
                except InvalidPath:
                    if self._warn_invalid_child:
                        warnings.warn(
                            'invalid path encountered: {}'.format(path_string))  # pylint: disable=no-member
Exemplo n.º 27
0
    def test_explicit_type_string():
        """Test that passing explicit `type_string` to `Group` constructor is still possible despite being deprecated.

        Both constructing a group while passing explicit `type_string` as well as loading a group with unregistered
        type string should emit a warning, but it should be possible.
        """
        type_string = 'data.potcar'  # An unregistered custom type string

        with pytest.warns(UserWarning):
            group = orm.Group(label='group', type_string=type_string)

        group.store()
        assert group.type_string == type_string

        with pytest.warns(UserWarning):
            loaded = orm.Group.get(label=group.label, type_string=type_string)

        assert isinstance(loaded, orm.Group)
        assert loaded.pk == group.pk
        assert loaded.type_string == group.type_string

        queried = orm.QueryBuilder().append(orm.Group,
                                            filters={
                                                'id': group.pk,
                                                'type_string': type_string
                                            }).one()[0]
        assert isinstance(queried, orm.Group)
        assert queried.pk == group.pk
        assert queried.type_string == group.type_string
Exemplo n.º 28
0
    def _find(backend: orm.implementation.Backend, entity_type: orm.Entity,
              **kwargs) -> orm.QueryBuilder:
        for key in kwargs:
            if key not in {
                    "filters", "order_by", "limit", "project", "offset"
            }:
                raise ValueError(
                    f"You supplied key {key}. _find() only takes the keys: "
                    '"filters", "order_by", "limit", "project", "offset"')

        filters = kwargs.get("filters", {})
        order_by = kwargs.get("order_by", None)
        order_by = {
            entity_type: order_by
        } if order_by else {
            entity_type: {
                "id": "asc"
            }
        }
        limit = kwargs.get("limit", None)
        offset = kwargs.get("offset", None)
        project = kwargs.get("project", [])

        query = orm.QueryBuilder(backend=backend, limit=limit, offset=offset)
        query.append(entity_type, project=project, filters=filters)
        query.order_by(order_by)

        return query
Exemplo n.º 29
0
    def walk_nodes(self, filters=None, node_class=None, query_batch=None):
        # type: () -> Iterable[WalkNodeResult]
        """Recursively iterate through all nodes of this path and its children.

        :param filters: filters to apply to the node query
        :param node_class: return only nodes of a certain class (or list of classes)
        :param int batch_size: The size of the batches to ask the backend to batch results in subcollections.
            You can optimize the speed of the query by tuning this parameter.
            Be aware though that is only safe if no commit will take place during this transaction.
        """
        query = orm.QueryBuilder()
        group_filters = {}
        if self.path:
            group_filters['label'] = {
                'or': [{
                    '==': self.path
                }, {
                    'like': self.path + self.delimiter + '%'
                }]
            }
        query.append(self.cls,
                     subclassing=False,
                     filters=group_filters,
                     project='label',
                     tag='group')
        query.append(
            orm.Node if node_class is None else node_class,
            with_group='group',
            filters=filters,
            project=['*'],
        )
        for (label, node
             ) in query.iterall(query_batch) if query_batch else query.all():
            yield WalkNodeResult(GroupPath(label, cls=self.cls), node)
Exemplo n.º 30
0
    def test_import_of_django_sqla_export_file(self):
        """Check that sqla import manages to import the django export file correctly"""
        from tests.utils.archives import import_archive

        for archive in ['django.aiida', 'sqlalchemy.aiida']:
            # Clean the database
            self.reset_database()

            # Import the needed data
            import_archive(archive, filepath='export/compare')

            # The expected metadata
            comp1_metadata = {'workdir': '/tmp/aiida'}

            # Check that we got the correct metadata
            # Make sure to exclude the default computer
            builder = orm.QueryBuilder()
            builder.append(orm.Computer,
                           project=['metadata'],
                           tag='comp',
                           filters={'name': {
                               '!==': self.computer.label
                           }})
            self.assertEqual(builder.count(), 1, 'Expected only one computer')

            res = builder.dict()[0]

            self.assertEqual(res['comp']['metadata'], comp1_metadata)