Example #1
0
def import_data(aiida_env):
    from aiida.tools.importexport import import_data
    for db_export_file in [
            'db_dump_kkrcalc.tar.gz', 'db_dump_kkrflex_create.tar.gz',
            'db_dump_vorocalc.tar.gz'
    ]:
        import_data('files/' + db_export_file)
Example #2
0
    def test_kkr_from_kkr(self):
        """
        continue KKR calculation after a previous KKR calculation instead of starting from voronoi
        """
        from aiida.orm import Code, load_node
        from aiida.plugins import DataFactory
        from masci_tools.io.kkr_params import kkrparams
        from aiida_kkr.calculations.kkr import KkrCalculation
        Dict = DataFactory('dict')

        # load necessary files from db_dump files
        from aiida.tools.importexport import import_data
        import_data('files/db_dump_kkrcalc.tar.gz')
        kkr_calc = load_node('3058bd6c-de0b-400e-aff5-2331a5f5d566')

        # prepare computer and code (needed so that
        prepare_code(kkr_codename, codelocation, computername, workdir)

        # extract KKR parameter (add missing values)
        params_node = kkr_calc.inputs.parameters

        # load code from database and create new voronoi calculation
        code = Code.get_from_string(kkr_codename+'@'+computername)
        options = {'resources': {'num_machines':1, 'tot_num_mpiprocs':1}, 'queue_name': queuename}
        builder = KkrCalculation.get_builder()
        builder.code = code
        builder.metadata.options = options
        builder.parameters = params_node
        builder.parent_folder = kkr_calc.outputs.remote_folder
        builder.metadata.dry_run = True
        from aiida.engine import run
        run(builder)
 def test_neworder_potential_wf(self):
     from numpy import loadtxt
     from aiida.orm import load_node
     from aiida.plugins import DataFactory
     from aiida_kkr.tools.common_workfunctions import neworder_potential_wf
     from aiida.tools.importexport import import_data
     Dict = DataFactory('dict')
     import_data('files/db_dump_kkrflex_create.tar.gz')
     GF_host_calc = load_node(
         'baabef05-f418-4475-bba5-ef0ee3fd5ca6').outputs
     neworder_pot1 = [
         int(i) for i in
         loadtxt(GF_host_calc.retrieved.open('scoef'), skiprows=1)[:, 3] - 1
     ]
     settings_dict = {
         'pot1': 'out_potential',
         'out_pot': 'potential_imp',
         'neworder': neworder_pot1
     }
     settings = Dict(dict=settings_dict)
     startpot_imp_sfd = neworder_potential_wf(
         settings_node=settings,
         parent_calc_folder=GF_host_calc.remote_folder)
     assert startpot_imp_sfd.get_object_content(
         startpot_imp_sfd.filename
     )[::
       1000] == u'C12807143D556463084.6+55 7D117 9D-87 0+25\n20.70351.75\n0521259.2+491.0-462. 02621D74112D03547T00 4D02116D502 6D39\n96.20261.50941.4944.7+30 98-29 .5-3625D07193.58104D0773D27252285417D341 9.506544D548447094.9+38 91063 54-08 6D28277.60909.98111'
Example #4
0
    def test_kkrflex_writeout_wc(self):
        """
        simple Cu noSOC, FP, lmax2 full example using scf workflow
        """
        from aiida.orm import Code, load_node
        from aiida.plugins import DataFactory
        from masci_tools.io.kkr_params import kkrparams
        from aiida_kkr.workflows.gf_writeout import kkr_flex_wc
        from numpy import array
        import os

        Dict = DataFactory('dict')
        StructureData = DataFactory('structure')

        # prepare computer and code (needed so that
        prepare_code(kkr_codename, codelocation, computername, workdir)

        # here we create a parameter node for the workflow input (workflow specific parameter) and adjust the convergence criterion.
        wfd =kkr_flex_wc.get_wf_defaults()
        options = {'queue_name' : queuename, 'resources': {"num_machines": 1},'max_wallclock_seconds' : 5*60, 'custom_scheduler_commands' : '', 'use_mpi' : False}
        options = Dict(dict=options)

        # The scf-workflow needs also the voronoi and KKR codes to be able to run the calulations
        KKRCode = Code.get_from_string(kkr_codename+'@'+computername)

        imp_info = Dict(dict={'Rcut':2.5533, 'ilayer_center': 0, 'Zimp':[29.]})

        label = 'GF_writeout Cu bulk'
        descr = 'GF_writeout workflow for Cu bulk'

        from aiida.tools.importexport import import_data
        import_data('files/db_dump_kkrcalc.tar.gz')
        kkr_calc_remote = load_node('3058bd6c-de0b-400e-aff5-2331a5f5d566').outputs.remote_folder

        # create process builder to set parameters
        builder = kkr_flex_wc.get_builder()
        builder.metadata.description = descr
        builder.metadata.label = label
        builder.kkr = KKRCode
        builder.options = options
        builder.remote_data = kkr_calc_remote
        builder.impurity_info = imp_info

        # now run calculation
        from aiida.engine import run
        out = run(builder)

        n = out['workflow_info']
        n = n.get_dict()

        assert n.get('successful')
        assert n.get('list_of_errors') == []

        d = out['GF_host_remote']
        assert isinstance(d, DataFactory('remote'))

        kkrflex_calc = load_node(n.get('pk_flexcalc'))
        kkrflex_retrieved = kkrflex_calc.outputs.retrieved
        for name in 'tmat green atominfo intercell_cmoms intercell_ref'.split():
            assert 'kkrflex_'+name in kkrflex_retrieved.list_object_names()
