コード例 #1
0
ファイル: utils.py プロジェクト: CasperWA/aiida_core
def update_metadata(metadata, version):
    """Update the metadata with a new version number and a notification of the conversion that was executed.

    :param metadata: the content of an export archive metadata.json file
    :param version: string version number that the updated metadata should get
    """
    from aiida import get_version

    old_version = metadata['export_version']
    conversion_info = metadata.get('conversion_info', [])

    conversion_message = 'Converted from version {} to {} with AiiDA v{}'.format(
        old_version, version, get_version())
    conversion_info.append(conversion_message)

    metadata['aiida_version'] = get_version()
    metadata['export_version'] = version
    metadata['conversion_info'] = conversion_info
コード例 #2
0
def test_migrations(migration_data):
    """Test each migration method from the `aiida.tools.importexport.migration` module."""
    version_old, version_new, metadata_old, metadata_new, data_old, data_new = migration_data

    # Remove AiiDA version, since this may change regardless of the migration function
    metadata_old.pop('aiida_version')
    metadata_new.pop('aiida_version')

    # Assert conversion message in `metadata.json` is correct and then remove it for later assertions
    metadata_new.pop('conversion_info')
    message = 'Converted from version {} to {} with AiiDA v{}'.format(
        version_old, version_new, get_version())
    assert metadata_old.pop('conversion_info')[
        -1] == message, 'Conversion message after migration is wrong'

    assert metadata_old == metadata_new
    assert data_old == data_new
コード例 #3
0
    def test_migrate_v7_to_v8(self):
        """Test migration for file containing complete v0.7 era possibilities"""
        from aiida import get_version

        # Get metadata.json and data.json as dicts from v0.7 file archive
        metadata_v7, data_v7 = get_json_files('export_v0.7_simple.aiida',
                                              **self.core_archive)
        verify_metadata_version(metadata_v7, version='0.7')

        # Get metadata.json and data.json as dicts from v0.8 file archive
        metadata_v8, data_v8 = get_json_files('export_v0.8_simple.aiida',
                                              **self.core_archive)
        verify_metadata_version(metadata_v8, version='0.8')

        # Migrate to v0.8
        migrate_v7_to_v8(metadata_v7, data_v7)
        verify_metadata_version(metadata_v7, version='0.8')

        # Remove AiiDA version, since this may change irregardless of the migration function
        metadata_v7.pop('aiida_version')
        metadata_v8.pop('aiida_version')

        # Assert conversion message in `metadata.json` is correct and then remove it for later assertions
        self.maxDiff = None  # pylint: disable=invalid-name
        conversion_message = 'Converted from version 0.7 to 0.8 with AiiDA v{}'.format(
            get_version())
        self.assertEqual(metadata_v7.pop('conversion_info')[-1],
                         conversion_message,
                         msg='The conversion message after migration is wrong')
        metadata_v8.pop('conversion_info')

        # Assert changes were performed correctly
        self.assertDictEqual(
            metadata_v7,
            metadata_v8,
            msg=
            'After migration, metadata.json should equal intended metadata.json from archives'
        )
        self.assertDictEqual(
            data_v7,
            data_v8,
            msg=
            'After migration, data.json should equal intended data.json from archives'
        )
コード例 #4
0
    def test_migrate_v2_to_v3(self):
        """Test function migrate_v2_to_v3"""
        from aiida import get_version

        # Get metadata.json and data.json as dicts from v0.2 file archive
        metadata_v2, data_v2 = get_json_files('export_v0.2_simple.aiida',
                                              **self.core_archive)
        verify_metadata_version(metadata_v2, version='0.2')

        # Get metadata.json and data.json as dicts from v0.3 file archive
        metadata_v3, data_v3 = get_json_files('export_v0.3_simple.aiida',
                                              **self.core_archive)
        verify_metadata_version(metadata_v3, version='0.3')

        # Migrate to v0.3
        migrate_v2_to_v3(metadata_v2, data_v2)
        verify_metadata_version(metadata_v2, version='0.3')

        # Remove AiiDA version, since this may change irregardless of the migration function
        metadata_v2.pop('aiida_version')
        metadata_v3.pop('aiida_version')

        # Assert conversion message in `metadata.json` is correct and then remove it for later assertions
        conversion_message = 'Converted from version 0.2 to 0.3 with AiiDA v{}'.format(
            get_version())
        self.assertEqual(metadata_v2.pop('conversion_info')[-1],
                         conversion_message,
                         msg='The conversion message after migration is wrong')
        metadata_v3.pop('conversion_info')

        # Assert changes were performed correctly
        self.maxDiff = None  # pylint: disable=invalid-name
        self.assertDictEqual(
            metadata_v2,
            metadata_v3,
            msg=
            'After migration, metadata.json should equal intended metadata.json from archives'
        )
        self.assertDictEqual(
            data_v2,
            data_v3,
            msg=
            'After migration, data.json should equal intended data.json from archives'
        )
コード例 #5
0
    def test_migrate_v3_to_v4(self):
        """Test function migrate_v3_to_v4"""
        from aiida import get_version

        # Get metadata.json and data.json as dicts from v0.4 file archive
        metadata_v4, data_v4 = get_json_files('export_v0.4_simple.aiida', **self.core_archive)
        verify_metadata_version(metadata_v4, version='0.4')

        # Get metadata.json and data.json as dicts from v0.3 file archive
        # Cannot use 'get_json_files' for 'export_v0.3_simple.aiida',
        # because we need to pass the SandboxFolder to 'migrate_v3_to_v4'
        dirpath_archive = get_archive_file('export_v0.3_simple.aiida', **self.core_archive)

        with SandboxFolder(sandbox_in_repo=False) as folder:
            if zipfile.is_zipfile(dirpath_archive):
                extract_zip(dirpath_archive, folder, silent=True)
            elif tarfile.is_tarfile(dirpath_archive):
                extract_tar(dirpath_archive, folder, silent=True)
            else:
                raise ValueError('invalid file format, expected either a zip archive or gzipped tarball')

            try:
                with io.open(folder.get_abs_path('data.json'), 'r', encoding='utf8') as fhandle:
                    data_v3 = jsonload(fhandle)
                with io.open(folder.get_abs_path('metadata.json'), 'r', encoding='utf8') as fhandle:
                    metadata_v3 = jsonload(fhandle)
            except IOError:
                raise NotExistent('export archive does not contain the required file {}'.format(fhandle.filename))

            verify_metadata_version(metadata_v3, version='0.3')

            # Migrate to v0.4
            migrate_v3_to_v4(metadata_v3, data_v3, folder)
            verify_metadata_version(metadata_v3, version='0.4')

        # Remove AiiDA version, since this may change irregardless of the migration function
        metadata_v3.pop('aiida_version')
        metadata_v4.pop('aiida_version')

        # Assert conversion message in `metadata.json` is correct and then remove it for later assertions
        self.maxDiff = None  # pylint: disable=invalid-name
        conversion_message = 'Converted from version 0.3 to 0.4 with AiiDA v{}'.format(get_version())
        self.assertEqual(
            metadata_v3.pop('conversion_info')[-1],
            conversion_message,
            msg='The conversion message after migration is wrong'
        )
        metadata_v4.pop('conversion_info')

        # Assert changes were performed correctly
        self.assertDictEqual(
            metadata_v3,
            metadata_v4,
            msg='After migration, metadata.json should equal intended metadata.json from archives'
        )
        self.assertDictEqual(
            data_v3, data_v4, msg='After migration, data.json should equal intended data.json from archives'
        )
