Example #1
0
    def test_check_for_export_format_version(self):
        """Test the check for the export format version."""
        # Creating a folder for the import/export files
        export_file_tmp_folder = tempfile.mkdtemp()
        unpack_tmp_folder = tempfile.mkdtemp()
        try:
            struct = orm.StructureData()
            struct.store()

            filename = os.path.join(export_file_tmp_folder, 'export.tar.gz')
            export([struct], outfile=filename, silent=True)

            with tarfile.open(filename, 'r:gz', format=tarfile.PAX_FORMAT) as tar:
                tar.extractall(unpack_tmp_folder)

            with open(os.path.join(unpack_tmp_folder, 'metadata.json'), 'r', encoding='utf8') as fhandle:
                metadata = json.load(fhandle)
            metadata['export_version'] = 0.0

            with open(os.path.join(unpack_tmp_folder, 'metadata.json'), 'wb') as fhandle:
                json.dump(metadata, fhandle)

            with tarfile.open(filename, 'w:gz', format=tarfile.PAX_FORMAT) as tar:
                tar.add(unpack_tmp_folder, arcname='')

            self.tearDownClass()
            self.setUpClass()

            with self.assertRaises(exceptions.IncompatibleArchiveVersionError):
                import_data(filename, silent=True)
        finally:
            # Deleting the created temporary folders
            shutil.rmtree(export_file_tmp_folder, ignore_errors=True)
            shutil.rmtree(unpack_tmp_folder, ignore_errors=True)
Example #2
0
    def test_group_import_existing(self, temp_dir):
        """
        Testing what happens when I try to import a group that already exists in the
        database. This should raise an appropriate exception
        """
        grouplabel = 'node_group_existing'

        # Create another user
        new_email = '[email protected]'
        user = orm.User(email=new_email)
        user.store()

        # Create a structure data node
        sd1 = orm.StructureData()
        sd1.user = user
        sd1.label = 'sd'
        sd1.store()

        # Create a group and add the data inside
        group = orm.Group(label=grouplabel)
        group.store()
        group.add_nodes([sd1])

        # At this point we export the generated data
        filename = os.path.join(temp_dir, 'export1.tar.gz')
        export([group], outfile=filename, silent=True)
        self.clean_db()
        self.insert_data()

        # Creating a group of the same name
        group = orm.Group(label='node_group_existing')
        group.store()
        import_data(filename, silent=True)
        # The import should have created a new group with a suffix
        # I check for this:
        builder = orm.QueryBuilder().append(
            orm.Group, filters={'label': {
                'like': grouplabel + '%'
            }})
        self.assertEqual(builder.count(), 2)
        # Now I check for the group having one member, and whether the name is different:
        builder = orm.QueryBuilder()
        builder.append(orm.Group,
                       filters={'label': {
                           'like': grouplabel + '%'
                       }},
                       tag='g',
                       project='label')
        builder.append(orm.StructureData, with_group='g')
        self.assertEqual(builder.count(), 1)
        # I check that the group name was changed:
        self.assertTrue(builder.all()[0][0] != grouplabel)
        # I import another name, the group should not be imported again
        import_data(filename, silent=True)
        builder = orm.QueryBuilder()
        builder.append(orm.Group,
                       filters={'label': {
                           'like': grouplabel + '%'
                       }})
        self.assertEqual(builder.count(), 2)
Example #3
0
    def test_high_level_workflow_links(self, temp_dir):
        """
        This test checks that all the needed links are correctly exported and imported.
        INPUT_CALC, INPUT_WORK, CALL_CALC, CALL_WORK, CREATE, and RETURN
        links connecting Data nodes and high-level Calculation and Workflow nodes:
        CalcJobNode, CalcFunctionNode, WorkChainNode, WorkFunctionNode
        """
        high_level_calc_nodes = [['CalcJobNode', 'CalcJobNode'], ['CalcJobNode', 'CalcFunctionNode'],
                                 ['CalcFunctionNode', 'CalcJobNode'], ['CalcFunctionNode', 'CalcFunctionNode']]

        high_level_work_nodes = [['WorkChainNode', 'WorkChainNode'], ['WorkChainNode', 'WorkFunctionNode'],
                                 ['WorkFunctionNode', 'WorkChainNode'], ['WorkFunctionNode', 'WorkFunctionNode']]

        for calcs in high_level_calc_nodes:
            for works in high_level_work_nodes:
                self.reset_database()

                graph_nodes, _ = self.construct_complex_graph(calc_nodes=calcs, work_nodes=works)

                # Getting the input, create, return and call links
                builder = orm.QueryBuilder()
                builder.append(orm.Node, project='uuid')
                builder.append(
                    orm.Node,
                    project='uuid',
                    edge_project=['label', 'type'],
                    edge_filters={
                        'type': {
                            'in': (
                                LinkType.INPUT_CALC.value, LinkType.INPUT_WORK.value, LinkType.CREATE.value,
                                LinkType.RETURN.value, LinkType.CALL_CALC.value, LinkType.CALL_WORK.value
                            )
                        }
                    }
                )

                self.assertEqual(
                    builder.count(),
                    13,
                    msg='Failed with c1={}, c2={}, w1={}, w2={}'.format(calcs[0], calcs[1], works[0], works[1])
                )

                export_links = builder.all()

                export_file = os.path.join(temp_dir, 'export.tar.gz')
                export(graph_nodes, outfile=export_file, silent=True, overwrite=True)

                self.reset_database()

                import_data(export_file, silent=True)
                import_links = get_all_node_links()

                export_set = [tuple(_) for _ in export_links]
                import_set = [tuple(_) for _ in import_links]

                self.assertSetEqual(
                    set(export_set),
                    set(import_set),
                    msg='Failed with c1={}, c2={}, w1={}, w2={}'.format(calcs[0], calcs[1], works[0], works[1])
                )