Example #5
0
    def test_complex_workflow_graph_export_sets(self, temp_dir):
        """Test ex-/import of individual nodes in complex graph"""
        for export_conf in range(0, 9):

            _, (export_node,
                export_target) = self.construct_complex_graph(export_conf)
            export_target_uuids = set(_.uuid for _ in export_target)

            export_file = os.path.join(temp_dir, 'export.aiida')
            export([export_node], filename=export_file, overwrite=True)
            export_node_str = str(export_node)

            self.clean_db()

            import_data(export_file)

            # Get all the nodes of the database
            builder = orm.QueryBuilder()
            builder.append(orm.Node, project='uuid')
            imported_node_uuids = set(_[0] for _ in builder.all())

            self.assertSetEqual(
                export_target_uuids, imported_node_uuids,
                'Problem in comparison of export node: ' + export_node_str +
                '\n' + 'Expected set: ' + str(export_target_uuids) + '\n' +
                'Imported set: ' + str(imported_node_uuids) + '\n' +
                'Difference: ' + str(
                    export_target_uuids.symmetric_difference(
                        imported_node_uuids)))
Example #6
0
    def test_critical_log_msg_and_metadata(self, temp_dir):
        """ Testing logging of critical message """
        message = 'Testing logging of critical failure'
        calc = orm.CalculationNode()

        # Firing a log for an unstored node should not end up in the database
        calc.logger.critical(message)
        # There should be no log messages for the unstored object
        self.assertEqual(len(orm.Log.objects.all()), 0)

        # After storing the node, logs above log level should be stored
        calc.store()
        calc.seal()
        calc.logger.critical(message)

        # Store Log metadata
        log_metadata = orm.Log.objects.get(dbnode_id=calc.id).metadata

        export_file = os.path.join(temp_dir, 'export.tar.gz')
        export([calc], outfile=export_file, silent=True)

        self.reset_database()

        import_data(export_file, silent=True)

        # Finding all the log messages
        logs = orm.Log.objects.all()

        self.assertEqual(len(logs), 1)
        self.assertEqual(logs[0].message, message)
        self.assertEqual(logs[0].metadata, log_metadata)
Example #7
0
    def test_no_node_migration(self, tmp_path, external_archive):
        """Test migration of archive file that has no Node entities."""
        input_file = get_archive_file('export_v0.3_no_Nodes.aiida',
                                      **external_archive)
        output_file = tmp_path / 'output_file.aiida'

        migrator_cls = get_migrator(detect_archive_type(input_file))
        migrator = migrator_cls(input_file)

        # Perform the migration
        migrator.migrate(newest_version, output_file)

        # Load the migrated file
        import_data(output_file)

        # Check known entities
        assert orm.QueryBuilder().append(orm.Node).count() == 0
        computer_query = orm.QueryBuilder().append(orm.Computer,
                                                   project=['uuid'])
        assert computer_query.all(flat=True) == [
            '4f33c6fd-b624-47df-9ffb-a58f05d323af'
        ]
        user_query = orm.QueryBuilder().append(orm.User, project=['email'])
        assert set(user_query.all(flat=True)) == {
            orm.User.objects.get_default().email, 'aiida@localhost'
        }
Example #8
0
    def test_group_import_existing(self, temp_dir):
        """
        Testing what happens when I try to import a group that already exists in the
        database. This should raise an appropriate exception
        """
        grouplabel = 'node_group_existing'

        # Create another user
        new_email = '[email protected]'
        user = orm.User(email=new_email)
        user.store()

        # Create a structure data node
        sd1 = orm.StructureData()
        sd1.user = user
        sd1.label = 'sd'
        sd1.store()

        # Create a group and add the data inside
        group = orm.Group(label=grouplabel)
        group.store()
        group.add_nodes([sd1])

        # At this point we export the generated data
        filename = os.path.join(temp_dir, 'export1.tar.gz')
        export([group], outfile=filename, silent=True)
        self.clean_db()
        self.insert_data()

        # Creating a group of the same name
        group = orm.Group(label='node_group_existing')
        group.store()
        import_data(filename, silent=True)
        # The import should have created a new group with a suffix
        # I check for this:
        builder = orm.QueryBuilder().append(
            orm.Group, filters={'label': {
                'like': grouplabel + '%'
            }})
        self.assertEqual(builder.count(), 2)
        # Now I check for the group having one member, and whether the name is different:
        builder = orm.QueryBuilder()
        builder.append(orm.Group,
                       filters={'label': {
                           'like': grouplabel + '%'
                       }},
                       tag='g',
                       project='label')
        builder.append(orm.StructureData, with_group='g')
        self.assertEqual(builder.count(), 1)
        # I check that the group name was changed:
        self.assertTrue(builder.all()[0][0] != grouplabel)
        # I import another name, the group should not be imported again
        import_data(filename, silent=True)
        builder = orm.QueryBuilder()
        builder.append(orm.Group,
                       filters={'label': {
                           'like': grouplabel + '%'
                       }})
        self.assertEqual(builder.count(), 2)