コード例 #6
0
    def test_dos_wc_Cu(self):
        """
        simple Cu noSOC, FP, lmax2 full example using scf workflow
        """
        from aiida import get_version
        from aiida.orm import Code, load_node
        from aiida.plugins import DataFactory
        from aiida.orm import Computer
        from aiida.orm.querybuilder import QueryBuilder
        from masci_tools.io.kkr_params import kkrparams
        from aiida_kkr.workflows.dos import kkr_dos_wc
        from numpy import array

        print('AiiDA version: {}'.format(get_version()))

        Dict = DataFactory('dict')
        StructureData = DataFactory('structure')

        # prepare computer and code (needed so that
        prepare_code(kkr_codename, codelocation, computername, workdir)

        # Then set up the structure
        alat = 6.83  # in a_Bohr
        abohr = 0.52917721067  # conversion factor to Angstroem units
        bravais = array([[0.5, 0.5, 0.0], [0.5, 0.0, 0.5],
                         [0.0, 0.5, 0.5]])  # bravais vectors
        a = 0.5 * alat * abohr
        Cu = StructureData(cell=[[a, a, 0.0], [a, 0.0, a], [0.0, a, a]])
        Cu.append_atom(position=[0.0, 0.0, 0.0], symbols='Cu')

        Cu.store()
        print(Cu)

        # here we create a parameter node for the workflow input (workflow specific parameter) and adjust the convergence criterion.
        wfd = kkr_dos_wc.get_wf_defaults()
        wfd['dos_params']['kmesh'] = [10, 10, 10]
        wfd['dos_params']['nepts'] = 10
        params_dos = Dict(dict=wfd)

        options = {
            'queue_name': queuename,
            'resources': {
                "num_machines": 1
            },
            'max_wallclock_seconds': 5 * 60,
            'use_mpi': False,
            'custom_scheduler_commands': ''
        }
        options = Dict(dict=options)

        # The scf-workflow needs also the voronoi and KKR codes to be able to run the calulations
        KKRCode = Code.get_from_string(kkr_codename + '@' + computername)

        label = 'dos Cu bulk'
        descr = 'DOS workflow for Cu bulk'

        from aiida.tools.importexport import import_data
        import_data('files/db_dump_kkrcalc.tar.gz')
        kkr_calc_remote = load_node(
            '3058bd6c-de0b-400e-aff5-2331a5f5d566').outputs.remote_folder

        # create process builder to set parameters
        builder = kkr_dos_wc.get_builder()
        builder.metadata.description = descr
        builder.metadata.label = label
        builder.kkr = KKRCode
        builder.wf_parameters = params_dos
        builder.options = options
        builder.remote_data = kkr_calc_remote

        # now run calculation
        from aiida.engine import run
        out = run(builder)

        # check outcome
        n = out['results_wf']
        n = n.get_dict()
        assert n.get('successful')
        assert n.get('list_of_errors') == []
        assert n.get('dos_params').get('nepts') == 10

        d = out['dos_data']
        x = d.get_x()
        y = d.get_y()

        assert sum(
            abs(x[1][0] - array([
                -19.24321191, -16.2197246, -13.1962373, -10.17274986,
                -7.14926255, -4.12577525, -1.10228794, 1.9211995, 4.94468681,
                7.96817411
            ]))) < 10**-7
        assert sum(
            abs(y[0][1][0] - array([
                9.86819781e-04, 1.40981029e-03, 2.27894713e-03, 4.79231363e-03,
                3.59368494e-02, 2.32929524e+00, 3.06973485e-01, 4.17629157e-01,
                3.04021941e-01, 1.24897739e-01
            ]))) < 10**-8