Example #4
0
    def test_complex_workflow_graph_export_sets(self, temp_dir):
        """Test ex-/import of individual nodes in complex graph"""
        for export_conf in range(0, 9):

            _, (export_node,
                export_target) = self.construct_complex_graph(export_conf)
            export_target_uuids = set(_.uuid for _ in export_target)

            export_file = os.path.join(temp_dir, 'export.aiida')
            export([export_node], filename=export_file, overwrite=True)
            export_node_str = str(export_node)

            self.clean_db()

            import_data(export_file)

            # Get all the nodes of the database
            builder = orm.QueryBuilder()
            builder.append(orm.Node, project='uuid')
            imported_node_uuids = set(_[0] for _ in builder.all())

            self.assertSetEqual(
                export_target_uuids, imported_node_uuids,
                'Problem in comparison of export node: ' + export_node_str +
                '\n' + 'Expected set: ' + str(export_target_uuids) + '\n' +
                'Imported set: ' + str(imported_node_uuids) + '\n' +
                'Difference: ' + str(
                    export_target_uuids.symmetric_difference(
                        imported_node_uuids)))
Example #5
0
def test_check_for_export_format_version(aiida_profile, tmp_path):
    """Test the check for the export format version."""
    # Creating a folder for the archive files
    export_file_tmp_folder = tmp_path / 'export_tmp'
    export_file_tmp_folder.mkdir()
    unpack_tmp_folder = tmp_path / 'unpack_tmp'
    unpack_tmp_folder.mkdir()

    aiida_profile.reset_db()

    struct = orm.StructureData()
    struct.store()

    filename = str(export_file_tmp_folder / 'export.aiida')
    export([struct], filename=filename, file_format='tar.gz')

    with tarfile.open(filename, 'r:gz', format=tarfile.PAX_FORMAT) as tar:
        tar.extractall(unpack_tmp_folder)

    with (unpack_tmp_folder / 'metadata.json').open(
            'r', encoding='utf8') as fhandle:
        metadata = json.load(fhandle)
    metadata['export_version'] = 0.0

    with (unpack_tmp_folder / 'metadata.json').open('wb') as fhandle:
        json.dump(metadata, fhandle)

    with tarfile.open(filename, 'w:gz', format=tarfile.PAX_FORMAT) as tar:
        tar.add(unpack_tmp_folder, arcname='')

    aiida_profile.reset_db()

    with pytest.raises(exceptions.IncompatibleArchiveVersionError):
        import_data(filename)
Example #6
0
    def test_workcalculation(self, temp_dir):
        """Test simple master/slave WorkChainNodes"""
        from aiida.common.links import LinkType

        master = orm.WorkChainNode()
        slave = orm.WorkChainNode()

        input_1 = orm.Int(3).store()
        input_2 = orm.Int(5).store()
        output_1 = orm.Int(2).store()

        master.add_incoming(input_1, LinkType.INPUT_WORK, 'input_1')
        slave.add_incoming(master, LinkType.CALL_WORK, 'CALL')
        slave.add_incoming(input_2, LinkType.INPUT_WORK, 'input_2')

        master.store()
        slave.store()

        output_1.add_incoming(master, LinkType.RETURN, 'RETURN')

        master.seal()
        slave.seal()

        uuids_values = [(v.uuid, v.value) for v in (output_1, )]
        filename1 = os.path.join(temp_dir, 'export1.tar.gz')
        export([output_1], outfile=filename1, silent=True)
        self.clean_db()
        self.insert_data()
        import_data(filename1, silent=True)

        for uuid, value in uuids_values:
            self.assertEqual(orm.load_node(uuid).value, value)
Example #7
0
    def test_input_and_create_links(self, temp_dir):
        """
        Simple test that will verify that INPUT and CREATE links are properly exported and
        correctly recreated upon import.
        """
        node_work = orm.CalculationNode()
        node_input = orm.Int(1).store()
        node_output = orm.Int(2).store()

        node_work.add_incoming(node_input, LinkType.INPUT_CALC, 'input')
        node_work.store()
        node_output.add_incoming(node_work, LinkType.CREATE, 'output')

        node_work.seal()

        export_links = get_all_node_links()
        export_file = os.path.join(temp_dir, 'export.aiida')
        export([node_output], filename=export_file)

        self.clean_db()

        import_data(export_file)
        import_links = get_all_node_links()

        export_set = [tuple(_) for _ in export_links]
        import_set = [tuple(_) for _ in import_links]

        self.assertSetEqual(set(export_set), set(import_set))
Example #8
0
    def test_critical_log_msg_and_metadata(self, temp_dir):
        """ Testing logging of critical message """
        message = 'Testing logging of critical failure'
        calc = orm.CalculationNode()

        # Firing a log for an unstored node should not end up in the database
        calc.logger.critical(message)
        # There should be no log messages for the unstored object
        self.assertEqual(len(orm.Log.objects.all()), 0)

        # After storing the node, logs above log level should be stored
        calc.store()
        calc.seal()
        calc.logger.critical(message)

        # Store Log metadata
        log_metadata = orm.Log.objects.get(dbnode_id=calc.id).metadata

        export_file = os.path.join(temp_dir, 'export.tar.gz')
        export([calc], outfile=export_file, silent=True)

        self.reset_database()

        import_data(export_file, silent=True)

        # Finding all the log messages
        logs = orm.Log.objects.all()

        self.assertEqual(len(logs), 1)
        self.assertEqual(logs[0].message, message)
        self.assertEqual(logs[0].metadata, log_metadata)
