コード例 #1
0
def export_zip(what,
               outfile='testzip',
               overwrite=False,
               silent=False,
               use_compression=True,
               **kwargs):
    """Export in a zipped folder

    :param what: a list of entity instances; they can belong to different models/entities.
    :type what: list

    :param outfile: the filename (possibly including the absolute path) of the file on which to export.
    :type outfile: str

    :param overwrite: if True, overwrite the output file without asking, if it exists. If False, raise an
        :py:class:`~aiida.tools.importexport.common.exceptions.ArchiveExportError` if the output file already exists.
    :type overwrite: bool

    :param silent: suppress prints.
    :type silent: bool

    :param use_compression: Whether or not to compress the zip file.
    :type use_compression: bool

    :param allowed_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type allowed_licenses: list

    :param forbidden_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type forbidden_licenses: list

    :param include_comments: In-/exclude export of comments for given node(s) in ``what``.
        Default: True, *include* comments in export (as well as relevant users).
    :type include_comments: bool

    :param include_logs: In-/exclude export of logs for given node(s) in ``what``.
        Default: True, *include* logs in export.
    :type include_logs: bool

    :param kwargs: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` what rule names
        are toggleable and what the defaults are.

    :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`: if there are any internal errors when
        exporting.
    :raises `~aiida.common.exceptions.LicensingException`: if any node is licensed under forbidden license.
    """
    if not overwrite and os.path.exists(outfile):
        raise exceptions.ArchiveExportError(
            "the output file '{}' already exists".format(outfile))

    time_start = time.time()
    with ZipFolder(outfile, mode='w',
                   use_compression=use_compression) as folder:
        export_tree(what, folder=folder, silent=silent, **kwargs)
    if not silent:
        print('File written in {:10.3g} s.'.format(time.time() - time_start))
コード例 #2
0
def _write_node_repositories(
    *, node_pks: Set[int], node_pk_2_uuid_mapping: Dict[int, str], writer: ArchiveWriterAbstract
):
    """Write all exported node repositories to the archive file."""
    with get_progress_reporter()(total=len(node_pks), desc='Exporting node repositories: ') as progress:

        for pk in node_pks:

            uuid = node_pk_2_uuid_mapping[pk]

            progress.set_description_str(f'Exporting node repositories: {pk}', refresh=False)
            progress.update()

            src = RepositoryFolder(section=Repository._section_name, uuid=uuid)  # pylint: disable=protected-access
            if not src.exists():
                raise exceptions.ArchiveExportError(
                    f'Unable to find the repository folder for Node with UUID={uuid} '
                    'in the local repository'
                )
            writer.write_node_repo_folder(uuid, src._abspath)  # pylint: disable=protected-access
コード例 #3
0
def export(what,
           outfile='export_data.aiida.tar.gz',
           overwrite=False,
           silent=False,
           **kwargs):
    """Export the entries passed in the 'what' list to a file tree.

    :param what: a list of entity instances; they can belong to different models/entities.
    :type what: list

    :param outfile: the filename (possibly including the absolute path) of the file on which to export.
    :type outfile: str

    :param overwrite: if True, overwrite the output file without asking, if it exists. If False, raise an
        :py:class:`~aiida.tools.importexport.common.exceptions.ArchiveExportError` if the output file already exists.
    :type overwrite: bool

    :param silent: suppress prints.
    :type silent: bool

    :param allowed_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type allowed_licenses: list

    :param forbidden_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type forbidden_licenses: list

    :param include_comments: In-/exclude export of comments for given node(s) in ``what``.
        Default: True, *include* comments in export (as well as relevant users).
    :type include_comments: bool

    :param include_logs: In-/exclude export of logs for given node(s) in ``what``.
        Default: True, *include* logs in export.
    :type include_logs: bool

    :param kwargs: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` what rule names
        are toggleable and what the defaults are.

    :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`: if there are any internal errors when
        exporting.
    :raises `~aiida.common.exceptions.LicensingException`: if any node is licensed under forbidden license.
    """
    from aiida.common.folders import SandboxFolder

    if not overwrite and os.path.exists(outfile):
        raise exceptions.ArchiveExportError(
            "The output file '{}' already exists".format(outfile))

    folder = SandboxFolder()
    time_export_start = time.time()
    export_tree(what, folder=folder, silent=silent, **kwargs)

    time_export_end = time.time()

    if not silent:
        print('COMPRESSING...')

    time_compress_start = time.time()
    with tarfile.open(outfile,
                      'w:gz',
                      format=tarfile.PAX_FORMAT,
                      dereference=True) as tar:
        tar.add(folder.abspath, arcname='')
    time_compress_end = time.time()

    if not silent:
        filecr_time = time_export_end - time_export_start
        filecomp_time = time_compress_end - time_compress_start
        print('Exported in {:6.2g}s, compressed in {:6.2g}s, total: {:6.2g}s.'.
              format(filecr_time, filecomp_time, filecr_time + filecomp_time))

    if not silent:
        print('DONE.')