コード例 #7
0
def export_tree(what,
                folder,
                allowed_licenses=None,
                forbidden_licenses=None,
                silent=False,
                include_comments=True,
                include_logs=True,
                **kwargs):
    """Export the entries passed in the 'what' list to a file tree.

    :param what: a list of entity instances; they can belong to different models/entities.
    :type what: list

    :param folder: a temporary folder to build the archive before compression.
    :type folder: :py:class:`~aiida.common.folders.Folder`

    :param allowed_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type allowed_licenses: list

    :param forbidden_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type forbidden_licenses: list

    :param silent: suppress prints.
    :type silent: bool

    :param include_comments: In-/exclude export of comments for given node(s) in ``what``.
        Default: True, *include* comments in export (as well as relevant users).
    :type include_comments: bool

    :param include_logs: In-/exclude export of logs for given node(s) in ``what``.
        Default: True, *include* logs in export.
    :type include_logs: bool

    :param kwargs: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` what rule names
        are toggleable and what the defaults are.

    :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`: if there are any internal errors when
        exporting.
    :raises `~aiida.common.exceptions.LicensingException`: if any node is licensed under forbidden license.
    """
    from collections import defaultdict

    if not silent:
        print('STARTING EXPORT...')

    all_fields_info, unique_identifiers = get_all_fields_info()

    entities_starting_set = defaultdict(set)

    # The set that contains the nodes ids of the nodes that should be exported
    given_data_entry_ids = set()
    given_calculation_entry_ids = set()
    given_group_entry_ids = set()
    given_computer_entry_ids = set()
    given_groups = set()
    given_log_entry_ids = set()
    given_comment_entry_ids = set()

    # I store a list of the actual dbnodes
    for entry in what:
        # This returns the class name (as in imports). E.g. for a model node:
        # aiida.backends.djsite.db.models.DbNode
        # entry_class_string = get_class_string(entry)
        # Now a load the backend-independent name into entry_entity_name, e.g. Node!
        # entry_entity_name = schema_to_entity_names(entry_class_string)
        if issubclass(entry.__class__, orm.Group):
            entities_starting_set[GROUP_ENTITY_NAME].add(entry.uuid)
            given_group_entry_ids.add(entry.id)
            given_groups.add(entry)
        elif issubclass(entry.__class__, orm.Node):
            entities_starting_set[NODE_ENTITY_NAME].add(entry.uuid)
            if issubclass(entry.__class__, orm.Data):
                given_data_entry_ids.add(entry.pk)
            elif issubclass(entry.__class__, orm.ProcessNode):
                given_calculation_entry_ids.add(entry.pk)
        elif issubclass(entry.__class__, orm.Computer):
            entities_starting_set[COMPUTER_ENTITY_NAME].add(entry.uuid)
            given_computer_entry_ids.add(entry.pk)
        else:
            raise exceptions.ArchiveExportError(
                'I was given {} ({}), which is not a Node, Computer, or Group instance'
                .format(entry, type(entry)))

    # Add all the nodes contained within the specified groups
    for group in given_groups:
        for entry in group.nodes:
            entities_starting_set[NODE_ENTITY_NAME].add(entry.uuid)
            if issubclass(entry.__class__, orm.Data):
                given_data_entry_ids.add(entry.pk)
            elif issubclass(entry.__class__, orm.ProcessNode):
                given_calculation_entry_ids.add(entry.pk)

    for entity, entity_set in entities_starting_set.items():
        entities_starting_set[entity] = list(entity_set)

    # We will iteratively explore the AiiDA graph to find further nodes that
    # should also be exported.
    # At the same time, we will create the links_uuid list of dicts to be exported

    if not silent:
        print('RETRIEVING LINKED NODES AND STORING LINKS...')

    to_be_exported, links_uuid, graph_traversal_rules = retrieve_linked_nodes(
        given_calculation_entry_ids, given_data_entry_ids, **kwargs)

    ## Universal "entities" attributed to all types of nodes
    # Logs
    if include_logs and to_be_exported:
        # Get related log(s) - universal for all nodes
        builder = orm.QueryBuilder()
        builder.append(orm.Log,
                       filters={'dbnode_id': {
                           'in': to_be_exported
                       }},
                       project='id')
        res = {_[0] for _ in builder.all()}
        given_log_entry_ids.update(res)

    # Comments
    if include_comments and to_be_exported:
        # Get related log(s) - universal for all nodes
        builder = orm.QueryBuilder()
        builder.append(orm.Comment,
                       filters={'dbnode_id': {
                           'in': to_be_exported
                       }},
                       project='id')
        res = {_[0] for _ in builder.all()}
        given_comment_entry_ids.update(res)

    # Here we get all the columns that we plan to project per entity that we
    # would like to extract
    given_entities = list()
    if given_group_entry_ids:
        given_entities.append(GROUP_ENTITY_NAME)
    if to_be_exported:
        given_entities.append(NODE_ENTITY_NAME)
    if given_computer_entry_ids:
        given_entities.append(COMPUTER_ENTITY_NAME)
    if given_log_entry_ids:
        given_entities.append(LOG_ENTITY_NAME)
    if given_comment_entry_ids:
        given_entities.append(COMMENT_ENTITY_NAME)

    entries_to_add = dict()
    for given_entity in given_entities:
        project_cols = ['id']
        # The following gets a list of fields that we need,
        # e.g. user, mtime, uuid, computer
        entity_prop = all_fields_info[given_entity].keys()

        # Here we do the necessary renaming of properties
        for prop in entity_prop:
            # nprop contains the list of projections
            nprop = (file_fields_to_model_fields[given_entity][prop] if prop
                     in file_fields_to_model_fields[given_entity] else prop)
            project_cols.append(nprop)

        # Getting the ids that correspond to the right entity
        if given_entity == GROUP_ENTITY_NAME:
            entry_ids_to_add = given_group_entry_ids
        elif given_entity == NODE_ENTITY_NAME:
            entry_ids_to_add = to_be_exported
        elif given_entity == COMPUTER_ENTITY_NAME:
            entry_ids_to_add = given_computer_entry_ids
        elif given_entity == LOG_ENTITY_NAME:
            entry_ids_to_add = given_log_entry_ids
        elif given_entity == COMMENT_ENTITY_NAME:
            entry_ids_to_add = given_comment_entry_ids

        builder = orm.QueryBuilder()
        builder.append(entity_names_to_entities[given_entity],
                       filters={'id': {
                           'in': entry_ids_to_add
                       }},
                       project=project_cols,
                       tag=given_entity,
                       outerjoin=True)
        entries_to_add[given_entity] = builder

    # TODO (Spyros) To see better! Especially for functional licenses
    # Check the licenses of exported data.
    if allowed_licenses is not None or forbidden_licenses is not None:
        builder = orm.QueryBuilder()
        builder.append(orm.Node,
                       project=['id', 'attributes.source.license'],
                       filters={'id': {
                           'in': to_be_exported
                       }})
        # Skip those nodes where the license is not set (this is the standard behavior with Django)
        node_licenses = list(
            (a, b) for [a, b] in builder.all() if b is not None)
        check_licenses(node_licenses, allowed_licenses, forbidden_licenses)

    ############################################################
    ##### Start automatic recursive export data generation #####
    ############################################################
    if not silent:
        print('STORING DATABASE ENTRIES...')

    export_data = dict()
    entity_separator = '_'
    for entity_name, partial_query in entries_to_add.items():

        foreign_fields = {
            k: v
            for k, v in all_fields_info[entity_name].items()
            # all_fields_info[model_name].items()
            if 'requires' in v
        }

        for value in foreign_fields.values():
            ref_model_name = value['requires']
            fill_in_query(partial_query, entity_name, ref_model_name,
                          [entity_name], entity_separator)

        for temp_d in partial_query.iterdict():
            for k in temp_d.keys():
                # Get current entity
                current_entity = k.split(entity_separator)[-1]

                # This is a empty result of an outer join.
                # It should not be taken into account.
                if temp_d[k]['id'] is None:
                    continue

                temp_d2 = {
                    temp_d[k]['id']:
                    serialize_dict(temp_d[k],
                                   remove_fields=['id'],
                                   rename_fields=model_fields_to_file_fields[
                                       current_entity])
                }
                try:
                    export_data[current_entity].update(temp_d2)
                except KeyError:
                    export_data[current_entity] = temp_d2

    #######################################
    # Manually manage attributes and extras
    #######################################
    # I use .get because there may be no nodes to export
    all_nodes_pk = list()
    if NODE_ENTITY_NAME in export_data:
        all_nodes_pk.extend(export_data.get(NODE_ENTITY_NAME).keys())

    if sum(len(model_data) for model_data in export_data.values()) == 0:
        if not silent:
            print('No nodes to store, exiting...')
        return

    if not silent:
        print('Exporting a total of {} db entries, of which {} nodes.'.format(
            sum(len(model_data) for model_data in export_data.values()),
            len(all_nodes_pk)))

    # ATTRIBUTES and EXTRAS
    if not silent:
        print('STORING NODE ATTRIBUTES AND EXTRAS...')
    node_attributes = {}
    node_extras = {}

    # A second QueryBuilder query to get the attributes and extras. See if this can be optimized
    if all_nodes_pk:
        all_nodes_query = orm.QueryBuilder()
        all_nodes_query.append(orm.Node,
                               filters={'id': {
                                   'in': all_nodes_pk
                               }},
                               project=['id', 'attributes', 'extras'])
        for res_pk, res_attributes, res_extras in all_nodes_query.iterall():
            node_attributes[str(res_pk)] = res_attributes
            node_extras[str(res_pk)] = res_extras

    if not silent:
        print('STORING GROUP ELEMENTS...')
    groups_uuid = dict()
    # If a group is in the exported date, we export the group/node correlation
    if GROUP_ENTITY_NAME in export_data:
        for curr_group in export_data[GROUP_ENTITY_NAME]:
            group_uuid_qb = orm.QueryBuilder()
            group_uuid_qb.append(entity_names_to_entities[GROUP_ENTITY_NAME],
                                 filters={'id': {
                                     '==': curr_group
                                 }},
                                 project=['uuid'],
                                 tag='group')
            group_uuid_qb.append(entity_names_to_entities[NODE_ENTITY_NAME],
                                 project=['uuid'],
                                 with_group='group')
            for res in group_uuid_qb.iterall():
                if str(res[0]) in groups_uuid:
                    groups_uuid[str(res[0])].append(str(res[1]))
                else:
                    groups_uuid[str(res[0])] = [str(res[1])]

    #######################################
    # Final check for unsealed ProcessNodes
    #######################################
    process_nodes = set()
    for node_pk, content in export_data.get(NODE_ENTITY_NAME, {}).items():
        if content['node_type'].startswith('process.'):
            process_nodes.add(node_pk)

    check_process_nodes_sealed(process_nodes)

    ######################################
    # Now I store
    ######################################
    # subfolder inside the export package
    nodesubfolder = folder.get_subfolder(NODES_EXPORT_SUBFOLDER,
                                         create=True,
                                         reset_limit=True)

    if not silent:
        print('STORING DATA...')

    data = {
        'node_attributes': node_attributes,
        'node_extras': node_extras,
        'export_data': export_data,
        'links_uuid': links_uuid,
        'groups_uuid': groups_uuid
    }

    # N.B. We're really calling zipfolder.open (if exporting a zipfile)
    with folder.open('data.json', mode='w') as fhandle:
        # fhandle.write(json.dumps(data, cls=UUIDEncoder))
        fhandle.write(json.dumps(data))

    # Add proper signature to unique identifiers & all_fields_info
    # Ignore if a key doesn't exist in any of the two dictionaries

    metadata = {
        'aiida_version': get_version(),
        'export_version': EXPORT_VERSION,
        'all_fields_info': all_fields_info,
        'unique_identifiers': unique_identifiers,
        'export_parameters': {
            'graph_traversal_rules': graph_traversal_rules,
            'entities_starting_set': entities_starting_set,
            'include_comments': include_comments,
            'include_logs': include_logs
        }
    }

    with folder.open('metadata.json', 'w') as fhandle:
        fhandle.write(json.dumps(metadata))

    if silent is not True:
        print('STORING REPOSITORY FILES...')

    # If there are no nodes, there are no repository files to store
    if all_nodes_pk:
        # Large speed increase by not getting the node itself and looping in memory in python, but just getting the uuid
        uuid_query = orm.QueryBuilder()
        uuid_query.append(orm.Node,
                          filters={'id': {
                              'in': all_nodes_pk
                          }},
                          project=['uuid'])
        for res in uuid_query.all():
            uuid = str(res[0])
            sharded_uuid = export_shard_uuid(uuid)

            # Important to set create=False, otherwise creates twice a subfolder. Maybe this is a bug of insert_path?
            thisnodefolder = nodesubfolder.get_subfolder(sharded_uuid,
                                                         create=False,
                                                         reset_limit=True)

            # Make sure the node's repository folder was not deleted
            src = RepositoryFolder(section=Repository._section_name, uuid=uuid)  # pylint: disable=protected-access
            if not src.exists():
                raise exceptions.ArchiveExportError(
                    'Unable to find the repository folder for Node with UUID={} in the local repository'
                    .format(uuid))

            # In this way, I copy the content of the folder, and not the folder itself
            thisnodefolder.insert_path(src=src.abspath, dest_name='.')