Example #9
0
    def test_high_level_workflow_links(self, temp_dir):
        """
        This test checks that all the needed links are correctly exported and imported.
        INPUT_CALC, INPUT_WORK, CALL_CALC, CALL_WORK, CREATE, and RETURN
        links connecting Data nodes and high-level Calculation and Workflow nodes:
        CalcJobNode, CalcFunctionNode, WorkChainNode, WorkFunctionNode
        """
        high_level_calc_nodes = [['CalcJobNode', 'CalcJobNode'], ['CalcJobNode', 'CalcFunctionNode'],
                                 ['CalcFunctionNode', 'CalcJobNode'], ['CalcFunctionNode', 'CalcFunctionNode']]

        high_level_work_nodes = [['WorkChainNode', 'WorkChainNode'], ['WorkChainNode', 'WorkFunctionNode'],
                                 ['WorkFunctionNode', 'WorkChainNode'], ['WorkFunctionNode', 'WorkFunctionNode']]

        for calcs in high_level_calc_nodes:
            for works in high_level_work_nodes:
                self.reset_database()

                graph_nodes, _ = self.construct_complex_graph(calc_nodes=calcs, work_nodes=works)

                # Getting the input, create, return and call links
                builder = orm.QueryBuilder()
                builder.append(orm.Node, project='uuid')
                builder.append(
                    orm.Node,
                    project='uuid',
                    edge_project=['label', 'type'],
                    edge_filters={
                        'type': {
                            'in': (
                                LinkType.INPUT_CALC.value, LinkType.INPUT_WORK.value, LinkType.CREATE.value,
                                LinkType.RETURN.value, LinkType.CALL_CALC.value, LinkType.CALL_WORK.value
                            )
                        }
                    }
                )

                self.assertEqual(
                    builder.count(),
                    13,
                    msg='Failed with c1={}, c2={}, w1={}, w2={}'.format(calcs[0], calcs[1], works[0], works[1])
                )

                export_links = builder.all()

                export_file = os.path.join(temp_dir, 'export.tar.gz')
                export(graph_nodes, outfile=export_file, silent=True, overwrite=True)

                self.reset_database()

                import_data(export_file, silent=True)
                import_links = get_all_node_links()

                export_set = [tuple(_) for _ in export_links]
                import_set = [tuple(_) for _ in import_links]

                self.assertSetEqual(
                    set(export_set),
                    set(import_set),
                    msg='Failed with c1={}, c2={}, w1={}, w2={}'.format(calcs[0], calcs[1], works[0], works[1])
                )
Example #10
0
    def _import_with_migrate(filename,
                             tempdir=temp_dir,
                             import_kwargs=None,
                             try_migration=True):
        from click import echo
        from aiida.tools.importexport import import_data
        from aiida.tools.importexport import EXPORT_VERSION, IncompatibleArchiveVersionError
        # these are only availbale after aiida >= 1.5.0, maybe rely on verdi import instead
        from aiida.tools.importexport import detect_archive_type
        from aiida.tools.importexport.archive.migrators import get_migrator
        from aiida.tools.importexport.common.config import ExportFileFormat
        if import_kwargs is None:
            import_kwargs = _DEFAULT_IMPORT_KWARGS
        archive_path = filename

        try:
            import_data(archive_path, **import_kwargs)
        except IncompatibleArchiveVersionError as exception:
            #raise ValueError
            if try_migration:
                echo(
                    f'incompatible version detected for {archive_path}, trying migration'
                )
                migrator = get_migrator(
                    detect_archive_type(archive_path))(archive_path)
                archive_path = migrator.migrate(EXPORT_VERSION,
                                                None,
                                                out_compression='none',
                                                work_dir=tempdir)
                import_data(archive_path, **import_kwargs)
Example #11
0
def _try_import(migration_performed, file_to_import, archive, group, migration, non_interactive, **kwargs):
    """Utility function for `verdi import` to try to import archive

    :param migration_performed: Boolean to determine the exception message to throw for
        `~aiida.tools.importexport.common.exceptions.IncompatibleArchiveVersionError`
    :param file_to_import: Absolute path, including filename, of file to be migrated.
    :param archive: Filename of archive to be migrated, and later attempted imported.
    :param group: AiiDA Group into which the import will be associated.
    :param migration: Whether or not to force migration of archive, if needed.
    :param non_interactive: Whether or not the user should be asked for input for any reason.
    :param kwargs: Key-word-arguments that _must_ contain:
        * `'extras_mode_existing'`: `import_data`'s `'extras_mode_existing'` keyword, determining import rules for
        Extras.
        * `'extras_mode_new'`: `import_data`'s `'extras_mode_new'` keyword, determining import rules for Extras.
        * `'comment_mode'`: `import_data`'s `'comment_mode'` keyword, determining import rules for Comments.
    """
    from aiida.tools.importexport import import_data, IncompatibleArchiveVersionError

    # Checks
    expected_keys = ['extras_mode_existing', 'extras_mode_new', 'comment_mode']
    for key in expected_keys:
        if key not in kwargs:
            raise ValueError("{} needed for utility function '{}' to use in 'import_data'".format(key, '_try_import'))

    # Initialization
    migrate_archive = False

    try:
        import_data(file_to_import, group, **kwargs)
    except IncompatibleArchiveVersionError as exception:
        if migration_performed:
            # Migration has been performed, something is still wrong
            crit_message = '{} has been migrated, but it still cannot be imported.\n{}'.format(archive, exception)
            echo.echo_critical(crit_message)
        else:
            # Migration has not yet been tried.
            if migration:
                # Confirm migration
                echo.echo_warning(str(exception).splitlines()[0])
                if non_interactive:
                    migrate_archive = True
                else:
                    migrate_archive = click.confirm(
                        'Do you want to try and migrate {} to the newest export file version?\n'
                        'Note: This will not change your current file.'.format(archive),
                        default=True,
                        abort=True
                    )
            else:
                # Abort
                echo.echo_critical(str(exception))
    except Exception:
        echo.echo_error('an exception occurred while importing the archive {}'.format(archive))
        echo.echo(traceback.format_exc())
        if not non_interactive:
            click.confirm('do you want to continue?', abort=True)
    else:
        echo.echo_success('imported archive {}'.format(archive))

    return migrate_archive