Example #9
0
def test_import(aiida_profile, benchmark, tmp_path, depth, breadth,
                num_objects):
    """Benchmark importing a provenance graph."""
    aiida_profile.reset_db()
    root_node = Dict()
    recursive_provenance(root_node,
                         depth=depth,
                         breadth=breadth,
                         num_objects=num_objects)
    root_uuid = root_node.uuid
    out_path = tmp_path / 'test.aiida'
    kwargs = get_export_kwargs(filename=str(out_path))
    export([root_node], **kwargs)

    def _setup():
        aiida_profile.reset_db()

    def _run():
        import_data(str(out_path), silent=True)

    benchmark.pedantic(_run,
                       setup=_setup,
                       iterations=1,
                       rounds=12,
                       warmup_rounds=1)
    load_node(root_uuid)
Example #10
0
    def test_missing_node_repo_folder_export(self, temp_dir):
        """
        Make sure `~aiida.tools.importexport.common.exceptions.ArchiveExportError` is raised during export when missing
        Node repository folder.
        Create and store a new Node and manually remove its repository folder.
        Attempt to export it and make sure `~aiida.tools.importexport.common.exceptions.ArchiveExportError` is raised,
        due to the missing folder.
        """
        node = orm.CalculationNode().store()
        node.seal()
        node_uuid = node.uuid

        node_repo = RepositoryFolder(section=Repository._section_name, uuid=node_uuid)  # pylint: disable=protected-access
        self.assertTrue(
            node_repo.exists(), msg='Newly created and stored Node should have had an existing repository folder'
        )

        # Removing the Node's local repository folder
        shutil.rmtree(node_repo.abspath, ignore_errors=True)
        self.assertFalse(
            node_repo.exists(), msg='Newly created and stored Node should have had its repository folder removed'
        )

        # Try to export, check it raises and check the raise message
        filename = os.path.join(temp_dir, 'export.aiida')
        with self.assertRaises(exceptions.ArchiveExportError) as exc:
            export([node], filename=filename)

        self.assertIn(f'Unable to find the repository folder for Node with UUID={node_uuid}', str(exc.exception))
        self.assertFalse(os.path.exists(filename), msg='The archive file should not exist')
Example #11
0
    def test_exclude_logs_flag(self, temp_dir):
        """Test that the `include_logs` argument for `export` works."""
        log_msg = 'Testing logging of critical failure'

        # Create node
        calc = orm.CalculationNode()
        calc.store()
        calc.seal()

        # Create log message
        calc.logger.critical(log_msg)

        # Save uuids prior to export
        calc_uuid = calc.uuid

        # Export, excluding logs
        export_file = os.path.join(temp_dir, 'export.tar.gz')
        export([calc], outfile=export_file, silent=True, include_logs=False)

        # Clean database and reimport exported data
        self.reset_database()
        import_data(export_file, silent=True)

        # Finding all the log messages
        import_calcs = orm.QueryBuilder().append(orm.CalculationNode,
                                                 project=['uuid']).all()
        import_logs = orm.QueryBuilder().append(orm.Log,
                                                project=['uuid']).all()

        # There should be exactly: 1 orm.CalculationNode, 0 Logs
        self.assertEqual(len(import_calcs), 1)
        self.assertEqual(len(import_logs), 0)

        # Check it's the correct node
        self.assertEqual(str(import_calcs[0][0]), calc_uuid)
Example #12
0
    def test_calcfunction(self, temp_dir):
        """Test @calcfunction"""
        from aiida.engine import calcfunction
        from aiida.common.exceptions import NotExistent

        @calcfunction
        def add(a, b):
            """Add 2 numbers"""
            return {'res': orm.Float(a + b)}

        def max_(**kwargs):
            """select the max value"""
            max_val = max([(v.value, v) for v in kwargs.values()])
            return {'res': max_val[1]}

        # I'm creating a bunch of numbers
        a, b, c, d, e = (orm.Float(i).store() for i in range(5))
        # this adds the maximum number between bcde to a.
        res = add(a=a, b=max_(b=b, c=c, d=d, e=e)['res'])['res']
        # These are the uuids that would be exported as well (as parents) if I wanted the final result
        uuids_values = [(a.uuid, a.value), (e.uuid, e.value), (res.uuid, res.value)]
        # These are the uuids that shouldn't be exported since it's a selection.
        not_wanted_uuids = [v.uuid for v in (b, c, d)]
        # At this point we export the generated data
        filename1 = os.path.join(temp_dir, 'export1.tar.gz')
        export([res], outfile=filename1, silent=True, return_backward=True)
        self.clean_db()
        self.insert_data()
        import_data(filename1, silent=True)
        # Check that the imported nodes are correctly imported and that the value is preserved
        for uuid, value in uuids_values:
            self.assertEqual(orm.load_node(uuid).value, value)
        for uuid in not_wanted_uuids:
            with self.assertRaises(NotExistent):
                orm.load_node(uuid)