コード例 #8
0
def export_tree(
    entities: Optional[Iterable[Any]] = None,
    folder: Optional[Union[Folder, ZipFolder]] = None,
    allowed_licenses: Optional[Union[list, Callable]] = None,
    forbidden_licenses: Optional[Union[list, Callable]] = None,
    silent: bool = False,
    include_comments: bool = True,
    include_logs: bool = True,
    **traversal_rules: bool,
) -> None:
    """Export the entries passed in the 'entities' list to a file tree.

    .. deprecated:: 1.2.1
        Support for the parameter `what` will be removed in `v2.0.0`. Please use `entities` instead.

    :param entities: a list of entity instances; they can belong to different models/entities.

    :param folder: a temporary folder to build the archive before compression.

    :param allowed_licenses: List or function.
        If a list, then checks whether all licenses of Data nodes are in the list.
        If a function, then calls function for licenses of Data nodes,
        expecting True if license is allowed, False otherwise.

    :param forbidden_licenses: List or function.
        If a list, then checks whether all licenses of Data nodes are in the list.
        If a function, then calls function for licenses of Data nodes,
        expecting True if license is allowed, False otherwise.

    :param silent: suppress console prints and progress bar.

    :param include_comments: In-/exclude export of comments for given node(s) in ``entities``.
        Default: True, *include* comments in export (as well as relevant users).

    :param include_logs: In-/exclude export of logs for given node(s) in ``entities``.
        Default: True, *include* logs in export.

    :param traversal_rules: graph traversal rules.
        See :const:`aiida.common.links.GraphTraversalRules`
        what rule names are toggleable and what the defaults are.

    :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`:
        if there are any internal errors when exporting.
    :raises `~aiida.common.exceptions.LicensingException`:
        if any node is licensed under forbidden license.
    """

    if silent:
        logging.disable(level=logging.CRITICAL)

    EXPORT_LOGGER.debug("STARTING EXPORT...")

    # Backwards-compatibility
    entities = cast(
        Iterable[Any],
        deprecated_parameters(
            old={
                "name": "what",
                "value": traversal_rules.pop("what", None)
            },
            new={
                "name": "entities",
                "value": entities
            },
        ),
    )

    type_check(
        entities,
        (list, tuple, set),
        msg=
        "`entities` must be specified and given as a list of AiiDA entities",
    )
    entities = list(entities)

    type_check(
        folder,
        (Folder, ZipFolder),
        msg="`folder` must be specified and given as an AiiDA Folder entity",
    )
    folder = cast(Union[Folder, ZipFolder], folder)

    all_fields_info, unique_identifiers = get_all_fields_info()

    entities_starting_set, given_node_entry_ids = get_starting_node_ids(
        entities, silent)

    (
        node_ids_to_be_exported,
        node_pk_2_uuid_mapping,
        links_uuid,
        traversal_rules,
    ) = collect_export_nodes(given_node_entry_ids, silent, **traversal_rules)

    check_node_licenses(node_ids_to_be_exported, allowed_licenses,
                        forbidden_licenses)

    entries_queries = get_entry_queries(
        node_ids_to_be_exported,
        entities_starting_set,
        node_pk_2_uuid_mapping,
        silent,
        include_comments,
        include_logs,
    )

    export_data = get_export_data(entries_queries, silent)

    # Close progress up until this point in order to print properly
    close_progress_bar(leave=False)

    # note this was originally below the attributes and group_uuid gather
    check_process_nodes_sealed({
        node_pk
        for node_pk, content in export_data.get(NODE_ENTITY_NAME, {}).items()
        if content["node_type"].startswith("process.")
    })

    model_data = sum(len(model_data) for model_data in export_data.values())
    if not model_data:
        EXPORT_LOGGER.log(msg="Nothing to store, exiting...",
                          level=LOG_LEVEL_REPORT)
        return
    EXPORT_LOGGER.log(
        msg=(f"Exporting a total of {model_data} database entries, "
             f"of which {len(node_ids_to_be_exported)} are Nodes."),
        level=LOG_LEVEL_REPORT,
    )

    node_attributes, node_extras = get_node_data(export_data,
                                                 node_ids_to_be_exported,
                                                 silent)
    groups_uuid = get_groups_uuid(export_data, silent)

    # Turn sets into lists to be able to export them as JSON metadata.
    for entity, entity_set in entities_starting_set.items():
        entities_starting_set[entity] = list(entity_set)  # type: ignore

    metadata = {
        "aiida_version": get_version(),
        "export_version": EXPORT_VERSION,
        "all_fields_info": all_fields_info,
        "unique_identifiers": unique_identifiers,
        "export_parameters": {
            "graph_traversal_rules": traversal_rules,
            "entities_starting_set": entities_starting_set,
            "include_comments": include_comments,
            "include_logs": include_logs,
        },
    }

    all_node_uuids = {
        node_pk_2_uuid_mapping[_]
        for _ in node_ids_to_be_exported
    }

    write_to_archive(
        folder,
        metadata,
        all_node_uuids,
        export_data,
        node_attributes,
        node_extras,
        groups_uuid,
        links_uuid,
        silent,
    )

    close_progress_bar(leave=False)

    # Reset logging level
    if silent:
        logging.disable(level=logging.NOTSET)