Example #12
0
    def test_no_node_export(self, temp_dir):
        """Test migration of export file that has no Nodes"""
        input_file = get_archive_file('export_v0.3_no_Nodes.aiida', **self.external_archive)
        output_file = os.path.join(temp_dir, 'output_file.aiida')

        # Known entities
        computer_uuids = [self.computer.uuid]  # pylint: disable=no-member
        user_emails = [orm.User.objects.get_default().email]

        # Known export file content used for checks
        node_count = 0
        computer_count = 1 + 1  # localhost is always present
        computer_uuids.append('4f33c6fd-b624-47df-9ffb-a58f05d323af')
        user_emails.append('aiida@localhost')

        # Perform the migration
        migrate_archive(input_file, output_file)

        # Load the migrated file
        import_data(output_file, silent=True)

        # Check known number of entities is present
        self.assertEqual(orm.QueryBuilder().append(orm.Node).count(), node_count)
        self.assertEqual(orm.QueryBuilder().append(orm.Computer).count(), computer_count)

        # Check unique identifiers
        computers = orm.QueryBuilder().append(orm.Computer, project=['uuid']).all()[0][0]
        users = orm.QueryBuilder().append(orm.User, project=['email']).all()[0][0]
        self.assertIn(computers, computer_uuids)
        self.assertIn(users, user_emails)
Example #13
0
    def test_workcalculation(self, temp_dir):
        """Test simple master/slave WorkChainNodes"""
        from aiida.common.links import LinkType

        master = orm.WorkChainNode()
        slave = orm.WorkChainNode()

        input_1 = orm.Int(3).store()
        input_2 = orm.Int(5).store()
        output_1 = orm.Int(2).store()

        master.add_incoming(input_1, LinkType.INPUT_WORK, 'input_1')
        slave.add_incoming(master, LinkType.CALL_WORK, 'CALL')
        slave.add_incoming(input_2, LinkType.INPUT_WORK, 'input_2')

        master.store()
        slave.store()

        output_1.add_incoming(master, LinkType.RETURN, 'RETURN')

        master.seal()
        slave.seal()

        uuids_values = [(v.uuid, v.value) for v in (output_1, )]
        filename1 = os.path.join(temp_dir, 'export1.tar.gz')
        export([output_1], outfile=filename1, silent=True)
        self.clean_db()
        self.insert_data()
        import_data(filename1, silent=True)

        for uuid, value in uuids_values:
            self.assertEqual(orm.load_node(uuid).value, value)
Example #14
0
    def test_calcfunction(self, temp_dir):
        """Test @calcfunction"""
        from aiida.engine import calcfunction
        from aiida.common.exceptions import NotExistent

        @calcfunction
        def add(a, b):
            """Add 2 numbers"""
            return {'res': orm.Float(a + b)}

        def max_(**kwargs):
            """select the max value"""
            max_val = max([(v.value, v) for v in kwargs.values()])
            return {'res': max_val[1]}

        # I'm creating a bunch of numbers
        a, b, c, d, e = (orm.Float(i).store() for i in range(5))
        # this adds the maximum number between bcde to a.
        res = add(a=a, b=max_(b=b, c=c, d=d, e=e)['res'])['res']
        # These are the uuids that would be exported as well (as parents) if I wanted the final result
        uuids_values = [(a.uuid, a.value), (e.uuid, e.value), (res.uuid, res.value)]
        # These are the uuids that shouldn't be exported since it's a selection.
        not_wanted_uuids = [v.uuid for v in (b, c, d)]
        # At this point we export the generated data
        filename1 = os.path.join(temp_dir, 'export1.tar.gz')
        export([res], outfile=filename1, silent=True, return_backward=True)
        self.clean_db()
        self.insert_data()
        import_data(filename1, silent=True)
        # Check that the imported nodes are correctly imported and that the value is preserved
        for uuid, value in uuids_values:
            self.assertEqual(orm.load_node(uuid).value, value)
        for uuid in not_wanted_uuids:
            with self.assertRaises(NotExistent):
                orm.load_node(uuid)
Example #15
0
    def test_exclude_logs_flag(self, temp_dir):
        """Test that the `include_logs` argument for `export` works."""
        log_msg = 'Testing logging of critical failure'

        # Create node
        calc = orm.CalculationNode()
        calc.store()
        calc.seal()

        # Create log message
        calc.logger.critical(log_msg)

        # Save uuids prior to export
        calc_uuid = calc.uuid

        # Export, excluding logs
        export_file = os.path.join(temp_dir, 'export.tar.gz')
        export([calc], outfile=export_file, silent=True, include_logs=False)

        # Clean database and reimport exported data
        self.reset_database()
        import_data(export_file, silent=True)

        # Finding all the log messages
        import_calcs = orm.QueryBuilder().append(orm.CalculationNode,
                                                 project=['uuid']).all()
        import_logs = orm.QueryBuilder().append(orm.Log,
                                                project=['uuid']).all()

        # There should be exactly: 1 orm.CalculationNode, 0 Logs
        self.assertEqual(len(import_calcs), 1)
        self.assertEqual(len(import_logs), 0)

        # Check it's the correct node
        self.assertEqual(str(import_calcs[0][0]), calc_uuid)