コード例 #4
0
def export_tree(what,
                folder,
                allowed_licenses=None,
                forbidden_licenses=None,
                silent=False,
                include_comments=True,
                include_logs=True,
                **kwargs):
    """Export the entries passed in the 'what' list to a file tree.

    :param what: a list of entity instances; they can belong to different models/entities.
    :type what: list

    :param folder: a temporary folder to build the archive before compression.
    :type folder: :py:class:`~aiida.common.folders.Folder`

    :param allowed_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type allowed_licenses: list

    :param forbidden_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type forbidden_licenses: list

    :param silent: suppress prints.
    :type silent: bool

    :param include_comments: In-/exclude export of comments for given node(s) in ``what``.
        Default: True, *include* comments in export (as well as relevant users).
    :type include_comments: bool

    :param include_logs: In-/exclude export of logs for given node(s) in ``what``.
        Default: True, *include* logs in export.
    :type include_logs: bool

    :param kwargs: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` what rule names
        are toggleable and what the defaults are.

    :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`: if there are any internal errors when
        exporting.
    :raises `~aiida.common.exceptions.LicensingException`: if any node is licensed under forbidden license.
    """
    from collections import defaultdict

    if not silent:
        print('STARTING EXPORT...')

    all_fields_info, unique_identifiers = get_all_fields_info()

    entities_starting_set = defaultdict(set)

    # The set that contains the nodes ids of the nodes that should be exported
    given_data_entry_ids = set()
    given_calculation_entry_ids = set()
    given_group_entry_ids = set()
    given_computer_entry_ids = set()
    given_groups = set()
    given_log_entry_ids = set()
    given_comment_entry_ids = set()

    # I store a list of the actual dbnodes
    for entry in what:
        # This returns the class name (as in imports). E.g. for a model node:
        # aiida.backends.djsite.db.models.DbNode
        # entry_class_string = get_class_string(entry)
        # Now a load the backend-independent name into entry_entity_name, e.g. Node!
        # entry_entity_name = schema_to_entity_names(entry_class_string)
        if issubclass(entry.__class__, orm.Group):
            entities_starting_set[GROUP_ENTITY_NAME].add(entry.uuid)
            given_group_entry_ids.add(entry.id)
            given_groups.add(entry)
        elif issubclass(entry.__class__, orm.Node):
            entities_starting_set[NODE_ENTITY_NAME].add(entry.uuid)
            if issubclass(entry.__class__, orm.Data):
                given_data_entry_ids.add(entry.pk)
            elif issubclass(entry.__class__, orm.ProcessNode):
                given_calculation_entry_ids.add(entry.pk)
        elif issubclass(entry.__class__, orm.Computer):
            entities_starting_set[COMPUTER_ENTITY_NAME].add(entry.uuid)
            given_computer_entry_ids.add(entry.pk)
        else:
            raise exceptions.ArchiveExportError(
                'I was given {} ({}), which is not a Node, Computer, or Group instance'
                .format(entry, type(entry)))

    # Add all the nodes contained within the specified groups
    for group in given_groups:
        for entry in group.nodes:
            entities_starting_set[NODE_ENTITY_NAME].add(entry.uuid)
            if issubclass(entry.__class__, orm.Data):
                given_data_entry_ids.add(entry.pk)
            elif issubclass(entry.__class__, orm.ProcessNode):
                given_calculation_entry_ids.add(entry.pk)

    for entity, entity_set in entities_starting_set.items():
        entities_starting_set[entity] = list(entity_set)

    # We will iteratively explore the AiiDA graph to find further nodes that
    # should also be exported.
    # At the same time, we will create the links_uuid list of dicts to be exported

    if not silent:
        print('RETRIEVING LINKED NODES AND STORING LINKS...')

    to_be_exported, links_uuid, graph_traversal_rules = retrieve_linked_nodes(
        given_calculation_entry_ids, given_data_entry_ids, **kwargs)

    ## Universal "entities" attributed to all types of nodes
    # Logs
    if include_logs and to_be_exported:
        # Get related log(s) - universal for all nodes
        builder = orm.QueryBuilder()
        builder.append(orm.Log,
                       filters={'dbnode_id': {
                           'in': to_be_exported
                       }},
                       project='id')
        res = {_[0] for _ in builder.all()}
        given_log_entry_ids.update(res)

    # Comments
    if include_comments and to_be_exported:
        # Get related log(s) - universal for all nodes
        builder = orm.QueryBuilder()
        builder.append(orm.Comment,
                       filters={'dbnode_id': {
                           'in': to_be_exported
                       }},
                       project='id')
        res = {_[0] for _ in builder.all()}
        given_comment_entry_ids.update(res)

    # Here we get all the columns that we plan to project per entity that we
    # would like to extract
    given_entities = list()
    if given_group_entry_ids:
        given_entities.append(GROUP_ENTITY_NAME)
    if to_be_exported:
        given_entities.append(NODE_ENTITY_NAME)
    if given_computer_entry_ids:
        given_entities.append(COMPUTER_ENTITY_NAME)
    if given_log_entry_ids:
        given_entities.append(LOG_ENTITY_NAME)
    if given_comment_entry_ids:
        given_entities.append(COMMENT_ENTITY_NAME)

    entries_to_add = dict()
    for given_entity in given_entities:
        project_cols = ['id']
        # The following gets a list of fields that we need,
        # e.g. user, mtime, uuid, computer
        entity_prop = all_fields_info[given_entity].keys()

        # Here we do the necessary renaming of properties
        for prop in entity_prop:
            # nprop contains the list of projections
            nprop = (file_fields_to_model_fields[given_entity][prop] if prop
                     in file_fields_to_model_fields[given_entity] else prop)
            project_cols.append(nprop)

        # Getting the ids that correspond to the right entity
        if given_entity == GROUP_ENTITY_NAME:
            entry_ids_to_add = given_group_entry_ids
        elif given_entity == NODE_ENTITY_NAME:
            entry_ids_to_add = to_be_exported
        elif given_entity == COMPUTER_ENTITY_NAME:
            entry_ids_to_add = given_computer_entry_ids
        elif given_entity == LOG_ENTITY_NAME:
            entry_ids_to_add = given_log_entry_ids
        elif given_entity == COMMENT_ENTITY_NAME:
            entry_ids_to_add = given_comment_entry_ids

        builder = orm.QueryBuilder()
        builder.append(entity_names_to_entities[given_entity],
                       filters={'id': {
                           'in': entry_ids_to_add
                       }},
                       project=project_cols,
                       tag=given_entity,
                       outerjoin=True)
        entries_to_add[given_entity] = builder

    # TODO (Spyros) To see better! Especially for functional licenses
    # Check the licenses of exported data.
    if allowed_licenses is not None or forbidden_licenses is not None:
        builder = orm.QueryBuilder()
        builder.append(orm.Node,
                       project=['id', 'attributes.source.license'],
                       filters={'id': {
                           'in': to_be_exported
                       }})
        # Skip those nodes where the license is not set (this is the standard behavior with Django)
        node_licenses = list(
            (a, b) for [a, b] in builder.all() if b is not None)
        check_licenses(node_licenses, allowed_licenses, forbidden_licenses)

    ############################################################
    ##### Start automatic recursive export data generation #####
    ############################################################
    if not silent:
        print('STORING DATABASE ENTRIES...')

    export_data = dict()
    entity_separator = '_'
    for entity_name, partial_query in entries_to_add.items():

        foreign_fields = {
            k: v
            for k, v in all_fields_info[entity_name].items()
            # all_fields_info[model_name].items()
            if 'requires' in v
        }

        for value in foreign_fields.values():
            ref_model_name = value['requires']
            fill_in_query(partial_query, entity_name, ref_model_name,
                          [entity_name], entity_separator)

        for temp_d in partial_query.iterdict():
            for k in temp_d.keys():
                # Get current entity
                current_entity = k.split(entity_separator)[-1]

                # This is a empty result of an outer join.
                # It should not be taken into account.
                if temp_d[k]['id'] is None:
                    continue

                temp_d2 = {
                    temp_d[k]['id']:
                    serialize_dict(temp_d[k],
                                   remove_fields=['id'],
                                   rename_fields=model_fields_to_file_fields[
                                       current_entity])
                }
                try:
                    export_data[current_entity].update(temp_d2)
                except KeyError:
                    export_data[current_entity] = temp_d2

    #######################################
    # Manually manage attributes and extras
    #######################################
    # I use .get because there may be no nodes to export
    all_nodes_pk = list()
    if NODE_ENTITY_NAME in export_data:
        all_nodes_pk.extend(export_data.get(NODE_ENTITY_NAME).keys())

    if sum(len(model_data) for model_data in export_data.values()) == 0:
        if not silent:
            print('No nodes to store, exiting...')
        return

    if not silent:
        print('Exporting a total of {} db entries, of which {} nodes.'.format(
            sum(len(model_data) for model_data in export_data.values()),
            len(all_nodes_pk)))

    # ATTRIBUTES and EXTRAS
    if not silent:
        print('STORING NODE ATTRIBUTES AND EXTRAS...')
    node_attributes = {}
    node_extras = {}

    # A second QueryBuilder query to get the attributes and extras. See if this can be optimized
    if all_nodes_pk:
        all_nodes_query = orm.QueryBuilder()
        all_nodes_query.append(orm.Node,
                               filters={'id': {
                                   'in': all_nodes_pk
                               }},
                               project=['id', 'attributes', 'extras'])
        for res_pk, res_attributes, res_extras in all_nodes_query.iterall():
            node_attributes[str(res_pk)] = res_attributes
            node_extras[str(res_pk)] = res_extras

    if not silent:
        print('STORING GROUP ELEMENTS...')
    groups_uuid = dict()
    # If a group is in the exported date, we export the group/node correlation
    if GROUP_ENTITY_NAME in export_data:
        for curr_group in export_data[GROUP_ENTITY_NAME]:
            group_uuid_qb = orm.QueryBuilder()
            group_uuid_qb.append(entity_names_to_entities[GROUP_ENTITY_NAME],
                                 filters={'id': {
                                     '==': curr_group
                                 }},
                                 project=['uuid'],
                                 tag='group')
            group_uuid_qb.append(entity_names_to_entities[NODE_ENTITY_NAME],
                                 project=['uuid'],
                                 with_group='group')
            for res in group_uuid_qb.iterall():
                if str(res[0]) in groups_uuid:
                    groups_uuid[str(res[0])].append(str(res[1]))
                else:
                    groups_uuid[str(res[0])] = [str(res[1])]

    #######################################
    # Final check for unsealed ProcessNodes
    #######################################
    process_nodes = set()
    for node_pk, content in export_data.get(NODE_ENTITY_NAME, {}).items():
        if content['node_type'].startswith('process.'):
            process_nodes.add(node_pk)

    check_process_nodes_sealed(process_nodes)

    ######################################
    # Now I store
    ######################################
    # subfolder inside the export package
    nodesubfolder = folder.get_subfolder(NODES_EXPORT_SUBFOLDER,
                                         create=True,
                                         reset_limit=True)

    if not silent:
        print('STORING DATA...')

    data = {
        'node_attributes': node_attributes,
        'node_extras': node_extras,
        'export_data': export_data,
        'links_uuid': links_uuid,
        'groups_uuid': groups_uuid
    }

    # N.B. We're really calling zipfolder.open (if exporting a zipfile)
    with folder.open('data.json', mode='w') as fhandle:
        # fhandle.write(json.dumps(data, cls=UUIDEncoder))
        fhandle.write(json.dumps(data))

    # Add proper signature to unique identifiers & all_fields_info
    # Ignore if a key doesn't exist in any of the two dictionaries

    metadata = {
        'aiida_version': get_version(),
        'export_version': EXPORT_VERSION,
        'all_fields_info': all_fields_info,
        'unique_identifiers': unique_identifiers,
        'export_parameters': {
            'graph_traversal_rules': graph_traversal_rules,
            'entities_starting_set': entities_starting_set,
            'include_comments': include_comments,
            'include_logs': include_logs
        }
    }

    with folder.open('metadata.json', 'w') as fhandle:
        fhandle.write(json.dumps(metadata))

    if silent is not True:
        print('STORING REPOSITORY FILES...')

    # If there are no nodes, there are no repository files to store
    if all_nodes_pk:
        # Large speed increase by not getting the node itself and looping in memory in python, but just getting the uuid
        uuid_query = orm.QueryBuilder()
        uuid_query.append(orm.Node,
                          filters={'id': {
                              'in': all_nodes_pk
                          }},
                          project=['uuid'])
        for res in uuid_query.all():
            uuid = str(res[0])
            sharded_uuid = export_shard_uuid(uuid)

            # Important to set create=False, otherwise creates twice a subfolder. Maybe this is a bug of insert_path?
            thisnodefolder = nodesubfolder.get_subfolder(sharded_uuid,
                                                         create=False,
                                                         reset_limit=True)

            # Make sure the node's repository folder was not deleted
            src = RepositoryFolder(section=Repository._section_name, uuid=uuid)  # pylint: disable=protected-access
            if not src.exists():
                raise exceptions.ArchiveExportError(
                    'Unable to find the repository folder for Node with UUID={} in the local repository'
                    .format(uuid))

            # In this way, I copy the content of the folder, and not the folder itself
            thisnodefolder.insert_path(src=src.abspath, dest_name='.')