コード例 #9
0
def export_tree(entities=None,
                folder=None,
                allowed_licenses=None,
                forbidden_licenses=None,
                silent=False,
                include_comments=True,
                include_logs=True,
                **kwargs):
    """Export the entries passed in the 'entities' list to a file tree.

    .. deprecated:: 1.2.1
        Support for the parameter `what` will be removed in `v2.0.0`. Please use `entities` instead.

    :param entities: a list of entity instances; they can belong to different models/entities.
    :type entities: list

    :param folder: a temporary folder to build the archive before compression.
    :type folder: :py:class:`~aiida.common.folders.Folder`

    :param allowed_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type allowed_licenses: list

    :param forbidden_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type forbidden_licenses: list

    :param silent: suppress console prints and progress bar.
    :type silent: bool

    :param include_comments: In-/exclude export of comments for given node(s) in ``entities``.
        Default: True, *include* comments in export (as well as relevant users).
    :type include_comments: bool

    :param include_logs: In-/exclude export of logs for given node(s) in ``entities``.
        Default: True, *include* logs in export.
    :type include_logs: bool

    :param kwargs: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` what rule names
        are toggleable and what the defaults are.

    :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`: if there are any internal errors when
        exporting.
    :raises `~aiida.common.exceptions.LicensingException`: if any node is licensed under forbidden license.
    """
    from collections import defaultdict
    from aiida.tools.graph.graph_traversers import get_nodes_export

    if silent:
        logging.disable(level=logging.CRITICAL)

    EXPORT_LOGGER.debug('STARTING EXPORT...')

    # Backwards-compatibility
    entities = deprecated_parameters(
        old={
            'name': 'what',
            'value': kwargs.pop('what', None)
        },
        new={
            'name': 'entities',
            'value': entities
        },
    )

    type_check(
        entities, (list, tuple, set),
        msg='`entities` must be specified and given as a list of AiiDA entities'
    )
    entities = list(entities)

    type_check(
        folder, (Folder, ZipFolder),
        msg='`folder` must be specified and given as an AiiDA Folder entity')

    all_fields_info, unique_identifiers = get_all_fields_info()

    entities_starting_set = defaultdict(set)

    # The set that contains the nodes ids of the nodes that should be exported
    given_node_entry_ids = set()
    given_log_entry_ids = set()
    given_comment_entry_ids = set()

    # Instantiate progress bar - go through list of `entities`
    pbar_total = len(entities) + 1 if entities else 1
    progress_bar = get_progress_bar(total=pbar_total,
                                    leave=False,
                                    disable=silent)
    progress_bar.set_description_str('Collecting chosen entities',
                                     refresh=False)

    # I store a list of the actual dbnodes
    for entry in entities:
        progress_bar.update()

        # This returns the class name (as in imports). E.g. for a model node:
        # aiida.backends.djsite.db.models.DbNode
        # entry_class_string = get_class_string(entry)
        # Now a load the backend-independent name into entry_entity_name, e.g. Node!
        # entry_entity_name = schema_to_entity_names(entry_class_string)
        if issubclass(entry.__class__, orm.Group):
            entities_starting_set[GROUP_ENTITY_NAME].add(entry.uuid)
        elif issubclass(entry.__class__, orm.Node):
            entities_starting_set[NODE_ENTITY_NAME].add(entry.uuid)
            given_node_entry_ids.add(entry.pk)
        elif issubclass(entry.__class__, orm.Computer):
            entities_starting_set[COMPUTER_ENTITY_NAME].add(entry.uuid)
        else:
            raise exceptions.ArchiveExportError(
                'I was given {} ({}), which is not a Node, Computer, or Group instance'
                .format(entry, type(entry)))

    # Add all the nodes contained within the specified groups
    if GROUP_ENTITY_NAME in entities_starting_set:

        progress_bar.set_description_str('Retrieving Nodes from Groups ...',
                                         refresh=True)

        # Use single query instead of given_group.nodes iterator for performance.
        qh_groups = orm.QueryBuilder().append(
            orm.Group,
            filters={
                'uuid': {
                    'in': entities_starting_set[GROUP_ENTITY_NAME]
                }
            },
            tag='groups').queryhelp

        # Delete this import once the dbexport.zip module has been renamed
        from builtins import zip  # pylint: disable=redefined-builtin

        node_results = orm.QueryBuilder(**qh_groups).append(
            orm.Node, project=['id', 'uuid'], with_group='groups').all()
        if node_results:
            pks, uuids = map(list, zip(*node_results))
            entities_starting_set[NODE_ENTITY_NAME].update(uuids)
            given_node_entry_ids.update(pks)
            del node_results, pks, uuids

        progress_bar.update()

    # We will iteratively explore the AiiDA graph to find further nodes that should also be exported.
    # At the same time, we will create the links_uuid list of dicts to be exported

    progress_bar = get_progress_bar(total=1, disable=silent)
    progress_bar.set_description_str(
        'Getting provenance and storing links ...', refresh=True)

    traverse_output = get_nodes_export(starting_pks=given_node_entry_ids,
                                       get_links=True,
                                       **kwargs)
    node_ids_to_be_exported = traverse_output['nodes']
    graph_traversal_rules = traverse_output['rules']

    # A utility dictionary for mapping PK to UUID.
    if node_ids_to_be_exported:
        qbuilder = orm.QueryBuilder().append(
            orm.Node,
            project=('id', 'uuid'),
            filters={'id': {
                'in': node_ids_to_be_exported
            }},
        )
        node_pk_2_uuid_mapping = dict(qbuilder.all())
    else:
        node_pk_2_uuid_mapping = {}

    # The set of tuples now has to be transformed to a list of dicts
    links_uuid = [{
        'input': node_pk_2_uuid_mapping[link.source_id],
        'output': node_pk_2_uuid_mapping[link.target_id],
        'label': link.link_label,
        'type': link.link_type
    } for link in traverse_output['links']]

    progress_bar.update()

    # Progress bar initialization - Entities
    progress_bar = get_progress_bar(total=1, disable=silent)
    progress_bar.set_description_str('Initializing export of all entities',
                                     refresh=True)

    ## Universal "entities" attributed to all types of nodes
    # Logs
    if include_logs and node_ids_to_be_exported:
        # Get related log(s) - universal for all nodes
        builder = orm.QueryBuilder()
        builder.append(orm.Log,
                       filters={'dbnode_id': {
                           'in': node_ids_to_be_exported
                       }},
                       project='uuid')
        res = set(builder.all(flat=True))
        given_log_entry_ids.update(res)

    # Comments
    if include_comments and node_ids_to_be_exported:
        # Get related log(s) - universal for all nodes
        builder = orm.QueryBuilder()
        builder.append(orm.Comment,
                       filters={'dbnode_id': {
                           'in': node_ids_to_be_exported
                       }},
                       project='uuid')
        res = set(builder.all(flat=True))
        given_comment_entry_ids.update(res)

    # Here we get all the columns that we plan to project per entity that we would like to extract
    given_entities = set(entities_starting_set.keys())
    if node_ids_to_be_exported:
        given_entities.add(NODE_ENTITY_NAME)
    if given_log_entry_ids:
        given_entities.add(LOG_ENTITY_NAME)
    if given_comment_entry_ids:
        given_entities.add(COMMENT_ENTITY_NAME)

    progress_bar.update()

    if given_entities:
        progress_bar = get_progress_bar(total=len(given_entities),
                                        disable=silent)
        pbar_base_str = 'Preparing entities'

    entries_to_add = dict()
    for given_entity in given_entities:
        progress_bar.set_description_str(pbar_base_str +
                                         ' - {}s'.format(given_entity),
                                         refresh=False)
        progress_bar.update()

        project_cols = ['id']
        # The following gets a list of fields that we need,
        # e.g. user, mtime, uuid, computer
        entity_prop = all_fields_info[given_entity].keys()

        # Here we do the necessary renaming of properties
        for prop in entity_prop:
            # nprop contains the list of projections
            nprop = (file_fields_to_model_fields[given_entity][prop] if prop
                     in file_fields_to_model_fields[given_entity] else prop)
            project_cols.append(nprop)

        # Getting the ids that correspond to the right entity
        entry_uuids_to_add = entities_starting_set.get(given_entity, set())
        if not entry_uuids_to_add:
            if given_entity == LOG_ENTITY_NAME:
                entry_uuids_to_add = given_log_entry_ids
            elif given_entity == COMMENT_ENTITY_NAME:
                entry_uuids_to_add = given_comment_entry_ids
        elif given_entity == NODE_ENTITY_NAME:
            entry_uuids_to_add.update(
                {node_pk_2_uuid_mapping[_]
                 for _ in node_ids_to_be_exported})

        builder = orm.QueryBuilder()
        builder.append(entity_names_to_entities[given_entity],
                       filters={'uuid': {
                           'in': entry_uuids_to_add
                       }},
                       project=project_cols,
                       tag=given_entity,
                       outerjoin=True)
        entries_to_add[given_entity] = builder

    # TODO (Spyros) To see better! Especially for functional licenses
    # Check the licenses of exported data.
    if allowed_licenses is not None or forbidden_licenses is not None:
        builder = orm.QueryBuilder()
        builder.append(orm.Node,
                       project=['id', 'attributes.source.license'],
                       filters={'id': {
                           'in': node_ids_to_be_exported
                       }})
        # Skip those nodes where the license is not set (this is the standard behavior with Django)
        node_licenses = list(
            (a, b) for [a, b] in builder.all() if b is not None)
        check_licenses(node_licenses, allowed_licenses, forbidden_licenses)

    ############################################################
    ##### Start automatic recursive export data generation #####
    ############################################################
    EXPORT_LOGGER.debug('GATHERING DATABASE ENTRIES...')

    if entries_to_add:
        progress_bar = get_progress_bar(total=len(entries_to_add),
                                        disable=silent)

    export_data = defaultdict(dict)
    entity_separator = '_'
    for entity_name, partial_query in entries_to_add.items():

        progress_bar.set_description_str('Exporting {}s'.format(entity_name),
                                         refresh=False)
        progress_bar.update()

        foreign_fields = {
            k: v
            for k, v in all_fields_info[entity_name].items() if 'requires' in v
        }

        for value in foreign_fields.values():
            ref_model_name = value['requires']
            fill_in_query(partial_query, entity_name, ref_model_name,
                          [entity_name], entity_separator)

        for temp_d in partial_query.iterdict():
            for key in temp_d:
                # Get current entity
                current_entity = key.split(entity_separator)[-1]

                # This is a empty result of an outer join.
                # It should not be taken into account.
                if temp_d[key]['id'] is None:
                    continue

                export_data[current_entity].update({
                    temp_d[key]['id']:
                    serialize_dict(temp_d[key],
                                   remove_fields=['id'],
                                   rename_fields=model_fields_to_file_fields[
                                       current_entity])
                })

    # Close progress up until this point in order to print properly
    close_progress_bar(leave=False)

    #######################################
    # Manually manage attributes and extras
    #######################################
    # Pointer. Renaming, since Nodes have now technically been retrieved and "stored"
    all_node_pks = node_ids_to_be_exported

    model_data = sum(len(model_data) for model_data in export_data.values())
    if not model_data:
        EXPORT_LOGGER.log(msg='Nothing to store, exiting...',
                          level=LOG_LEVEL_REPORT)
        return
    EXPORT_LOGGER.log(
        msg='Exporting a total of {} database entries, of which {} are Nodes.'.
        format(model_data, len(all_node_pks)),
        level=LOG_LEVEL_REPORT)

    # Instantiate new progress bar
    progress_bar = get_progress_bar(total=1, leave=False, disable=silent)

    # ATTRIBUTES and EXTRAS
    EXPORT_LOGGER.debug('GATHERING NODE ATTRIBUTES AND EXTRAS...')
    node_attributes = {}
    node_extras = {}

    # Another QueryBuilder query to get the attributes and extras. TODO: See if this can be optimized
    if all_node_pks:
        all_nodes_query = orm.QueryBuilder().append(
            orm.Node,
            filters={'id': {
                'in': all_node_pks
            }},
            project=['id', 'attributes', 'extras'])

        progress_bar = get_progress_bar(total=all_nodes_query.count(),
                                        disable=silent)
        progress_bar.set_description_str('Exporting Attributes and Extras',
                                         refresh=False)

        for node_pk, attributes, extras in all_nodes_query.iterall():
            progress_bar.update()

            node_attributes[str(node_pk)] = attributes
            node_extras[str(node_pk)] = extras

    EXPORT_LOGGER.debug('GATHERING GROUP ELEMENTS...')
    groups_uuid = defaultdict(list)
    # If a group is in the exported data, we export the group/node correlation
    if GROUP_ENTITY_NAME in export_data:
        group_uuids_with_node_uuids = orm.QueryBuilder().append(
            orm.Group,
            filters={
                'id': {
                    'in': export_data[GROUP_ENTITY_NAME]
                }
            },
            project='uuid',
            tag='groups').append(orm.Node, project='uuid', with_group='groups')

        # This part is _only_ for the progress bar
        total_node_uuids_for_groups = group_uuids_with_node_uuids.count()
        if total_node_uuids_for_groups:
            progress_bar = get_progress_bar(total=total_node_uuids_for_groups,
                                            disable=silent)
            progress_bar.set_description_str('Exporting Groups ...',
                                             refresh=False)

        for group_uuid, node_uuid in group_uuids_with_node_uuids.iterall():
            progress_bar.update()

            groups_uuid[group_uuid].append(node_uuid)

    #######################################
    # Final check for unsealed ProcessNodes
    #######################################
    process_nodes = set()
    for node_pk, content in export_data.get(NODE_ENTITY_NAME, {}).items():
        if content['node_type'].startswith('process.'):
            process_nodes.add(node_pk)

    check_process_nodes_sealed(process_nodes)

    ######################################
    # Now collecting and storing
    ######################################
    # subfolder inside the export package
    nodesubfolder = folder.get_subfolder(NODES_EXPORT_SUBFOLDER,
                                         create=True,
                                         reset_limit=True)

    EXPORT_LOGGER.debug('ADDING DATA TO EXPORT ARCHIVE...')

    data = {
        'node_attributes': node_attributes,
        'node_extras': node_extras,
        'export_data': export_data,
        'links_uuid': links_uuid,
        'groups_uuid': groups_uuid
    }

    # N.B. We're really calling zipfolder.open (if exporting a zipfile)
    with folder.open('data.json', mode='w') as fhandle:
        # fhandle.write(json.dumps(data, cls=UUIDEncoder))
        fhandle.write(json.dumps(data))

    # Turn sets into lists to be able to export them as JSON metadata.
    for entity, entity_set in entities_starting_set.items():
        entities_starting_set[entity] = list(entity_set)

    metadata = {
        'aiida_version': get_version(),
        'export_version': EXPORT_VERSION,
        'all_fields_info': all_fields_info,
        'unique_identifiers': unique_identifiers,
        'export_parameters': {
            'graph_traversal_rules': graph_traversal_rules,
            'entities_starting_set': entities_starting_set,
            'include_comments': include_comments,
            'include_logs': include_logs
        }
    }

    with folder.open('metadata.json', 'w') as fhandle:
        fhandle.write(json.dumps(metadata))

    EXPORT_LOGGER.debug('ADDING REPOSITORY FILES TO EXPORT ARCHIVE...')

    # If there are no nodes, there are no repository files to store
    if all_node_pks:
        all_node_uuids = {node_pk_2_uuid_mapping[_] for _ in all_node_pks}

        progress_bar = get_progress_bar(total=len(all_node_uuids),
                                        disable=silent)
        pbar_base_str = 'Exporting repository - '

        for uuid in all_node_uuids:
            sharded_uuid = export_shard_uuid(uuid)

            progress_bar.set_description_str(
                pbar_base_str + 'UUID={}'.format(uuid.split('-')[0]),
                refresh=False)
            progress_bar.update()

            # Important to set create=False, otherwise creates twice a subfolder. Maybe this is a bug of insert_path?
            thisnodefolder = nodesubfolder.get_subfolder(sharded_uuid,
                                                         create=False,
                                                         reset_limit=True)

            # Make sure the node's repository folder was not deleted
            src = RepositoryFolder(section=Repository._section_name, uuid=uuid)  # pylint: disable=protected-access
            if not src.exists():
                raise exceptions.ArchiveExportError(
                    'Unable to find the repository folder for Node with UUID={} in the local repository'
                    .format(uuid))

            # In this way, I copy the content of the folder, and not the folder itself
            thisnodefolder.insert_path(src=src.abspath, dest_name='.')

    close_progress_bar(leave=False)

    # Reset logging level
    if silent:
        logging.disable(level=logging.NOTSET)