Example #13
0
    def test_dangling_link_to_existing_db_node(self, temp_dir):
        """A dangling link that references a Node that is not included in the archive should `not` be importable"""
        struct = orm.StructureData()
        struct.store()
        struct_uuid = struct.uuid

        calc = orm.CalculationNode()
        calc.add_incoming(struct, LinkType.INPUT_CALC, 'input')
        calc.store()
        calc.seal()
        calc_uuid = calc.uuid

        filename = os.path.join(temp_dir, 'export.aiida')
        export([struct], filename=filename, file_format='tar.gz')

        unpack = SandboxFolder()
        with tarfile.open(filename, 'r:gz', format=tarfile.PAX_FORMAT) as tar:
            tar.extractall(unpack.abspath)

        with open(unpack.get_abs_path('data.json'), 'r',
                  encoding='utf8') as fhandle:
            data = json.load(fhandle)
        data['links_uuid'].append({
            'output': calc.uuid,
            'input': struct.uuid,
            'label': 'input',
            'type': LinkType.INPUT_CALC.value
        })

        with open(unpack.get_abs_path('data.json'), 'wb') as fhandle:
            json.dump(data, fhandle)

        with tarfile.open(filename, 'w:gz', format=tarfile.PAX_FORMAT) as tar:
            tar.add(unpack.abspath, arcname='')

        # Make sure the CalculationNode is still in the database
        builder = orm.QueryBuilder().append(orm.CalculationNode,
                                            project='uuid')
        self.assertEqual(
            builder.count(),
            1,
            msg=
            f'There should be a single CalculationNode, instead {builder.count()} has been found'
        )
        self.assertEqual(builder.all()[0][0], calc_uuid)

        with self.assertRaises(DanglingLinkError):
            import_data(filename)

        # Using the flag `ignore_unknown_nodes` should import it without problems
        import_data(filename, ignore_unknown_nodes=True)
        builder = orm.QueryBuilder().append(orm.StructureData, project='uuid')
        self.assertEqual(
            builder.count(),
            1,
            msg=
            f'There should be a single StructureData, instead {builder.count()} has been found'
        )
        self.assertEqual(builder.all()[0][0], struct_uuid)
    def test_missing_node_repo_folder_import(self, temp_dir):
        """
        Make sure `~aiida.tools.importexport.common.exceptions.CorruptArchive` is raised during import when missing
        Node repository folder.
        Create and export a Node and manually remove its repository folder in the export file.
        Attempt to import it and make sure `~aiida.tools.importexport.common.exceptions.CorruptArchive` is raised,
        due to the missing folder.
        """
        import tarfile

        from aiida.common.folders import SandboxFolder
        from aiida.tools.importexport.common.archive import extract_tar
        from aiida.tools.importexport.common.config import NODES_EXPORT_SUBFOLDER
        from aiida.tools.importexport.common.utils import export_shard_uuid

        node = orm.CalculationNode().store()
        node.seal()
        node_uuid = node.uuid

        node_repo = RepositoryFolder(section=Repository._section_name, uuid=node_uuid)  # pylint: disable=protected-access
        self.assertTrue(
            node_repo.exists(), msg='Newly created and stored Node should have had an existing repository folder'
        )

        # Export and reset db
        filename = os.path.join(temp_dir, 'export.aiida')
        export([node], filename=filename, file_format='tar.gz', silent=True)
        self.reset_database()

        # Untar export file, remove repository folder, re-tar
        node_shard_uuid = export_shard_uuid(node_uuid)
        node_top_folder = node_shard_uuid.split('/')[0]
        with SandboxFolder() as folder:
            extract_tar(filename, folder, silent=True, nodes_export_subfolder=NODES_EXPORT_SUBFOLDER)
            node_folder = folder.get_subfolder(os.path.join(NODES_EXPORT_SUBFOLDER, node_shard_uuid))
            self.assertTrue(
                node_folder.exists(), msg="The Node's repository folder should still exist in the export file"
            )

            # Removing the Node's repository folder from the export file
            shutil.rmtree(
                folder.get_subfolder(os.path.join(NODES_EXPORT_SUBFOLDER, node_top_folder)).abspath, ignore_errors=True
            )
            self.assertFalse(
                node_folder.exists(),
                msg="The Node's repository folder should now have been removed in the export file"
            )

            filename_corrupt = os.path.join(temp_dir, 'export_corrupt.aiida')
            with tarfile.open(filename_corrupt, 'w:gz', format=tarfile.PAX_FORMAT, dereference=True) as tar:
                tar.add(folder.abspath, arcname='')

        # Try to import, check it raises and check the raise message
        with self.assertRaises(exceptions.CorruptArchive) as exc:
            import_data(filename_corrupt, silent=True)

        self.assertIn(
            'Unable to find the repository folder for Node with UUID={}'.format(node_uuid), str(exc.exception)
        )
Example #15
0
def create(
    output_file, codes, computers, groups, nodes, archive_format, force, input_calc_forward, input_work_forward,
    create_backward, return_backward, call_calc_backward, call_work_backward, include_comments, include_logs
):
    """
    Export subsets of the provenance graph to file for sharing.

    Besides Nodes of the provenance graph, you can export Groups, Codes, Computers, Comments and Logs.

    By default, the export file will include not only the entities explicitly provided via the command line but also
    their provenance, according to the rules outlined in the documentation.
    You can modify some of those rules using options of this command.
    """
    from aiida.tools.importexport import export, ExportFileFormat
    from aiida.tools.importexport.common.exceptions import ArchiveExportError

    entities = []

    if codes:
        entities.extend(codes)

    if computers:
        entities.extend(computers)

    if groups:
        entities.extend(groups)

    if nodes:
        entities.extend(nodes)

    kwargs = {
        'input_calc_forward': input_calc_forward,
        'input_work_forward': input_work_forward,
        'create_backward': create_backward,
        'return_backward': return_backward,
        'call_calc_backward': call_calc_backward,
        'call_work_backward': call_work_backward,
        'include_comments': include_comments,
        'include_logs': include_logs,
        'overwrite': force
    }

    if archive_format == 'zip':
        export_format = ExportFileFormat.ZIP
        kwargs.update({'use_compression': True})
    elif archive_format == 'zip-uncompressed':
        export_format = ExportFileFormat.ZIP
        kwargs.update({'use_compression': False})
    elif archive_format == 'tar.gz':
        export_format = ExportFileFormat.TAR_GZIPPED

    try:
        export(entities, filename=output_file, file_format=export_format, **kwargs)
    except ArchiveExportError as exception:
        echo.echo_critical('failed to write the archive file. Exception: {}'.format(exception))
    else:
        echo.echo_success('wrote the export archive file to {}'.format(output_file))
    def test_empty_repo_folder_export(self, temp_dir):
        """Check a Node's empty repository folder is exported properly"""
        from aiida.common.folders import Folder
        from aiida.tools.importexport.dbexport import export_tree

        node = orm.Dict().store()
        node_uuid = node.uuid

        node_repo = RepositoryFolder(section=Repository._section_name, uuid=node_uuid)  # pylint: disable=protected-access
        self.assertTrue(
            node_repo.exists(), msg='Newly created and stored Node should have had an existing repository folder'
        )
        for filename, is_file in node_repo.get_content_list(only_paths=False):
            abspath_filename = os.path.join(node_repo.abspath, filename)
            if is_file:
                os.remove(abspath_filename)
            else:
                shutil.rmtree(abspath_filename, ignore_errors=False)
        self.assertFalse(
            node_repo.get_content_list(),
            msg='Repository folder should be empty, instead the following was found: {}'.format(
                node_repo.get_content_list()
            )
        )

        archive_variants = {
            'archive folder': os.path.join(temp_dir, 'export_tree'),
            'tar archive': os.path.join(temp_dir, 'export.tar.gz'),
            'zip archive': os.path.join(temp_dir, 'export.zip')
        }

        export_tree([node], folder=Folder(archive_variants['archive folder']), silent=True)
        export([node], filename=archive_variants['tar archive'], file_format='tar.gz', silent=True)
        export([node], filename=archive_variants['zip archive'], file_format='zip', silent=True)

        for variant, filename in archive_variants.items():
            self.reset_database()
            node_count = orm.QueryBuilder().append(orm.Dict, project='uuid').count()
            self.assertEqual(node_count, 0, msg='After DB reset {} Dict Nodes was (wrongly) found'.format(node_count))

            import_data(filename, silent=True)
            builder = orm.QueryBuilder().append(orm.Dict, project='uuid')
            imported_node_count = builder.count()
            self.assertEqual(
                imported_node_count,
                1,
                msg='After {} import a single Dict Node should have been found, '
                'instead {} was/were found'.format(variant, imported_node_count)
            )
            imported_node_uuid = builder.all()[0][0]
            self.assertEqual(
                imported_node_uuid,
                node_uuid,
                msg='The wrong UUID was found for the imported {}: '
                '{}. It should have been: {}'.format(variant, imported_node_uuid, node_uuid)
            )