Example #16
0
    def test_check_for_export_format_version(self):
        """Test the check for the export format version."""
        # Creating a folder for the import/export files
        export_file_tmp_folder = tempfile.mkdtemp()
        unpack_tmp_folder = tempfile.mkdtemp()
        try:
            struct = orm.StructureData()
            struct.store()

            filename = os.path.join(export_file_tmp_folder, 'export.tar.gz')
            export([struct], outfile=filename, silent=True)

            with tarfile.open(filename, 'r:gz', format=tarfile.PAX_FORMAT) as tar:
                tar.extractall(unpack_tmp_folder)

            with open(os.path.join(unpack_tmp_folder, 'metadata.json'), 'r', encoding='utf8') as fhandle:
                metadata = json.load(fhandle)
            metadata['export_version'] = 0.0

            with open(os.path.join(unpack_tmp_folder, 'metadata.json'), 'wb') as fhandle:
                json.dump(metadata, fhandle)

            with tarfile.open(filename, 'w:gz', format=tarfile.PAX_FORMAT) as tar:
                tar.add(unpack_tmp_folder, arcname='')

            self.tearDownClass()
            self.setUpClass()

            with self.assertRaises(exceptions.IncompatibleArchiveVersionError):
                import_data(filename, silent=True)
        finally:
            # Deleting the created temporary folders
            shutil.rmtree(export_file_tmp_folder, ignore_errors=True)
            shutil.rmtree(unpack_tmp_folder, ignore_errors=True)
Example #17
0
    def link_flags_import_helper(self, test_data):
        """Helper function"""
        for test, (export_file, _, expected_nodes) in test_data.items():
            self.clean_db()

            import_data(export_file)

            nodes_util = {
                'work': orm.WorkflowNode,
                'calc': orm.CalculationNode,
                'data': orm.Data
            }
            for node_type, node_cls in nodes_util.items():
                if node_type in expected_nodes:
                    builder = orm.QueryBuilder().append(node_cls,
                                                        project='uuid')
                    self.assertEqual(
                        builder.count(),
                        len(expected_nodes[node_type]),
                        msg='Expected {} {} node(s), but got {}. Test: "{}"'.
                        format(len(expected_nodes[node_type]), node_type,
                               builder.count(), test))
                    for node_uuid in builder.iterall():
                        self.assertIn(node_uuid[0],
                                      expected_nodes[node_type],
                                      msg=f'Failed for test: "{test}"')
Example #18
0
def test_check_for_export_format_version(aiida_profile, tmp_path):
    """Test the check for the export format version."""
    # Creating a folder for the archive files
    export_file_tmp_folder = tmp_path / 'export_tmp'
    export_file_tmp_folder.mkdir()
    unpack_tmp_folder = tmp_path / 'unpack_tmp'
    unpack_tmp_folder.mkdir()

    aiida_profile.reset_db()

    struct = orm.StructureData()
    struct.store()

    filename = str(export_file_tmp_folder / 'export.aiida')
    export([struct], filename=filename, file_format='tar.gz')

    with tarfile.open(filename, 'r:gz', format=tarfile.PAX_FORMAT) as tar:
        tar.extractall(unpack_tmp_folder)

    with (unpack_tmp_folder / 'metadata.json').open(
            'r', encoding='utf8') as fhandle:
        metadata = json.load(fhandle)
    metadata['export_version'] = 0.0

    with (unpack_tmp_folder / 'metadata.json').open('wb') as fhandle:
        json.dump(metadata, fhandle)

    with tarfile.open(filename, 'w:gz', format=tarfile.PAX_FORMAT) as tar:
        tar.add(unpack_tmp_folder, arcname='')

    aiida_profile.reset_db()

    with pytest.raises(exceptions.IncompatibleArchiveVersionError):
        import_data(filename)
Example #19
0
    def test_import_folder(self):
        """Verify a pre-extracted archive (aka. a folder with the archive structure) can be imported.

        It is important to check that the source directory or any of its contents are not deleted after import.
        """
        from aiida.common.folders import SandboxFolder
        from tests.utils.archives import get_archive_file
        from aiida.tools.importexport.common.archive import extract_zip

        archive = get_archive_file('arithmetic.add.aiida', filepath='calcjob')

        with SandboxFolder() as temp_dir:
            extract_zip(archive, temp_dir, silent=True)

            # Make sure the JSON files and the nodes subfolder was correctly extracted (is present),
            # then try to import it by passing the extracted folder to the import function.
            for name in {'metadata.json', 'data.json', 'nodes'}:
                self.assertTrue(os.path.exists(os.path.join(temp_dir.abspath, name)))

            # Get list of all folders in extracted archive
            org_folders = []
            for dirpath, dirnames, _ in os.walk(temp_dir.abspath):
                org_folders += [os.path.join(dirpath, dirname) for dirname in dirnames]

            import_data(temp_dir.abspath, silent=True)

            # Check nothing from the source was deleted
            src_folders = []
            for dirpath, dirnames, _ in os.walk(temp_dir.abspath):
                src_folders += [os.path.join(dirpath, dirname) for dirname in dirnames]
            self.maxDiff = None  # pylint: disable=invalid-name
            self.assertListEqual(org_folders, src_folders)