コード例 #5
0
def write_to_archive(
    folder: Union[Folder, ZipFolder],
    metadata: dict,
    all_node_uuids: Set[str],
    export_data: Dict[str, Dict[int, dict]],
    node_attributes: Dict[str, dict],
    node_extras: Dict[str, dict],
    groups_uuid: Dict[str, List[str]],
    links_uuid: List[dict],
    silent: bool,
) -> None:
    """Store data to the archive."""
    ######################################
    # Now collecting and storing
    ######################################
    # subfolder inside the export package
    nodesubfolder = folder.get_subfolder(NODES_EXPORT_SUBFOLDER,
                                         create=True,
                                         reset_limit=True)

    EXPORT_LOGGER.debug("ADDING DATA TO EXPORT ARCHIVE...")

    data = {
        "node_attributes": node_attributes,
        "node_extras": node_extras,
        "export_data": export_data,
        "links_uuid": links_uuid,
        "groups_uuid": groups_uuid,
    }

    # N.B. We're really calling zipfolder.open (if exporting a zipfile)
    with folder.open("data.json", mode="w") as fhandle:
        # fhandle.write(json.dumps(data, cls=UUIDEncoder))
        fhandle.write(json.dumps(data))

    with folder.open("metadata.json", "w") as fhandle:
        fhandle.write(json.dumps(metadata))

    EXPORT_LOGGER.debug("ADDING REPOSITORY FILES TO EXPORT ARCHIVE...")

    # If there are no nodes, there are no repository files to store
    if all_node_uuids:

        progress_bar = get_progress_bar(total=len(all_node_uuids),
                                        disable=silent)
        pbar_base_str = "Exporting repository - "

        for uuid in all_node_uuids:
            sharded_uuid = export_shard_uuid(uuid)

            progress_bar.set_description_str(
                f"{pbar_base_str}UUID={uuid.split('-')[0]}", refresh=False)
            progress_bar.update()

            # Important to set create=False, otherwise creates twice a subfolder.
            # Maybe this is a bug of insert_path?
            thisnodefolder = nodesubfolder.get_subfolder(sharded_uuid,
                                                         create=False,
                                                         reset_limit=True)

            # Make sure the node's repository folder was not deleted
            src = RepositoryFolder(section=Repository._section_name, uuid=uuid)  # pylint: disable=protected-access
            if not src.exists():
                raise exceptions.ArchiveExportError(
                    f"Unable to find the repository folder for Node with UUID={uuid} "
                    "in the local repository")

            # In this way, I copy the content of the folder, and not the folder itself
            thisnodefolder.insert_path(src=src.abspath, dest_name=".")