Example #17
0
    def test_base_data_type_change(self, temp_dir):
        """ Base Data types type string changed
        Example: Bool: “data.base.Bool.” → “data.bool.Bool.”
        """
        # Test content
        test_content = ('Hello', 6, -1.2399834e12, False)
        test_types = ()
        for node_type in ['str', 'int', 'float', 'bool']:
            add_type = ('data.{}.{}.'.format(node_type, node_type.capitalize()),)
            test_types = test_types.__add__(add_type)

        # List of nodes to be exported
        export_nodes = []

        # Create list of base type nodes
        nodes = [cls(val).store() for val, cls in zip(test_content, (orm.Str, orm.Int, orm.Float, orm.Bool))]
        export_nodes.extend(nodes)

        # Collect uuids for created nodes
        uuids = [n.uuid for n in nodes]

        # Create List() and insert already created nodes into it
        list_node = orm.List()
        list_node.set_list(nodes)
        list_node.store()
        list_node_uuid = list_node.uuid
        export_nodes.append(list_node)

        # Export nodes
        filename = os.path.join(temp_dir, 'export.aiida')
        export(export_nodes, filename=filename, silent=True)

        # Clean the database
        self.reset_database()

        # Import nodes again
        import_data(filename, silent=True)

        # Check whether types are correctly imported
        nlist = orm.load_node(list_node_uuid)  # List
        for uuid, list_value, refval, reftype in zip(uuids, nlist.get_list(), test_content, test_types):
            # Str, Int, Float, Bool
            base = orm.load_node(uuid)
            # Check value/content
            self.assertEqual(base.value, refval)
            # Check type
            msg = "type of node ('{}') is not updated according to db schema v0.4".format(base.node_type)
            self.assertEqual(base.node_type, reftype, msg=msg)

            # List
            # Check value
            self.assertEqual(list_value, refval)

        # Check List type
        msg = "type of node ('{}') is not updated according to db schema v0.4".format(nlist.node_type)
        self.assertEqual(nlist.node_type, 'data.list.List.', msg=msg)
Example #18
0
    def test_mtime_of_imported_comments(self, temp_dir):
        """
        Test mtime does not change for imported comments
        This is related to correct usage of `comment_mode` when importing.
        """
        # Get user
        user = orm.User.objects.get_default()

        comment_content = 'You get what you give'

        # Create node
        calc = orm.CalculationNode().store()
        calc.seal()

        # Create comment
        orm.Comment(calc, user, comment_content).store()
        calc.store()

        # Save UUIDs and mtime
        calc_uuid = calc.uuid
        builder = orm.QueryBuilder().append(orm.Comment, project=['uuid', 'mtime']).all()
        comment_uuid = str(builder[0][0])
        comment_mtime = builder[0][1]

        builder = orm.QueryBuilder().append(orm.CalculationNode, project=['uuid', 'mtime']).all()
        calc_uuid = str(builder[0][0])
        calc_mtime = builder[0][1]

        # Export, reset database and reimport
        export_file = os.path.join(temp_dir, 'export.aiida')
        export([calc], filename=export_file, silent=True)
        self.reset_database()
        import_data(export_file, silent=True)

        # Retrieve node and comment
        builder = orm.QueryBuilder().append(orm.CalculationNode, tag='calc', project=['uuid', 'mtime'])
        builder.append(orm.Comment, with_node='calc', project=['uuid', 'mtime'])

        import_entities = builder.all()[0]

        self.assertEqual(len(import_entities), 4)  # Check we have the correct amount of returned values

        import_calc_uuid = str(import_entities[0])
        import_calc_mtime = import_entities[1]
        import_comment_uuid = str(import_entities[2])
        import_comment_mtime = import_entities[3]

        # Check we have the correct UUIDs
        self.assertEqual(import_calc_uuid, calc_uuid)
        self.assertEqual(import_comment_uuid, comment_uuid)

        # Make sure the mtime is the same after import as it was before export
        self.assertEqual(import_comment_mtime, comment_mtime)
        self.assertEqual(import_calc_mtime, calc_mtime)