Example #20
0
    def test_input_and_create_links(self, temp_dir):
        """
        Simple test that will verify that INPUT and CREATE links are properly exported and
        correctly recreated upon import.
        """
        node_work = orm.CalculationNode()
        node_input = orm.Int(1).store()
        node_output = orm.Int(2).store()

        node_work.add_incoming(node_input, LinkType.INPUT_CALC, 'input')
        node_work.store()
        node_output.add_incoming(node_work, LinkType.CREATE, 'output')

        node_work.seal()

        export_links = get_all_node_links()
        export_file = os.path.join(temp_dir, 'export.aiida')
        export([node_output], filename=export_file)

        self.clean_db()

        import_data(export_file)
        import_links = get_all_node_links()

        export_set = [tuple(_) for _ in export_links]
        import_set = [tuple(_) for _ in import_links]

        self.assertSetEqual(set(export_set), set(import_set))
Example #21
0
    def test_dangling_link_to_existing_db_node(self, temp_dir):
        """A dangling link that references a Node that is not included in the archive should `not` be importable"""
        struct = orm.StructureData()
        struct.store()
        struct_uuid = struct.uuid

        calc = orm.CalculationNode()
        calc.add_incoming(struct, LinkType.INPUT_CALC, 'input')
        calc.store()
        calc.seal()
        calc_uuid = calc.uuid

        filename = os.path.join(temp_dir, 'export.aiida')
        export([struct], filename=filename, file_format='tar.gz')

        unpack = SandboxFolder()
        with tarfile.open(filename, 'r:gz', format=tarfile.PAX_FORMAT) as tar:
            tar.extractall(unpack.abspath)

        with open(unpack.get_abs_path('data.json'), 'r',
                  encoding='utf8') as fhandle:
            data = json.load(fhandle)
        data['links_uuid'].append({
            'output': calc.uuid,
            'input': struct.uuid,
            'label': 'input',
            'type': LinkType.INPUT_CALC.value
        })

        with open(unpack.get_abs_path('data.json'), 'wb') as fhandle:
            json.dump(data, fhandle)

        with tarfile.open(filename, 'w:gz', format=tarfile.PAX_FORMAT) as tar:
            tar.add(unpack.abspath, arcname='')

        # Make sure the CalculationNode is still in the database
        builder = orm.QueryBuilder().append(orm.CalculationNode,
                                            project='uuid')
        self.assertEqual(
            builder.count(),
            1,
            msg=
            f'There should be a single CalculationNode, instead {builder.count()} has been found'
        )
        self.assertEqual(builder.all()[0][0], calc_uuid)

        with self.assertRaises(DanglingLinkError):
            import_data(filename)

        # Using the flag `ignore_unknown_nodes` should import it without problems
        import_data(filename, ignore_unknown_nodes=True)
        builder = orm.QueryBuilder().append(orm.StructureData, project='uuid')
        self.assertEqual(
            builder.count(),
            1,
            msg=
            f'There should be a single StructureData, instead {builder.count()} has been found'
        )
        self.assertEqual(builder.all()[0][0], struct_uuid)
    def test_missing_node_repo_folder_import(self, temp_dir):
        """
        Make sure `~aiida.tools.importexport.common.exceptions.CorruptArchive` is raised during import when missing
        Node repository folder.
        Create and export a Node and manually remove its repository folder in the export file.
        Attempt to import it and make sure `~aiida.tools.importexport.common.exceptions.CorruptArchive` is raised,
        due to the missing folder.
        """
        import tarfile

        from aiida.common.folders import SandboxFolder
        from aiida.tools.importexport.common.archive import extract_tar
        from aiida.tools.importexport.common.config import NODES_EXPORT_SUBFOLDER
        from aiida.tools.importexport.common.utils import export_shard_uuid

        node = orm.CalculationNode().store()
        node.seal()
        node_uuid = node.uuid

        node_repo = RepositoryFolder(section=Repository._section_name, uuid=node_uuid)  # pylint: disable=protected-access
        self.assertTrue(
            node_repo.exists(), msg='Newly created and stored Node should have had an existing repository folder'
        )

        # Export and reset db
        filename = os.path.join(temp_dir, 'export.aiida')
        export([node], filename=filename, file_format='tar.gz', silent=True)
        self.reset_database()

        # Untar export file, remove repository folder, re-tar
        node_shard_uuid = export_shard_uuid(node_uuid)
        node_top_folder = node_shard_uuid.split('/')[0]
        with SandboxFolder() as folder:
            extract_tar(filename, folder, silent=True, nodes_export_subfolder=NODES_EXPORT_SUBFOLDER)
            node_folder = folder.get_subfolder(os.path.join(NODES_EXPORT_SUBFOLDER, node_shard_uuid))
            self.assertTrue(
                node_folder.exists(), msg="The Node's repository folder should still exist in the export file"
            )

            # Removing the Node's repository folder from the export file
            shutil.rmtree(
                folder.get_subfolder(os.path.join(NODES_EXPORT_SUBFOLDER, node_top_folder)).abspath, ignore_errors=True
            )
            self.assertFalse(
                node_folder.exists(),
                msg="The Node's repository folder should now have been removed in the export file"
            )

            filename_corrupt = os.path.join(temp_dir, 'export_corrupt.aiida')
            with tarfile.open(filename_corrupt, 'w:gz', format=tarfile.PAX_FORMAT, dereference=True) as tar:
                tar.add(folder.abspath, arcname='')

        # Try to import, check it raises and check the raise message
        with self.assertRaises(exceptions.CorruptArchive) as exc:
            import_data(filename_corrupt, silent=True)

        self.assertIn(
            'Unable to find the repository folder for Node with UUID={}'.format(node_uuid), str(exc.exception)
        )