コード例 #10
0
def export(
    entities: Optional[Iterable[Any]] = None,
    filename: Optional[str] = None,
    file_format: Union[str, Type[ArchiveWriterAbstract]] = ExportFileFormat.ZIP,
    overwrite: bool = False,
    silent: Optional[bool] = None,
    use_compression: Optional[bool] = None,
    include_comments: bool = True,
    include_logs: bool = True,
    allowed_licenses: Optional[Union[list, Callable]] = None,
    forbidden_licenses: Optional[Union[list, Callable]] = None,
    writer_init: Optional[Dict[str, Any]] = None,
    batch_size: int = 100,
    **traversal_rules: bool,
) -> ArchiveWriterAbstract:
    """Export AiiDA data to an archive file.

    Note, the logging level and progress reporter should be set externally, for example::

        from aiida.common.progress_reporter import set_progress_bar_tqdm

        EXPORT_LOGGER.setLevel('DEBUG')
        set_progress_bar_tqdm(leave=True)
        export(...)

    .. deprecated:: 1.5.0
        Support for the parameter `silent` will be removed in `v2.0.0`.
        Please set the log level and progress bar implementation independently.

    .. deprecated:: 1.5.0
        Support for the parameter `use_compression` will be removed in `v2.0.0`.
        Please use `writer_init={'use_compression': True}`.

    .. deprecated:: 1.2.1
        Support for the parameters `what` and `outfile` will be removed in `v2.0.0`.
        Please use `entities` and `filename` instead, respectively.

    :param entities: a list of entity instances;
        they can belong to different models/entities.

    :param filename: the filename (possibly including the absolute path)
        of the file on which to export.

    :param file_format: 'zip', 'tar.gz' or 'folder' or a specific writer class.

    :param overwrite: if True, overwrite the output file without asking, if it exists.
        If False, raise an
        :py:class:`~aiida.tools.importexport.common.exceptions.ArchiveExportError`
        if the output file already exists.

    :param allowed_licenses: List or function.
        If a list, then checks whether all licenses of Data nodes are in the list. If a function,
        then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.

    :param forbidden_licenses: List or function. If a list,
        then checks whether all licenses of Data nodes are in the list. If a function,
        then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.

    :param include_comments: In-/exclude export of comments for given node(s) in ``entities``.
        Default: True, *include* comments in export (as well as relevant users).

    :param include_logs: In-/exclude export of logs for given node(s) in ``entities``.
        Default: True, *include* logs in export.

    :param writer_init: Additional key-word arguments to pass to the writer class init

    :param batch_size: batch database query results in sub-collections to reduce memory usage

    :param traversal_rules: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules`
        what rule names are toggleable and what the defaults are.

    :returns: a dictionary of data regarding the export process (timings, etc)

    :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`:
        if there are any internal errors when exporting.
    :raises `~aiida.common.exceptions.LicensingException`:
        if any node is licensed under forbidden license.
    """
    # pylint: disable=too-many-locals,too-many-branches,too-many-statements

    # Backwards-compatibility
    entities = cast(
        Iterable[Any],
        deprecated_parameters(
            old={
                'name': 'what',
                'value': traversal_rules.pop('what', None)
            },
            new={
                'name': 'entities',
                'value': entities
            },
        ),
    )
    filename = cast(
        str,
        deprecated_parameters(
            old={
                'name': 'outfile',
                'value': traversal_rules.pop('outfile', None)
            },
            new={
                'name': 'filename',
                'value': filename
            },
        ),
    )
    if silent is not None:
        warnings.warn(
            'silent keyword is deprecated and will be removed in AiiDA v2.0.0, set the logger level explicitly instead',
            AiidaDeprecationWarning
        )  # pylint: disable=no-member

    type_check(
        entities,
        (list, tuple, set),
        msg='`entities` must be specified and given as a list of AiiDA entities',
    )
    entities = list(entities)
    if type_check(filename, str, allow_none=True) is None:
        filename = 'export_data.aiida'

    if not overwrite and os.path.exists(filename):
        raise exceptions.ArchiveExportError(f"The output file '{filename}' already exists")

    # validate the traversal rules and generate a full set for reporting
    validate_traversal_rules(GraphTraversalRules.EXPORT, **traversal_rules)
    full_traversal_rules = {
        name: traversal_rules.get(name, rule.default) for name, rule in GraphTraversalRules.EXPORT.value.items()
    }

    # setup the archive writer
    writer_init = writer_init or {}
    if use_compression is not None:
        warnings.warn(
            'use_compression argument is deprecated and will be removed in AiiDA v2.0.0 (which will always compress)',
            AiidaDeprecationWarning
        )  # pylint: disable=no-member
        writer_init['use_compression'] = use_compression
    if isinstance(file_format, str):
        writer = get_writer(file_format)(filepath=filename, **writer_init)
    elif issubclass(file_format, ArchiveWriterAbstract):
        writer = file_format(filepath=filename, **writer_init)
    else:
        raise TypeError('file_format must be a string or ArchiveWriterAbstract class')

    summary(
        file_format=writer.file_format_verbose,
        export_version=writer.export_version,
        outfile=filename,
        include_comments=include_comments,
        include_logs=include_logs,
        traversal_rules=full_traversal_rules
    )

    EXPORT_LOGGER.debug('STARTING EXPORT...')

    all_fields_info, unique_identifiers = get_all_fields_info()
    entities_starting_set, given_node_entry_ids = _get_starting_node_ids(entities)

    # Initialize the writer
    with writer as writer_context:

        # Iteratively explore the AiiDA graph to find further nodes that should also be exported
        with get_progress_reporter()(desc='Traversing provenance via links ...', total=1) as progress:
            traverse_output = get_nodes_export(starting_pks=given_node_entry_ids, get_links=True, **traversal_rules)
            progress.update()
        node_ids_to_be_exported = traverse_output['nodes']

        EXPORT_LOGGER.debug('WRITING METADATA...')

        writer_context.write_metadata(
            ArchiveMetadata(
                export_version=EXPORT_VERSION,
                aiida_version=get_version(),
                unique_identifiers=unique_identifiers,
                all_fields_info=all_fields_info,
                graph_traversal_rules=traverse_output['rules'],
                # Turn sets into lists to be able to export them as JSON metadata.
                entities_starting_set={
                    entity: list(entity_set) for entity, entity_set in entities_starting_set.items()
                },
                include_comments=include_comments,
                include_logs=include_logs,
            )
        )

        # Create a mapping of node PK to UUID.
        node_pk_2_uuid_mapping: Dict[int, str] = {}
        if node_ids_to_be_exported:
            qbuilder = orm.QueryBuilder().append(
                orm.Node,
                project=('id', 'uuid'),
                filters={'id': {
                    'in': node_ids_to_be_exported
                }},
            )
            node_pk_2_uuid_mapping = dict(qbuilder.all(batch_size=batch_size))

        # check that no nodes are being exported with incorrect licensing
        _check_node_licenses(node_ids_to_be_exported, allowed_licenses, forbidden_licenses)

        # write the link data
        if traverse_output['links'] is not None:
            with get_progress_reporter()(total=len(traverse_output['links']), desc='Writing links') as progress:
                for link in traverse_output['links']:
                    progress.update()
                    writer_context.write_link({
                        'input': node_pk_2_uuid_mapping[link.source_id],
                        'output': node_pk_2_uuid_mapping[link.target_id],
                        'label': link.link_label,
                        'type': link.link_type,
                    })

        # generate a list of queries to encapsulate all required entities
        entity_queries = _collect_entity_queries(
            node_ids_to_be_exported,
            entities_starting_set,
            node_pk_2_uuid_mapping,
            include_comments,
            include_logs,
        )

        total_entities = sum(query.count() for query in entity_queries.values())

        # write all entity data fields
        if total_entities:
            exported_entity_pks = _write_entity_data(
                total_entities=total_entities,
                entity_queries=entity_queries,
                writer=writer_context,
                batch_size=batch_size
            )
        else:
            exported_entity_pks = defaultdict(set)
            EXPORT_LOGGER.info('No entities were found to export')

        # write mappings of groups to the nodes they contain
        if exported_entity_pks[GROUP_ENTITY_NAME]:

            EXPORT_LOGGER.debug('Writing group UUID -> [nodes UUIDs]')

            _write_group_mappings(
                group_pks=exported_entity_pks[GROUP_ENTITY_NAME], batch_size=batch_size, writer=writer_context
            )

        # copy all required node repositories
        if exported_entity_pks[NODE_ENTITY_NAME]:

            _write_node_repositories(
                node_pks=exported_entity_pks[NODE_ENTITY_NAME],
                node_pk_2_uuid_mapping=node_pk_2_uuid_mapping,
                writer=writer_context
            )

        EXPORT_LOGGER.info('Finalizing Export...')

    # summarize export
    export_summary = '\n  - '.join(f'{name:<6}: {len(pks)}' for name, pks in exported_entity_pks.items())
    if exported_entity_pks:
        EXPORT_LOGGER.info('Exported Entities:\n  - ' + export_summary + '\n')
    # TODO
    # EXPORT_LOGGER.info('Writer Information:\n %s', writer.export_info)

    return writer
コード例 #11
0
 def test_verdi_version(self):
     """Regression test for #2238: verify that `verdi --version` prints the current version"""
     result = self.cli_runner.invoke(cmd_verdi.verdi, ['--version'])
     self.assertIsNone(result.exception, result.output)
     self.assertIn(get_version(), result.output)
コード例 #12
0
"""
Utility functions for aiida plugins.

Useful for:
 * compatibility with different aiida versions

"""
from __future__ import absolute_import

import aiida
from distutils.version import StrictVersion  # pylint: disable=no-name-in-module,import-error

AIIDA_VERSION = StrictVersion(aiida.get_version())


def load_verdi_data():
    """Load the verdi data click command group for any version since 0.11."""
    verdi_data = None
    import_errors = []

    try:
        from aiida.cmdline.commands import data_cmd as verdi_data
    except ImportError as err:
        import_errors.append(err)

    if not verdi_data:
        try:
            from aiida.cmdline.commands import verdi_data
        except ImportError as err:
            import_errors.append(err)