Example #19
0
    def setUpClass(cls, *args, **kwargs):
        """Only run to prepare an export file"""
        super().setUpClass()

        data = orm.Data()
        data.label = 'my_test_data_node'
        data.store()
        data.set_extra_many({'b': 2, 'c': 3})
        cls.tmp_folder = tempfile.mkdtemp()
        cls.export_file = os.path.join(cls.tmp_folder, 'export.aiida')
        export([data], outfile=cls.export_file, silent=True)
Example #20
0
    def test_nodes_in_group(self, temp_dir):
        """
        This test checks that nodes that belong to a specific group are
        correctly imported and exported.
        """
        from aiida.common.links import LinkType

        # Create another user
        new_email = '[email protected]'
        user = orm.User(email=new_email)
        user.store()

        # Create a structure data node that has a calculation as output
        sd1 = orm.StructureData()
        sd1.user = user
        sd1.label = 'sd1'
        sd1.store()

        jc1 = orm.CalcJobNode()
        jc1.computer = self.computer
        jc1.set_option('resources', {
            'num_machines': 1,
            'num_mpiprocs_per_machine': 1
        })
        jc1.user = user
        jc1.label = 'jc1'
        jc1.add_incoming(sd1, link_type=LinkType.INPUT_CALC, link_label='link')
        jc1.store()
        jc1.seal()

        # Create a group and add the data inside
        gr1 = orm.Group(label='node_group')
        gr1.store()
        gr1.add_nodes([sd1, jc1])
        gr1_uuid = gr1.uuid

        # At this point we export the generated data
        filename1 = os.path.join(temp_dir, 'export1.tar.gz')
        export([sd1, jc1, gr1], outfile=filename1, silent=True)
        n_uuids = [sd1.uuid, jc1.uuid]
        self.clean_db()
        self.insert_data()
        import_data(filename1, silent=True)

        # Check that the imported nodes are correctly imported and that
        # the user assigned to the nodes is the right one
        for uuid in n_uuids:
            self.assertEqual(orm.load_node(uuid).user.email, new_email)

        # Check that the exported group is imported correctly
        builder = orm.QueryBuilder()
        builder.append(orm.Group, filters={'uuid': {'==': gr1_uuid}})
        self.assertEqual(builder.count(), 1, 'The group was not found.')
Example #21
0
    def test_input_code(self, temp_dir):
        """
        This test checks that when a calculation is exported then the
        corresponding code is also exported. It also checks that the links
        are also in place after the import.
        """
        code_label = 'test_code1'

        code = orm.Code()
        code.set_remote_computer_exec((self.computer, '/bin/true'))
        code.label = code_label
        code.store()

        code_uuid = code.uuid

        calc = orm.CalcJobNode()
        calc.computer = self.computer
        calc.set_option('resources', {
            'num_machines': 1,
            'num_mpiprocs_per_machine': 1
        })

        calc.add_incoming(code, LinkType.INPUT_CALC, 'code')
        calc.store()
        calc.seal()
        links_count = 1

        export_links = get_all_node_links()

        export_file = os.path.join(temp_dir, 'export.aiida')
        export([calc], filename=export_file)

        self.clean_db()

        import_data(export_file)

        # Check that the code node is there
        self.assertEqual(orm.load_node(code_uuid).label, code_label)

        # Check that the link is in place
        import_links = get_all_node_links()
        self.assertListEqual(sorted(export_links), sorted(import_links))
        self.assertEqual(
            len(export_links), links_count,
            'Expected to find only one link from code to '
            'the calculation node before export. {} found.'.format(
                len(export_links)))
        self.assertEqual(
            len(import_links), links_count,
            'Expected to find only one link from code to '
            'the calculation node after import. {} found.'.format(
                len(import_links)))
Example #22
0
def cmd_export(
    group, max_atoms, max_atomic_number, number_species, partial_occupancies, include_duplicates, no_cod_hydrogen,
    sssp_only, filename
):
    """Pass."""
    from aiida import orm
    from aiida.common.constants import elements
    from aiida.tools.importexport import export

    filters_elements = set()
    filters_structures = {'and': []}

    if no_cod_hydrogen:
        filters_structures['and'].append({'id': {'!in': get_cod_hydrogen_structure_ids()}})

    if max_atoms is not None:
        filters_structures['and'].append({'attributes.sites': {'shorter': max_atoms + 1}})

    if max_atomic_number:
        filters_elements = filters_elements.union({e['symbol'] for z, e in elements.items() if z > max_atomic_number})

    if sssp_only:
        # All elements with atomic number of Radon or lower, with the exception of Astatine
        filters_elements = filters_elements.union({e['symbol'] for z, e in elements.items() if z > 86 or z == 85})

    builder = orm.QueryBuilder().append(
        orm.Group, filters={'id': group.pk}, tag='group').append(
        orm.StructureData, with_group='group', filters=filters_structures)

    duplicates = []

    if max_atomic_number or sssp_only:
        structures = []
        for structure, in builder.iterall():
            if all([element not in filters_elements for element in structure.get_symbols_set()]):
                structures.append(structure)
    else:
        structures = builder.all(flat=True)

    if include_duplicates:
        for structure in structures:
            dupes = []
            structure_duplicates = structure.get_extra('duplicates')
            for database, uuids in structure_duplicates.items():
                dupes.extend(uuids)
            for duplicate in dupes:
                if duplicate != structure.uuid:
                    duplicates.append(orm.load_node(duplicate))

    export(structures + duplicates, outfile=filename, create_backward=False, return_backward=False)