コード例 #6
0
def export(
    entities: Optional[Iterable[Any]] = None,
    filename: Optional[str] = None,
    file_format: str = ExportFileFormat.ZIP,
    overwrite: bool = False,
    silent: bool = False,
    use_compression: bool = True,
    **kwargs: Any,
) -> None:
    """Export AiiDA data

    .. deprecated:: 1.2.1
        Support for the parameters `what` and `outfile` will be removed in `v2.0.0`.
        Please use `entities` and `filename` instead, respectively.

    :param entities: a list of entity instances;
        they can belong to different models/entities.

    :param filename: the filename (possibly including the absolute path)
        of the file on which to export.

    :param file_format: See `ExportFileFormat` for complete list of valid values (default: 'zip').

    :param overwrite: if True, overwrite the output file without asking, if it exists.
        If False, raise an
        :py:class:`~aiida.tools.importexport.common.exceptions.ArchiveExportError`
        if the output file already exists.

    :param silent: suppress console prints and progress bar.

    :param use_compression: Whether or not to compress the archive file
        (only valid for the zip file format).

    :param allowed_licenses: List or function.
        If a list, then checks whether all licenses of Data nodes are in the list. If a function,
        then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type allowed_licenses: list

    :param forbidden_licenses: List or function. If a list,
        then checks whether all licenses of Data nodes are in the list. If a function,
        then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type forbidden_licenses: list

    :param include_comments: In-/exclude export of comments for given node(s) in ``entities``.
        Default: True, *include* comments in export (as well as relevant users).
    :type include_comments: bool

    :param include_logs: In-/exclude export of logs for given node(s) in ``entities``.
        Default: True, *include* logs in export.
    :type include_logs: bool

    :param kwargs: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules`
        what rule names are toggleable and what the defaults are.

    :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`:
        if there are any internal errors when exporting.
    :raises `~aiida.common.exceptions.LicensingException`:
        if any node is licensed under forbidden license.
    """
    if file_format not in list(ExportFileFormat):
        raise exceptions.ArchiveExportError(
            'Can only export in the formats: {}, please specify one for "file_format".'
            .format(tuple(_.value for _ in ExportFileFormat)))

    # Backwards-compatibility
    entities = cast(
        Iterable[Any],
        deprecated_parameters(
            old={
                "name": "what",
                "value": kwargs.pop("what", None)
            },
            new={
                "name": "entities",
                "value": entities
            },
        ),
    )
    filename = cast(
        str,
        deprecated_parameters(
            old={
                "name": "outfile",
                "value": kwargs.pop("outfile", None)
            },
            new={
                "name": "filename",
                "value": filename
            },
        ),
    )

    type_check(
        entities,
        (list, tuple, set),
        msg=
        "`entities` must be specified and given as a list of AiiDA entities",
    )
    entities = list(entities)
    if type_check(filename, str, allow_none=True) is None:
        filename = "export_data.aiida"

    if not overwrite and os.path.exists(filename):
        raise exceptions.ArchiveExportError(
            f"The output file '{filename}' already exists")

    if silent:
        logging.disable(level=logging.CRITICAL)

    if file_format == ExportFileFormat.TAR_GZIPPED:
        file_format_verbose = "Gzipped tarball (compressed)"
    # Must be a zip then
    elif use_compression:
        file_format_verbose = "Zip (compressed)"
    else:
        file_format_verbose = "Zip (uncompressed)"
    summary(file_format_verbose, filename, **kwargs)

    try:
        if file_format == ExportFileFormat.TAR_GZIPPED:
            times = export_tar(entities=entities,
                               filename=filename,
                               silent=silent,
                               **kwargs)
        else:  # zip
            times = export_zip(
                entities=entities,
                filename=filename,
                use_compression=use_compression,
                silent=silent,
                **kwargs,
            )
    except (exceptions.ArchiveExportError, LicensingException) as exc:
        if os.path.exists(filename):
            os.remove(filename)
        raise exc

    if len(times) == 2:
        export_start, export_end = times  # pylint: disable=unbalanced-tuple-unpacking
        EXPORT_LOGGER.debug("Exported in %6.2g s.", export_end - export_start)
    elif len(times) == 4:
        export_start, export_end, compress_start, compress_end = times
        EXPORT_LOGGER.debug(
            "Exported in %6.2g s, compressed in %6.2g s, total: %6.2g s.",
            export_end - export_start,
            compress_end - compress_start,
            compress_end - export_start,
        )
    else:
        EXPORT_LOGGER.debug("No information about the timing of the export.")

    # Reset logging level
    if silent:
        logging.disable(level=logging.NOTSET)
コード例 #7
0
def get_starting_node_ids(
        entities: List[Any],
        silent: bool) -> Tuple[DefaultDict[str, Set[str]], Set[int]]:
    """Get the starting node UUIDs and PKs

    :param entities: a list of entity instances
    :param silent: suppress console prints and progress bar.

    :raises exceptions.ArchiveExportError
    :return: entities_starting_set, given_node_entry_ids
    """
    # Instantiate progress bar - go through list of `entities`
    pbar_total = len(entities) + 1 if entities else 1
    progress_bar = get_progress_bar(total=pbar_total,
                                    leave=False,
                                    disable=silent)
    progress_bar.set_description_str("Collecting chosen entities",
                                     refresh=False)

    entities_starting_set = defaultdict(set)
    given_node_entry_ids = set()

    # I store a list of the actual dbnodes
    for entry in entities:
        progress_bar.update()

        # This returns the class name (as in imports). E.g. for a model node:
        # aiida.backends.djsite.db.models.DbNode
        # entry_class_string = get_class_string(entry)
        # Now a load the backend-independent name into entry_entity_name, e.g. Node!
        # entry_entity_name = schema_to_entity_names(entry_class_string)
        if issubclass(entry.__class__, orm.Group):
            entities_starting_set[GROUP_ENTITY_NAME].add(entry.uuid)
        elif issubclass(entry.__class__, orm.Node):
            entities_starting_set[NODE_ENTITY_NAME].add(entry.uuid)
            given_node_entry_ids.add(entry.pk)
        elif issubclass(entry.__class__, orm.Computer):
            entities_starting_set[COMPUTER_ENTITY_NAME].add(entry.uuid)
        else:
            raise exceptions.ArchiveExportError(
                f"I was given {entry} ({type(entry)}),"
                " which is not a Node, Computer, or Group instance")

    # Add all the nodes contained within the specified groups
    if GROUP_ENTITY_NAME in entities_starting_set:

        progress_bar.set_description_str("Retrieving Nodes from Groups ...",
                                         refresh=True)

        # Use single query instead of given_group.nodes iterator for performance.
        qh_groups = (orm.QueryBuilder().append(
            orm.Group,
            filters={
                "uuid": {
                    "in": entities_starting_set[GROUP_ENTITY_NAME]
                }
            },
            tag="groups",
        ).queryhelp)

        # Delete this import once the dbexport.zip module has been renamed
        from builtins import zip  # pylint: disable=redefined-builtin

        node_results = (orm.QueryBuilder(**qh_groups).append(
            orm.Node, project=["id", "uuid"], with_group="groups").all())
        if node_results:
            pks, uuids = map(list, zip(*node_results))
            entities_starting_set[NODE_ENTITY_NAME].update(uuids)
            given_node_entry_ids.update(pks)
            del node_results, pks, uuids

        progress_bar.update()

    return entities_starting_set, given_node_entry_ids