Example #23
0
    def import_attributes(self):
        """Import an AiiDA database"""
        import_data(self.export_file)

        builder = orm.QueryBuilder().append(orm.Data, filters={'label': 'my_test_data_node'})

        self.assertEqual(builder.count(), 1)
        self.imported_node = builder.all()[0][0]
Example #24
0
    def import_extras(self, mode_new='import'):
        """Import an aiida database"""
        import_data(self.export_file, extras_mode_new=mode_new)

        builder = orm.QueryBuilder().append(orm.Data, filters={'label': 'my_test_data_node'})

        self.assertEqual(builder.count(), 1)
        self.imported_node = builder.all()[0][0]
Example #25
0
 def test_extras_import_mode_correct(self):
     """Test all possible import modes except 'ask'"""
     self.import_extras()
     for mode1 in ['k', 'n']:  # keep or not keep old extras
         for mode2 in ['n', 'c']:  # create or not create new extras
             for mode3 in ['l', 'u', 'd']:  # leave old, update or delete collided extras
                 mode = mode1 + mode2 + mode3
                 import_data(self.export_file, extras_mode_existing=mode)
Example #26
0
    def test_v02_to_newest(self, temp_dir):
        """Test migration of exported files from v0.2 to newest export version"""
        # Get export file with export version 0.2
        input_file = get_archive_file('export_v0.2.aiida',
                                      **self.external_archive)
        output_file = os.path.join(temp_dir, 'output_file.aiida')

        # Perform the migration
        migrate_archive(input_file, output_file)
        metadata, _ = get_json_files(output_file)
        verify_metadata_version(metadata, version=newest_version)

        # Load the migrated file
        import_data(output_file, silent=True)

        # Do the necessary checks
        self.assertEqual(orm.QueryBuilder().append(orm.Node).count(),
                         self.node_count)

        # Verify that CalculationNodes have non-empty attribute dictionaries
        builder = orm.QueryBuilder().append(orm.CalculationNode)
        for [calculation] in builder.iterall():
            self.assertIsInstance(calculation.attributes, dict)
            self.assertNotEqual(len(calculation.attributes), 0)

        # Verify that the StructureData nodes maintained their (same) label, cell, and kinds
        builder = orm.QueryBuilder().append(orm.StructureData)
        self.assertEqual(
            builder.count(),
            self.struct_count,
            msg='There should be {} StructureData, instead {} were/was found'.
            format(self.struct_count, builder.count()))
        for structures in builder.all():
            structure = structures[0]
            self.assertEqual(structure.label, self.known_struct_label)
            self.assertEqual(structure.cell, self.known_cell)

        builder = orm.QueryBuilder().append(orm.StructureData,
                                            project=['attributes.kinds'])
        for [kinds] in builder.iterall():
            self.assertEqual(len(kinds), len(self.known_kinds))
            for kind in kinds:
                self.assertIn(kind,
                              self.known_kinds,
                              msg="Kind '{}' not found in: {}".format(
                                  kind, self.known_kinds))

        # Check that there is a StructureData that is an input of a CalculationNode
        builder = orm.QueryBuilder()
        builder.append(orm.StructureData, tag='structure')
        builder.append(orm.CalculationNode, with_incoming='structure')
        self.assertGreater(len(builder.all()), 0)

        # Check that there is a RemoteData that is the output of a CalculationNode
        builder = orm.QueryBuilder()
        builder.append(orm.CalculationNode, tag='parent')
        builder.append(orm.RemoteData, with_incoming='parent')
        self.assertGreater(len(builder.all()), 0)
Example #27
0
def _import_archive(archive: str, web_based: bool, import_kwargs: dict, try_migration: bool):
    """Perform the archive import.

    :param archive: the path or URL to the archive
    :param web_based: If the archive needs to be downloaded first
    :param import_kwargs: keyword arguments to pass to the import function
    :param try_migration: whether to try a migration if the import raises IncompatibleArchiveVersionError

    """
    from aiida.common.folders import SandboxFolder
    from aiida.tools.importexport import (
        detect_archive_type, EXPORT_VERSION, import_data, IncompatibleArchiveVersionError
    )
    from aiida.tools.importexport.archive.migrators import get_migrator

    with SandboxFolder() as temp_folder:

        archive_path = archive

        if web_based:
            echo.echo_info(f'downloading archive: {archive}')
            try:
                response = urllib.request.urlopen(archive)
            except Exception as exception:
                _echo_exception(f'downloading archive {archive} failed', exception)
            temp_folder.create_file_from_filelike(response, 'downloaded_archive.zip')
            archive_path = temp_folder.get_abs_path('downloaded_archive.zip')
            echo.echo_success('archive downloaded, proceeding with import')

        echo.echo_info(f'starting import: {archive}')
        try:
            import_data(archive_path, **import_kwargs)
        except IncompatibleArchiveVersionError as exception:
            if try_migration:

                echo.echo_info(f'incompatible version detected for {archive}, trying migration')
                try:
                    migrator = get_migrator(detect_archive_type(archive_path))(archive_path)
                    archive_path = migrator.migrate(
                        EXPORT_VERSION, None, out_compression='none', work_dir=temp_folder.abspath
                    )
                except Exception as exception:
                    _echo_exception(f'an exception occurred while migrating the archive {archive}', exception)

                echo.echo_info('proceeding with import of migrated archive')
                try:
                    import_data(archive_path, **import_kwargs)
                except Exception as exception:
                    _echo_exception(
                        f'an exception occurred while trying to import the migrated archive {archive}', exception
                    )
            else:
                _echo_exception(f'an exception occurred while trying to import the archive {archive}', exception)
        except Exception as exception:
            _echo_exception(f'an exception occurred while trying to import the archive {archive}', exception)

        echo.echo_success(f'imported archive {archive}')