Example #23
0
    def test_calc_and_data_nodes_with_comments(self, temp_dir):
        """ Test comments for CalculatioNode and Data node are correctly ex-/imported """
        # Create user, nodes, and comments
        user = orm.User.objects.get_default()

        calc_node = orm.CalculationNode().store()
        calc_node.seal()
        data_node = orm.Data().store()

        comment_one = orm.Comment(calc_node, user, self.comments[0]).store()
        comment_two = orm.Comment(calc_node, user, self.comments[1]).store()

        comment_three = orm.Comment(data_node, user, self.comments[2]).store()
        comment_four = orm.Comment(data_node, user, self.comments[3]).store()

        # Get values prior to export
        calc_uuid = calc_node.uuid
        data_uuid = data_node.uuid
        calc_comments_uuid = [c.uuid for c in [comment_one, comment_two]]
        data_comments_uuid = [c.uuid for c in [comment_three, comment_four]]

        # Export nodes
        export_file = os.path.join(temp_dir, 'export.tar.gz')
        export([calc_node, data_node], outfile=export_file, silent=True)

        # Clean database and reimport exported file
        self.reset_database()
        import_data(export_file, silent=True)

        # Get nodes and comments
        builder = orm.QueryBuilder()
        builder.append(orm.Node, tag='node', project=['uuid'])
        builder.append(orm.Comment, with_node='node', project=['uuid'])
        nodes_and_comments = builder.all()

        self.assertEqual(len(nodes_and_comments), len(self.comments))
        for entry in nodes_and_comments:
            self.assertEqual(len(entry), 2)  # 1 Node + 1 Comment

            import_node_uuid = str(entry[0])
            import_comment_uuid = str(entry[1])

            self.assertIn(import_node_uuid, [calc_uuid, data_uuid])
            if import_node_uuid == calc_uuid:
                # Calc node comments
                self.assertIn(import_comment_uuid, calc_comments_uuid)
            else:
                # Data node comments
                self.assertIn(import_comment_uuid, data_comments_uuid)
Example #24
0
def test_calc_of_structuredata(aiida_profile, tmp_path, file_format):
    """Simple ex-/import of CalcJobNode with input StructureData"""
    aiida_profile.reset_db()

    struct = orm.StructureData()
    struct.store()

    computer = orm.Computer(
        label='localhost-test',
        description='localhost computer set up by test manager',
        hostname='localhost-test',
        workdir=str(tmp_path / 'workdir'),
        transport_type='local',
        scheduler_type='direct')
    computer.store()
    computer.configure()

    calc = orm.CalcJobNode()
    calc.computer = computer
    calc.set_option('resources', {
        'num_machines': 1,
        'num_mpiprocs_per_machine': 1
    })

    calc.add_incoming(struct, link_type=LinkType.INPUT_CALC, link_label='link')
    calc.store()
    calc.seal()

    pks = [struct.pk, calc.pk]

    attrs = {}
    for pk in pks:
        node = orm.load_node(pk)
        attrs[node.uuid] = dict()
        for k in node.attributes.keys():
            attrs[node.uuid][k] = node.get_attribute(k)

    filename = str(tmp_path / 'export.aiida')

    export([calc], filename=filename, file_format=file_format)

    aiida_profile.reset_db()

    import_data(filename)
    for uuid in attrs:
        node = orm.load_node(uuid)
        for k in attrs[uuid].keys():
            assert attrs[uuid][k] == node.get_attribute(k)
Example #25
0
    def prepare_link_flags_export(nodes_to_export, test_data):
        """Helper function"""
        from aiida.common.links import GraphTraversalRules

        export_rules = GraphTraversalRules.EXPORT.value
        traversal_rules = {name: rule.default for name, rule in export_rules.items() if rule.toggleable}

        for export_file, rule_changes, expected_nodes in test_data.values():
            traversal_rules.update(rule_changes)
            export(nodes_to_export[0], outfile=export_file, silent=True, **traversal_rules)

            for node_type in nodes_to_export[1]:
                if node_type in expected_nodes:
                    expected_nodes[node_type].update(nodes_to_export[1][node_type])
                else:
                    expected_nodes[node_type] = nodes_to_export[1][node_type]
Example #26
0
    def test_exclude_comments_flag(self, temp_dir):
        """Test comments and associated commenting users are not exported when using `include_comments=False`."""
        # Create users, node, and comments
        user_one = orm.User.objects.get_default()
        user_two = orm.User(email='[email protected]').store()

        node = orm.Data().store()

        orm.Comment(node, user_one, self.comments[0]).store()
        orm.Comment(node, user_one, self.comments[1]).store()
        orm.Comment(node, user_two, self.comments[2]).store()
        orm.Comment(node, user_two, self.comments[3]).store()

        # Get values prior to export
        users_email = [u.email for u in [user_one, user_two]]
        node_uuid = node.uuid

        # Check that node belongs to user_one
        self.assertEqual(node.user.email, users_email[0])

        # Export nodes, excluding comments
        export_file = os.path.join(temp_dir, 'export.tar.gz')
        export([node],
               outfile=export_file,
               silent=True,
               include_comments=False)

        # Clean database and reimport exported file
        self.reset_database()
        import_data(export_file, silent=True)

        # Get node, users, and comments
        import_nodes = orm.QueryBuilder().append(orm.Node,
                                                 project=['uuid']).all()
        import_comments = orm.QueryBuilder().append(orm.Comment,
                                                    project=['uuid']).all()
        import_users = orm.QueryBuilder().append(orm.User,
                                                 project=['email']).all()

        # There should be exactly: 1 Node, 0 Comments, 1 User
        self.assertEqual(len(import_nodes), 1)
        self.assertEqual(len(import_comments), 0)
        self.assertEqual(len(import_users), 1)

        # Check it's the correct user (and node)
        self.assertEqual(str(import_nodes[0][0]), node_uuid)
        self.assertEqual(str(import_users[0][0]), users_email[0])