コード例 #8
0
def export_tree(entities=None,
                folder=None,
                allowed_licenses=None,
                forbidden_licenses=None,
                silent=False,
                include_comments=True,
                include_logs=True,
                **kwargs):
    """Export the entries passed in the 'entities' list to a file tree.

    .. deprecated:: 1.2.1
        Support for the parameter `what` will be removed in `v2.0.0`. Please use `entities` instead.

    :param entities: a list of entity instances; they can belong to different models/entities.
    :type entities: list

    :param folder: a temporary folder to build the archive before compression.
    :type folder: :py:class:`~aiida.common.folders.Folder`

    :param allowed_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type allowed_licenses: list

    :param forbidden_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type forbidden_licenses: list

    :param silent: suppress console prints and progress bar.
    :type silent: bool

    :param include_comments: In-/exclude export of comments for given node(s) in ``entities``.
        Default: True, *include* comments in export (as well as relevant users).
    :type include_comments: bool

    :param include_logs: In-/exclude export of logs for given node(s) in ``entities``.
        Default: True, *include* logs in export.
    :type include_logs: bool

    :param kwargs: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` what rule names
        are toggleable and what the defaults are.

    :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`: if there are any internal errors when
        exporting.
    :raises `~aiida.common.exceptions.LicensingException`: if any node is licensed under forbidden license.
    """
    from collections import defaultdict
    from aiida.tools.graph.graph_traversers import get_nodes_export

    if silent:
        logging.disable(level=logging.CRITICAL)

    EXPORT_LOGGER.debug('STARTING EXPORT...')

    # Backwards-compatibility
    entities = deprecated_parameters(
        old={
            'name': 'what',
            'value': kwargs.pop('what', None)
        },
        new={
            'name': 'entities',
            'value': entities
        },
    )

    type_check(
        entities, (list, tuple, set),
        msg='`entities` must be specified and given as a list of AiiDA entities'
    )
    entities = list(entities)

    type_check(
        folder, (Folder, ZipFolder),
        msg='`folder` must be specified and given as an AiiDA Folder entity')

    all_fields_info, unique_identifiers = get_all_fields_info()

    entities_starting_set = defaultdict(set)

    # The set that contains the nodes ids of the nodes that should be exported
    given_node_entry_ids = set()
    given_log_entry_ids = set()
    given_comment_entry_ids = set()

    # Instantiate progress bar - go through list of `entities`
    pbar_total = len(entities) + 1 if entities else 1
    progress_bar = get_progress_bar(total=pbar_total,
                                    leave=False,
                                    disable=silent)
    progress_bar.set_description_str('Collecting chosen entities',
                                     refresh=False)

    # I store a list of the actual dbnodes
    for entry in entities:
        progress_bar.update()

        # This returns the class name (as in imports). E.g. for a model node:
        # aiida.backends.djsite.db.models.DbNode
        # entry_class_string = get_class_string(entry)
        # Now a load the backend-independent name into entry_entity_name, e.g. Node!
        # entry_entity_name = schema_to_entity_names(entry_class_string)
        if issubclass(entry.__class__, orm.Group):
            entities_starting_set[GROUP_ENTITY_NAME].add(entry.uuid)
        elif issubclass(entry.__class__, orm.Node):
            entities_starting_set[NODE_ENTITY_NAME].add(entry.uuid)
            given_node_entry_ids.add(entry.pk)
        elif issubclass(entry.__class__, orm.Computer):
            entities_starting_set[COMPUTER_ENTITY_NAME].add(entry.uuid)
        else:
            raise exceptions.ArchiveExportError(
                'I was given {} ({}), which is not a Node, Computer, or Group instance'
                .format(entry, type(entry)))

    # Add all the nodes contained within the specified groups
    if GROUP_ENTITY_NAME in entities_starting_set:

        progress_bar.set_description_str('Retrieving Nodes from Groups ...',
                                         refresh=True)

        # Use single query instead of given_group.nodes iterator for performance.
        qh_groups = orm.QueryBuilder().append(
            orm.Group,
            filters={
                'uuid': {
                    'in': entities_starting_set[GROUP_ENTITY_NAME]
                }
            },
            tag='groups').queryhelp

        # Delete this import once the dbexport.zip module has been renamed
        from builtins import zip  # pylint: disable=redefined-builtin

        node_results = orm.QueryBuilder(**qh_groups).append(
            orm.Node, project=['id', 'uuid'], with_group='groups').all()
        if node_results:
            pks, uuids = map(list, zip(*node_results))
            entities_starting_set[NODE_ENTITY_NAME].update(uuids)
            given_node_entry_ids.update(pks)
            del node_results, pks, uuids

        progress_bar.update()

    # We will iteratively explore the AiiDA graph to find further nodes that should also be exported.
    # At the same time, we will create the links_uuid list of dicts to be exported

    progress_bar = get_progress_bar(total=1, disable=silent)
    progress_bar.set_description_str(
        'Getting provenance and storing links ...', refresh=True)

    traverse_output = get_nodes_export(starting_pks=given_node_entry_ids,
                                       get_links=True,
                                       **kwargs)
    node_ids_to_be_exported = traverse_output['nodes']
    graph_traversal_rules = traverse_output['rules']

    # A utility dictionary for mapping PK to UUID.
    if node_ids_to_be_exported:
        qbuilder = orm.QueryBuilder().append(
            orm.Node,
            project=('id', 'uuid'),
            filters={'id': {
                'in': node_ids_to_be_exported
            }},
        )
        node_pk_2_uuid_mapping = dict(qbuilder.all())
    else:
        node_pk_2_uuid_mapping = {}

    # The set of tuples now has to be transformed to a list of dicts
    links_uuid = [{
        'input': node_pk_2_uuid_mapping[link.source_id],
        'output': node_pk_2_uuid_mapping[link.target_id],
        'label': link.link_label,
        'type': link.link_type
    } for link in traverse_output['links']]

    progress_bar.update()

    # Progress bar initialization - Entities
    progress_bar = get_progress_bar(total=1, disable=silent)
    progress_bar.set_description_str('Initializing export of all entities',
                                     refresh=True)

    ## Universal "entities" attributed to all types of nodes
    # Logs
    if include_logs and node_ids_to_be_exported:
        # Get related log(s) - universal for all nodes
        builder = orm.QueryBuilder()
        builder.append(orm.Log,
                       filters={'dbnode_id': {
                           'in': node_ids_to_be_exported
                       }},
                       project='uuid')
        res = set(builder.all(flat=True))
        given_log_entry_ids.update(res)

    # Comments
    if include_comments and node_ids_to_be_exported:
        # Get related log(s) - universal for all nodes
        builder = orm.QueryBuilder()
        builder.append(orm.Comment,
                       filters={'dbnode_id': {
                           'in': node_ids_to_be_exported
                       }},
                       project='uuid')
        res = set(builder.all(flat=True))
        given_comment_entry_ids.update(res)

    # Here we get all the columns that we plan to project per entity that we would like to extract
    given_entities = set(entities_starting_set.keys())
    if node_ids_to_be_exported:
        given_entities.add(NODE_ENTITY_NAME)
    if given_log_entry_ids:
        given_entities.add(LOG_ENTITY_NAME)
    if given_comment_entry_ids:
        given_entities.add(COMMENT_ENTITY_NAME)

    progress_bar.update()

    if given_entities:
        progress_bar = get_progress_bar(total=len(given_entities),
                                        disable=silent)
        pbar_base_str = 'Preparing entities'

    entries_to_add = dict()
    for given_entity in given_entities:
        progress_bar.set_description_str(pbar_base_str +
                                         ' - {}s'.format(given_entity),
                                         refresh=False)
        progress_bar.update()

        project_cols = ['id']
        # The following gets a list of fields that we need,
        # e.g. user, mtime, uuid, computer
        entity_prop = all_fields_info[given_entity].keys()

        # Here we do the necessary renaming of properties
        for prop in entity_prop:
            # nprop contains the list of projections
            nprop = (file_fields_to_model_fields[given_entity][prop] if prop
                     in file_fields_to_model_fields[given_entity] else prop)
            project_cols.append(nprop)

        # Getting the ids that correspond to the right entity
        entry_uuids_to_add = entities_starting_set.get(given_entity, set())
        if not entry_uuids_to_add:
            if given_entity == LOG_ENTITY_NAME:
                entry_uuids_to_add = given_log_entry_ids
            elif given_entity == COMMENT_ENTITY_NAME:
                entry_uuids_to_add = given_comment_entry_ids
        elif given_entity == NODE_ENTITY_NAME:
            entry_uuids_to_add.update(
                {node_pk_2_uuid_mapping[_]
                 for _ in node_ids_to_be_exported})

        builder = orm.QueryBuilder()
        builder.append(entity_names_to_entities[given_entity],
                       filters={'uuid': {
                           'in': entry_uuids_to_add
                       }},
                       project=project_cols,
                       tag=given_entity,
                       outerjoin=True)
        entries_to_add[given_entity] = builder

    # TODO (Spyros) To see better! Especially for functional licenses
    # Check the licenses of exported data.
    if allowed_licenses is not None or forbidden_licenses is not None:
        builder = orm.QueryBuilder()
        builder.append(orm.Node,
                       project=['id', 'attributes.source.license'],
                       filters={'id': {
                           'in': node_ids_to_be_exported
                       }})
        # Skip those nodes where the license is not set (this is the standard behavior with Django)
        node_licenses = list(
            (a, b) for [a, b] in builder.all() if b is not None)
        check_licenses(node_licenses, allowed_licenses, forbidden_licenses)

    ############################################################
    ##### Start automatic recursive export data generation #####
    ############################################################
    EXPORT_LOGGER.debug('GATHERING DATABASE ENTRIES...')

    if entries_to_add:
        progress_bar = get_progress_bar(total=len(entries_to_add),
                                        disable=silent)

    export_data = defaultdict(dict)
    entity_separator = '_'
    for entity_name, partial_query in entries_to_add.items():

        progress_bar.set_description_str('Exporting {}s'.format(entity_name),
                                         refresh=False)
        progress_bar.update()

        foreign_fields = {
            k: v
            for k, v in all_fields_info[entity_name].items() if 'requires' in v
        }

        for value in foreign_fields.values():
            ref_model_name = value['requires']
            fill_in_query(partial_query, entity_name, ref_model_name,
                          [entity_name], entity_separator)

        for temp_d in partial_query.iterdict():
            for key in temp_d:
                # Get current entity
                current_entity = key.split(entity_separator)[-1]

                # This is a empty result of an outer join.
                # It should not be taken into account.
                if temp_d[key]['id'] is None:
                    continue

                export_data[current_entity].update({
                    temp_d[key]['id']:
                    serialize_dict(temp_d[key],
                                   remove_fields=['id'],
                                   rename_fields=model_fields_to_file_fields[
                                       current_entity])
                })

    # Close progress up until this point in order to print properly
    close_progress_bar(leave=False)

    #######################################
    # Manually manage attributes and extras
    #######################################
    # Pointer. Renaming, since Nodes have now technically been retrieved and "stored"
    all_node_pks = node_ids_to_be_exported

    model_data = sum(len(model_data) for model_data in export_data.values())
    if not model_data:
        EXPORT_LOGGER.log(msg='Nothing to store, exiting...',
                          level=LOG_LEVEL_REPORT)
        return
    EXPORT_LOGGER.log(
        msg='Exporting a total of {} database entries, of which {} are Nodes.'.
        format(model_data, len(all_node_pks)),
        level=LOG_LEVEL_REPORT)

    # Instantiate new progress bar
    progress_bar = get_progress_bar(total=1, leave=False, disable=silent)

    # ATTRIBUTES and EXTRAS
    EXPORT_LOGGER.debug('GATHERING NODE ATTRIBUTES AND EXTRAS...')
    node_attributes = {}
    node_extras = {}

    # Another QueryBuilder query to get the attributes and extras. TODO: See if this can be optimized
    if all_node_pks:
        all_nodes_query = orm.QueryBuilder().append(
            orm.Node,
            filters={'id': {
                'in': all_node_pks
            }},
            project=['id', 'attributes', 'extras'])

        progress_bar = get_progress_bar(total=all_nodes_query.count(),
                                        disable=silent)
        progress_bar.set_description_str('Exporting Attributes and Extras',
                                         refresh=False)

        for node_pk, attributes, extras in all_nodes_query.iterall():
            progress_bar.update()

            node_attributes[str(node_pk)] = attributes
            node_extras[str(node_pk)] = extras

    EXPORT_LOGGER.debug('GATHERING GROUP ELEMENTS...')
    groups_uuid = defaultdict(list)
    # If a group is in the exported data, we export the group/node correlation
    if GROUP_ENTITY_NAME in export_data:
        group_uuids_with_node_uuids = orm.QueryBuilder().append(
            orm.Group,
            filters={
                'id': {
                    'in': export_data[GROUP_ENTITY_NAME]
                }
            },
            project='uuid',
            tag='groups').append(orm.Node, project='uuid', with_group='groups')

        # This part is _only_ for the progress bar
        total_node_uuids_for_groups = group_uuids_with_node_uuids.count()
        if total_node_uuids_for_groups:
            progress_bar = get_progress_bar(total=total_node_uuids_for_groups,
                                            disable=silent)
            progress_bar.set_description_str('Exporting Groups ...',
                                             refresh=False)

        for group_uuid, node_uuid in group_uuids_with_node_uuids.iterall():
            progress_bar.update()

            groups_uuid[group_uuid].append(node_uuid)

    #######################################
    # Final check for unsealed ProcessNodes
    #######################################
    process_nodes = set()
    for node_pk, content in export_data.get(NODE_ENTITY_NAME, {}).items():
        if content['node_type'].startswith('process.'):
            process_nodes.add(node_pk)

    check_process_nodes_sealed(process_nodes)

    ######################################
    # Now collecting and storing
    ######################################
    # subfolder inside the export package
    nodesubfolder = folder.get_subfolder(NODES_EXPORT_SUBFOLDER,
                                         create=True,
                                         reset_limit=True)

    EXPORT_LOGGER.debug('ADDING DATA TO EXPORT ARCHIVE...')

    data = {
        'node_attributes': node_attributes,
        'node_extras': node_extras,
        'export_data': export_data,
        'links_uuid': links_uuid,
        'groups_uuid': groups_uuid
    }

    # N.B. We're really calling zipfolder.open (if exporting a zipfile)
    with folder.open('data.json', mode='w') as fhandle:
        # fhandle.write(json.dumps(data, cls=UUIDEncoder))
        fhandle.write(json.dumps(data))

    # Turn sets into lists to be able to export them as JSON metadata.
    for entity, entity_set in entities_starting_set.items():
        entities_starting_set[entity] = list(entity_set)

    metadata = {
        'aiida_version': get_version(),
        'export_version': EXPORT_VERSION,
        'all_fields_info': all_fields_info,
        'unique_identifiers': unique_identifiers,
        'export_parameters': {
            'graph_traversal_rules': graph_traversal_rules,
            'entities_starting_set': entities_starting_set,
            'include_comments': include_comments,
            'include_logs': include_logs
        }
    }

    with folder.open('metadata.json', 'w') as fhandle:
        fhandle.write(json.dumps(metadata))

    EXPORT_LOGGER.debug('ADDING REPOSITORY FILES TO EXPORT ARCHIVE...')

    # If there are no nodes, there are no repository files to store
    if all_node_pks:
        all_node_uuids = {node_pk_2_uuid_mapping[_] for _ in all_node_pks}

        progress_bar = get_progress_bar(total=len(all_node_uuids),
                                        disable=silent)
        pbar_base_str = 'Exporting repository - '

        for uuid in all_node_uuids:
            sharded_uuid = export_shard_uuid(uuid)

            progress_bar.set_description_str(
                pbar_base_str + 'UUID={}'.format(uuid.split('-')[0]),
                refresh=False)
            progress_bar.update()

            # Important to set create=False, otherwise creates twice a subfolder. Maybe this is a bug of insert_path?
            thisnodefolder = nodesubfolder.get_subfolder(sharded_uuid,
                                                         create=False,
                                                         reset_limit=True)

            # Make sure the node's repository folder was not deleted
            src = RepositoryFolder(section=Repository._section_name, uuid=uuid)  # pylint: disable=protected-access
            if not src.exists():
                raise exceptions.ArchiveExportError(
                    'Unable to find the repository folder for Node with UUID={} in the local repository'
                    .format(uuid))

            # In this way, I copy the content of the folder, and not the folder itself
            thisnodefolder.insert_path(src=src.abspath, dest_name='.')

    close_progress_bar(leave=False)

    # Reset logging level
    if silent:
        logging.disable(level=logging.NOTSET)