Example #28
0
    def test_base_data_type_change(self, temp_dir):
        """ Base Data types type string changed
        Example: Bool: “data.base.Bool.” → “data.bool.Bool.”
        """
        # Test content
        test_content = ('Hello', 6, -1.2399834e12, False)
        test_types = ()
        for node_type in ['str', 'int', 'float', 'bool']:
            add_type = ('data.{}.{}.'.format(node_type, node_type.capitalize()),)
            test_types = test_types.__add__(add_type)

        # List of nodes to be exported
        export_nodes = []

        # Create list of base type nodes
        nodes = [cls(val).store() for val, cls in zip(test_content, (orm.Str, orm.Int, orm.Float, orm.Bool))]
        export_nodes.extend(nodes)

        # Collect uuids for created nodes
        uuids = [n.uuid for n in nodes]

        # Create List() and insert already created nodes into it
        list_node = orm.List()
        list_node.set_list(nodes)
        list_node.store()
        list_node_uuid = list_node.uuid
        export_nodes.append(list_node)

        # Export nodes
        filename = os.path.join(temp_dir, 'export.aiida')
        export(export_nodes, filename=filename, silent=True)

        # Clean the database
        self.reset_database()

        # Import nodes again
        import_data(filename, silent=True)

        # Check whether types are correctly imported
        nlist = orm.load_node(list_node_uuid)  # List
        for uuid, list_value, refval, reftype in zip(uuids, nlist.get_list(), test_content, test_types):
            # Str, Int, Float, Bool
            base = orm.load_node(uuid)
            # Check value/content
            self.assertEqual(base.value, refval)
            # Check type
            msg = "type of node ('{}') is not updated according to db schema v0.4".format(base.node_type)
            self.assertEqual(base.node_type, reftype, msg=msg)

            # List
            # Check value
            self.assertEqual(list_value, refval)

        # Check List type
        msg = "type of node ('{}') is not updated according to db schema v0.4".format(nlist.node_type)
        self.assertEqual(nlist.node_type, 'data.list.List.', msg=msg)
    def test_empty_repo_folder_export(self, temp_dir):
        """Check a Node's empty repository folder is exported properly"""
        from aiida.common.folders import Folder
        from aiida.tools.importexport.dbexport import export_tree

        node = orm.Dict().store()
        node_uuid = node.uuid

        node_repo = RepositoryFolder(section=Repository._section_name, uuid=node_uuid)  # pylint: disable=protected-access
        self.assertTrue(
            node_repo.exists(), msg='Newly created and stored Node should have had an existing repository folder'
        )
        for filename, is_file in node_repo.get_content_list(only_paths=False):
            abspath_filename = os.path.join(node_repo.abspath, filename)
            if is_file:
                os.remove(abspath_filename)
            else:
                shutil.rmtree(abspath_filename, ignore_errors=False)
        self.assertFalse(
            node_repo.get_content_list(),
            msg='Repository folder should be empty, instead the following was found: {}'.format(
                node_repo.get_content_list()
            )
        )

        archive_variants = {
            'archive folder': os.path.join(temp_dir, 'export_tree'),
            'tar archive': os.path.join(temp_dir, 'export.tar.gz'),
            'zip archive': os.path.join(temp_dir, 'export.zip')
        }

        export_tree([node], folder=Folder(archive_variants['archive folder']), silent=True)
        export([node], filename=archive_variants['tar archive'], file_format='tar.gz', silent=True)
        export([node], filename=archive_variants['zip archive'], file_format='zip', silent=True)

        for variant, filename in archive_variants.items():
            self.reset_database()
            node_count = orm.QueryBuilder().append(orm.Dict, project='uuid').count()
            self.assertEqual(node_count, 0, msg='After DB reset {} Dict Nodes was (wrongly) found'.format(node_count))

            import_data(filename, silent=True)
            builder = orm.QueryBuilder().append(orm.Dict, project='uuid')
            imported_node_count = builder.count()
            self.assertEqual(
                imported_node_count,
                1,
                msg='After {} import a single Dict Node should have been found, '
                'instead {} was/were found'.format(variant, imported_node_count)
            )
            imported_node_uuid = builder.all()[0][0]
            self.assertEqual(
                imported_node_uuid,
                node_uuid,
                msg='The wrong UUID was found for the imported {}: '
                '{}. It should have been: {}'.format(variant, imported_node_uuid, node_uuid)
            )
Example #30
0
    def pull(self, *, profile: Optional[str] = None, **kwargs: Any) -> None:
        """Import nodes from an archive.

        :param profile: The AiiDA profile name (or use default)

        """
        if "uuids" in kwargs:
            raise NotImplementedError("Cannot specify particular UUIDs")
        load_profile(profile)
        import_data(str(self.location))