Example #27
0
    def test_calc_of_structuredata(self, temp_dir):
        """Simple ex-/import of CalcJobNode with input StructureData"""
        from aiida.common.links import LinkType

        struct = orm.StructureData()
        struct.store()

        calc = orm.CalcJobNode()
        calc.computer = self.computer
        calc.set_option('resources', {
            'num_machines': 1,
            'num_mpiprocs_per_machine': 1
        })

        calc.add_incoming(struct,
                          link_type=LinkType.INPUT_CALC,
                          link_label='link')
        calc.store()
        calc.seal()

        pks = [struct.pk, calc.pk]

        attrs = {}
        for pk in pks:
            node = orm.load_node(pk)
            attrs[node.uuid] = dict()
            for k in node.attributes.keys():
                attrs[node.uuid][k] = node.get_attribute(k)

        filename = os.path.join(temp_dir, 'export.aiida')

        export([calc], filename=filename, silent=True)

        self.clean_db()
        self.create_user()

        # NOTE: it is better to load new nodes by uuid, rather than assuming
        # that they will have the first 3 pks. In fact, a recommended policy in
        # databases is that pk always increment, even if you've deleted elements
        import_data(filename, silent=True)
        for uuid in attrs:
            node = orm.load_node(uuid)
            for k in attrs[uuid].keys():
                self.assertEqual(attrs[uuid][k], node.get_attribute(k))
Example #28
0
    def test_node_process_type(self, temp_dir):
        """ Column `process_type` added to `Node` entity DB table """
        from aiida.engine import run_get_node
        from tests.utils.processes import AddProcess
        # Node types
        node_type = 'process.workflow.WorkflowNode.'
        node_process_type = 'tests.utils.processes.AddProcess'

        # Run workflow
        inputs = {'a': orm.Int(2), 'b': orm.Int(3)}
        _, node = run_get_node(AddProcess, **inputs)

        # Save node uuid
        node_uuid = str(node.uuid)

        # Assert correct type and process_type strings
        self.assertEqual(node.node_type, node_type)
        self.assertEqual(node.process_type, node_process_type)

        # Export nodes
        filename = os.path.join(temp_dir, 'export.aiida')
        export([node], filename=filename)

        # Clean the database and reimport data
        self.clean_db()
        import_data(filename)

        # Retrieve node and check exactly one node is imported
        builder = orm.QueryBuilder()
        builder.append(orm.ProcessNode, project=['uuid'])

        self.assertEqual(builder.count(), 1)

        # Get node uuid and check it is the same as the one exported
        nodes = builder.all()
        imported_node_uuid = str(nodes[0][0])

        self.assertEqual(imported_node_uuid, node_uuid)

        # Check imported node type and process type
        node = orm.load_node(imported_node_uuid)

        self.assertEqual(node.node_type, node_type)
        self.assertEqual(node.process_type, node_process_type)
Example #29
0
    def test_double_return_links_for_workflows(self, temp_dir):
        """
        This test checks that double return links to a node can be exported
        and imported without problems,
        """
        work1 = orm.WorkflowNode()
        work2 = orm.WorkflowNode().store()
        data_in = orm.Int(1).store()
        data_out = orm.Int(2).store()

        work1.add_incoming(data_in, LinkType.INPUT_WORK, 'input_i1')
        work1.add_incoming(work2, LinkType.CALL_WORK, 'call')
        work1.store()
        data_out.add_incoming(work1, LinkType.RETURN, 'return1')
        data_out.add_incoming(work2, LinkType.RETURN, 'return2')
        links_count = 4

        work1.seal()
        work2.seal()

        uuids_wanted = set(_.uuid for _ in (work1, data_out, data_in, work2))
        links_wanted = get_all_node_links()

        export_file = os.path.join(temp_dir, 'export.tar.gz')
        export([data_out, work1, work2, data_in],
               outfile=export_file,
               silent=True)

        self.reset_database()

        import_data(export_file, silent=True)

        uuids_in_db = [
            str(uuid) for [uuid] in orm.QueryBuilder().append(
                orm.Node, project='uuid').all()
        ]
        self.assertListEqual(sorted(uuids_wanted), sorted(uuids_in_db))

        links_in_db = get_all_node_links()
        self.assertListEqual(sorted(links_wanted), sorted(links_in_db))

        # Assert number of links, checking both RETURN links are included
        self.assertEqual(len(links_wanted), links_count)  # Before export
        self.assertEqual(len(links_in_db), links_count)  # After import
Example #30
0
    def test_simple_import(self):
        """
        This is a very simple test which checks that an export file with nodes
        that are not associated to a computer is imported correctly. In Django
        when such nodes are exported, there is an empty set for computers
        in the export file. In SQLA there is such a set only when a computer is
        associated with the exported nodes. When an empty computer set is
        found at the export file (when imported to an SQLA profile), the SQLA
        import code used to crash. This test demonstrates this problem.
        """
        parameters = orm.Dict(
            dict={
                'Pr': {
                    'cutoff': 50.0,
                    'pseudo_type': 'Wentzcovitch',
                    'dual': 8,
                    'cutoff_units': 'Ry'
                },
                'Ru': {
                    'cutoff': 40.0,
                    'pseudo_type': 'SG15',
                    'dual': 4,
                    'cutoff_units': 'Ry'
                },
            }).store()

        with tempfile.NamedTemporaryFile() as handle:
            nodes = [parameters]
            export(nodes, outfile=handle.name, overwrite=True, silent=True)

            # Check that we have the expected number of nodes in the database
            self.assertEqual(orm.QueryBuilder().append(orm.Node).count(),
                             len(nodes))

            # Clean the database and verify there are no nodes left
            self.clean_db()
            self.create_user()
            self.assertEqual(orm.QueryBuilder().append(orm.Node).count(), 0)

            # After importing we should have the original number of nodes again
            import_data(handle.name, silent=True)
            self.assertEqual(orm.QueryBuilder().append(orm.Node).count(),
                             len(nodes))