コード例 #9
0
def export(
    entities: Optional[Iterable[Any]] = None,
    filename: Optional[str] = None,
    file_format: Union[str, Type[ArchiveWriterAbstract]] = ExportFileFormat.ZIP,
    overwrite: bool = False,
    silent: Optional[bool] = None,
    use_compression: Optional[bool] = None,
    include_comments: bool = True,
    include_logs: bool = True,
    allowed_licenses: Optional[Union[list, Callable]] = None,
    forbidden_licenses: Optional[Union[list, Callable]] = None,
    writer_init: Optional[Dict[str, Any]] = None,
    batch_size: int = 100,
    **traversal_rules: bool,
) -> ArchiveWriterAbstract:
    """Export AiiDA data to an archive file.

    Note, the logging level and progress reporter should be set externally, for example::

        from aiida.common.progress_reporter import set_progress_bar_tqdm

        EXPORT_LOGGER.setLevel('DEBUG')
        set_progress_bar_tqdm(leave=True)
        export(...)

    .. deprecated:: 1.5.0
        Support for the parameter `silent` will be removed in `v2.0.0`.
        Please set the log level and progress bar implementation independently.

    .. deprecated:: 1.5.0
        Support for the parameter `use_compression` will be removed in `v2.0.0`.
        Please use `writer_init={'use_compression': True}`.

    .. deprecated:: 1.2.1
        Support for the parameters `what` and `outfile` will be removed in `v2.0.0`.
        Please use `entities` and `filename` instead, respectively.

    :param entities: a list of entity instances;
        they can belong to different models/entities.

    :param filename: the filename (possibly including the absolute path)
        of the file on which to export.

    :param file_format: 'zip', 'tar.gz' or 'folder' or a specific writer class.

    :param overwrite: if True, overwrite the output file without asking, if it exists.
        If False, raise an
        :py:class:`~aiida.tools.importexport.common.exceptions.ArchiveExportError`
        if the output file already exists.

    :param allowed_licenses: List or function.
        If a list, then checks whether all licenses of Data nodes are in the list. If a function,
        then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.

    :param forbidden_licenses: List or function. If a list,
        then checks whether all licenses of Data nodes are in the list. If a function,
        then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.

    :param include_comments: In-/exclude export of comments for given node(s) in ``entities``.
        Default: True, *include* comments in export (as well as relevant users).

    :param include_logs: In-/exclude export of logs for given node(s) in ``entities``.
        Default: True, *include* logs in export.

    :param writer_init: Additional key-word arguments to pass to the writer class init

    :param batch_size: batch database query results in sub-collections to reduce memory usage

    :param traversal_rules: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules`
        what rule names are toggleable and what the defaults are.

    :returns: a dictionary of data regarding the export process (timings, etc)

    :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`:
        if there are any internal errors when exporting.
    :raises `~aiida.common.exceptions.LicensingException`:
        if any node is licensed under forbidden license.
    """
    # pylint: disable=too-many-locals,too-many-branches,too-many-statements

    # Backwards-compatibility
    entities = cast(
        Iterable[Any],
        deprecated_parameters(
            old={
                'name': 'what',
                'value': traversal_rules.pop('what', None)
            },
            new={
                'name': 'entities',
                'value': entities
            },
        ),
    )
    filename = cast(
        str,
        deprecated_parameters(
            old={
                'name': 'outfile',
                'value': traversal_rules.pop('outfile', None)
            },
            new={
                'name': 'filename',
                'value': filename
            },
        ),
    )
    if silent is not None:
        warnings.warn(
            'silent keyword is deprecated and will be removed in AiiDA v2.0.0, set the logger level explicitly instead',
            AiidaDeprecationWarning
        )  # pylint: disable=no-member

    type_check(
        entities,
        (list, tuple, set),
        msg='`entities` must be specified and given as a list of AiiDA entities',
    )
    entities = list(entities)
    if type_check(filename, str, allow_none=True) is None:
        filename = 'export_data.aiida'

    if not overwrite and os.path.exists(filename):
        raise exceptions.ArchiveExportError(f"The output file '{filename}' already exists")

    # validate the traversal rules and generate a full set for reporting
    validate_traversal_rules(GraphTraversalRules.EXPORT, **traversal_rules)
    full_traversal_rules = {
        name: traversal_rules.get(name, rule.default) for name, rule in GraphTraversalRules.EXPORT.value.items()
    }

    # setup the archive writer
    writer_init = writer_init or {}
    if use_compression is not None:
        warnings.warn(
            'use_compression argument is deprecated and will be removed in AiiDA v2.0.0 (which will always compress)',
            AiidaDeprecationWarning
        )  # pylint: disable=no-member
        writer_init['use_compression'] = use_compression
    if isinstance(file_format, str):
        writer = get_writer(file_format)(filepath=filename, **writer_init)
    elif issubclass(file_format, ArchiveWriterAbstract):
        writer = file_format(filepath=filename, **writer_init)
    else:
        raise TypeError('file_format must be a string or ArchiveWriterAbstract class')

    summary(
        file_format=writer.file_format_verbose,
        export_version=writer.export_version,
        outfile=filename,
        include_comments=include_comments,
        include_logs=include_logs,
        traversal_rules=full_traversal_rules
    )

    EXPORT_LOGGER.debug('STARTING EXPORT...')

    all_fields_info, unique_identifiers = get_all_fields_info()
    entities_starting_set, given_node_entry_ids = _get_starting_node_ids(entities)

    # Initialize the writer
    with writer as writer_context:

        # Iteratively explore the AiiDA graph to find further nodes that should also be exported
        with get_progress_reporter()(desc='Traversing provenance via links ...', total=1) as progress:
            traverse_output = get_nodes_export(starting_pks=given_node_entry_ids, get_links=True, **traversal_rules)
            progress.update()
        node_ids_to_be_exported = traverse_output['nodes']

        EXPORT_LOGGER.debug('WRITING METADATA...')

        writer_context.write_metadata(
            ArchiveMetadata(
                export_version=EXPORT_VERSION,
                aiida_version=get_version(),
                unique_identifiers=unique_identifiers,
                all_fields_info=all_fields_info,
                graph_traversal_rules=traverse_output['rules'],
                # Turn sets into lists to be able to export them as JSON metadata.
                entities_starting_set={
                    entity: list(entity_set) for entity, entity_set in entities_starting_set.items()
                },
                include_comments=include_comments,
                include_logs=include_logs,
            )
        )

        # Create a mapping of node PK to UUID.
        node_pk_2_uuid_mapping: Dict[int, str] = {}
        if node_ids_to_be_exported:
            qbuilder = orm.QueryBuilder().append(
                orm.Node,
                project=('id', 'uuid'),
                filters={'id': {
                    'in': node_ids_to_be_exported
                }},
            )
            node_pk_2_uuid_mapping = dict(qbuilder.all(batch_size=batch_size))

        # check that no nodes are being exported with incorrect licensing
        _check_node_licenses(node_ids_to_be_exported, allowed_licenses, forbidden_licenses)

        # write the link data
        if traverse_output['links'] is not None:
            with get_progress_reporter()(total=len(traverse_output['links']), desc='Writing links') as progress:
                for link in traverse_output['links']:
                    progress.update()
                    writer_context.write_link({
                        'input': node_pk_2_uuid_mapping[link.source_id],
                        'output': node_pk_2_uuid_mapping[link.target_id],
                        'label': link.link_label,
                        'type': link.link_type,
                    })

        # generate a list of queries to encapsulate all required entities
        entity_queries = _collect_entity_queries(
            node_ids_to_be_exported,
            entities_starting_set,
            node_pk_2_uuid_mapping,
            include_comments,
            include_logs,
        )

        total_entities = sum(query.count() for query in entity_queries.values())

        # write all entity data fields
        if total_entities:
            exported_entity_pks = _write_entity_data(
                total_entities=total_entities,
                entity_queries=entity_queries,
                writer=writer_context,
                batch_size=batch_size
            )
        else:
            exported_entity_pks = defaultdict(set)
            EXPORT_LOGGER.info('No entities were found to export')

        # write mappings of groups to the nodes they contain
        if exported_entity_pks[GROUP_ENTITY_NAME]:

            EXPORT_LOGGER.debug('Writing group UUID -> [nodes UUIDs]')

            _write_group_mappings(
                group_pks=exported_entity_pks[GROUP_ENTITY_NAME], batch_size=batch_size, writer=writer_context
            )

        # copy all required node repositories
        if exported_entity_pks[NODE_ENTITY_NAME]:

            _write_node_repositories(
                node_pks=exported_entity_pks[NODE_ENTITY_NAME],
                node_pk_2_uuid_mapping=node_pk_2_uuid_mapping,
                writer=writer_context
            )

        EXPORT_LOGGER.info('Finalizing Export...')

    # summarize export
    export_summary = '\n  - '.join(f'{name:<6}: {len(pks)}' for name, pks in exported_entity_pks.items())
    if exported_entity_pks:
        EXPORT_LOGGER.info('Exported Entities:\n  - ' + export_summary + '\n')
    # TODO
    # EXPORT_LOGGER.info('Writer Information:\n %s', writer.export_info)

    return writer
コード例 #10
0
def _get_starting_node_ids(entities: List[Any]) -> Tuple[DefaultDict[str, Set[str]], Set[int]]:
    """Get the starting node UUIDs and PKs

    :param entities: a list of entity instances

    :raises exceptions.ArchiveExportError:
    :return: entities_starting_set, given_node_entry_ids
    """
    entities_starting_set: DefaultDict[str, Set[str]] = defaultdict(set)
    given_node_entry_ids: Set[int] = set()

    # store a list of the actual dbnodes
    total = len(entities) + (1 if GROUP_ENTITY_NAME in entities_starting_set else 0)
    if not total:
        return entities_starting_set, given_node_entry_ids

    for entry in entities:

        if issubclass(entry.__class__, orm.Group):
            entities_starting_set[GROUP_ENTITY_NAME].add(entry.uuid)
        elif issubclass(entry.__class__, orm.Node):
            entities_starting_set[NODE_ENTITY_NAME].add(entry.uuid)
            given_node_entry_ids.add(entry.pk)
        elif issubclass(entry.__class__, orm.Computer):
            entities_starting_set[COMPUTER_ENTITY_NAME].add(entry.uuid)
        else:
            raise exceptions.ArchiveExportError(
                f'I was given {entry} ({type(entry)}),'
                ' which is not a Node, Computer, or Group instance'
            )

    # Add all the nodes contained within the specified groups
    if GROUP_ENTITY_NAME in entities_starting_set:

        # Use single query instead of given_group.nodes iterator for performance.
        qh_groups = (
            orm.QueryBuilder().append(
                orm.Group,
                filters={
                    'uuid': {
                        'in': entities_starting_set[GROUP_ENTITY_NAME]
                    }
                },
                tag='groups',
            ).queryhelp
        )
        node_query = orm.QueryBuilder(**qh_groups).append(orm.Node, project=['id', 'uuid'], with_group='groups')
        node_count = node_query.count()

        if node_count:
            with get_progress_reporter()(desc='Collecting nodes in groups', total=node_count) as progress:

                pks, uuids = [], []
                for pk, uuid in node_query.all():
                    progress.update()
                    pks.append(pk)
                    uuids.append(uuid)

            entities_starting_set[NODE_ENTITY_NAME].update(uuids)
            given_node_entry_ids.update(pks)

    return entities_starting_set, given_node_entry_ids