Esempio n. 1
0
def _get_model_fields(entity_name: str) -> List[str]:
    """Return a list of fields to retrieve for a particular entity

    :param entity_name: name of database entity, such as Node

    """
    all_fields_info, _ = get_all_fields_info()
    project_cols = ['id']
    entity_prop = all_fields_info[entity_name].keys()
    # Here we do the necessary renaming of properties
    for prop in entity_prop:
        # nprop contains the list of projections
        nprop = (
            file_fields_to_model_fields[entity_name][prop] if prop in file_fields_to_model_fields[entity_name] else prop
        )
        project_cols.append(nprop)
    return project_cols
Esempio n. 2
0
def fill_in_query(partial_query,
                  originating_entity_str,
                  current_entity_str,
                  tag_suffixes=None,
                  entity_separator='_'):
    """
    This function recursively constructs QueryBuilder queries that are needed
    for the SQLA export function. To manage to construct such queries, the
    relationship dictionary is consulted (which shows how to reference
    different AiiDA entities in QueryBuilder.
    To find the dependencies of the relationships of the exported data, the
    get_all_fields_info (which described the exported schema and its
    dependencies) is consulted.
    """
    if not tag_suffixes:
        tag_suffixes = []

    relationship_dic = {
        'Node': {
            'Computer': 'with_computer',
            'Group': 'with_group',
            'User': '******',
            'Log': 'with_log',
            'Comment': 'with_comment'
        },
        'Group': {
            'Node': 'with_node'
        },
        'Computer': {
            'Node': 'with_node'
        },
        'User': {
            'Node': 'with_node',
            'Group': 'with_group',
            'Comment': 'with_comment',
        },
        'Log': {
            'Node': 'with_node'
        },
        'Comment': {
            'Node': 'with_node',
            'User': '******'
        }
    }

    all_fields_info, _ = get_all_fields_info()

    entity_prop = all_fields_info[current_entity_str].keys()

    project_cols = ['id']
    for prop in entity_prop:
        nprop = prop
        if current_entity_str in file_fields_to_model_fields:
            if prop in file_fields_to_model_fields[current_entity_str]:
                nprop = file_fields_to_model_fields[current_entity_str][prop]

        project_cols.append(nprop)

    # Here we should reference the entity of the main query
    current_entity_mod = entity_names_to_entities[current_entity_str]

    rel_string = relationship_dic[current_entity_str][originating_entity_str]
    mydict = {rel_string: entity_separator.join(tag_suffixes)}

    partial_query.append(current_entity_mod,
                         tag=entity_separator.join(tag_suffixes) +
                         entity_separator + current_entity_str,
                         project=project_cols,
                         outerjoin=True,
                         **mydict)

    # prepare the recursion for the referenced entities
    foreign_fields = {
        k: v
        for k, v in all_fields_info[current_entity_str].items()
        # all_fields_info[model_name].items()
        if 'requires' in v
    }

    new_tag_suffixes = tag_suffixes + [current_entity_str]
    for value in foreign_fields.values():
        ref_model_name = value['requires']
        fill_in_query(partial_query, current_entity_str, ref_model_name,
                      new_tag_suffixes)
Esempio n. 3
0
def export_tree(what,
                folder,
                allowed_licenses=None,
                forbidden_licenses=None,
                silent=False,
                include_comments=True,
                include_logs=True,
                **kwargs):
    """Export the entries passed in the 'what' list to a file tree.

    :param what: a list of entity instances; they can belong to different models/entities.
    :type what: list

    :param folder: a temporary folder to build the archive before compression.
    :type folder: :py:class:`~aiida.common.folders.Folder`

    :param allowed_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type allowed_licenses: list

    :param forbidden_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type forbidden_licenses: list

    :param silent: suppress prints.
    :type silent: bool

    :param include_comments: In-/exclude export of comments for given node(s) in ``what``.
        Default: True, *include* comments in export (as well as relevant users).
    :type include_comments: bool

    :param include_logs: In-/exclude export of logs for given node(s) in ``what``.
        Default: True, *include* logs in export.
    :type include_logs: bool

    :param kwargs: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` what rule names
        are toggleable and what the defaults are.

    :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`: if there are any internal errors when
        exporting.
    :raises `~aiida.common.exceptions.LicensingException`: if any node is licensed under forbidden license.
    """
    from collections import defaultdict

    if not silent:
        print('STARTING EXPORT...')

    all_fields_info, unique_identifiers = get_all_fields_info()

    entities_starting_set = defaultdict(set)

    # The set that contains the nodes ids of the nodes that should be exported
    given_data_entry_ids = set()
    given_calculation_entry_ids = set()
    given_group_entry_ids = set()
    given_computer_entry_ids = set()
    given_groups = set()
    given_log_entry_ids = set()
    given_comment_entry_ids = set()

    # I store a list of the actual dbnodes
    for entry in what:
        # This returns the class name (as in imports). E.g. for a model node:
        # aiida.backends.djsite.db.models.DbNode
        # entry_class_string = get_class_string(entry)
        # Now a load the backend-independent name into entry_entity_name, e.g. Node!
        # entry_entity_name = schema_to_entity_names(entry_class_string)
        if issubclass(entry.__class__, orm.Group):
            entities_starting_set[GROUP_ENTITY_NAME].add(entry.uuid)
            given_group_entry_ids.add(entry.id)
            given_groups.add(entry)
        elif issubclass(entry.__class__, orm.Node):
            entities_starting_set[NODE_ENTITY_NAME].add(entry.uuid)
            if issubclass(entry.__class__, orm.Data):
                given_data_entry_ids.add(entry.pk)
            elif issubclass(entry.__class__, orm.ProcessNode):
                given_calculation_entry_ids.add(entry.pk)
        elif issubclass(entry.__class__, orm.Computer):
            entities_starting_set[COMPUTER_ENTITY_NAME].add(entry.uuid)
            given_computer_entry_ids.add(entry.pk)
        else:
            raise exceptions.ArchiveExportError(
                'I was given {} ({}), which is not a Node, Computer, or Group instance'
                .format(entry, type(entry)))

    # Add all the nodes contained within the specified groups
    for group in given_groups:
        for entry in group.nodes:
            entities_starting_set[NODE_ENTITY_NAME].add(entry.uuid)
            if issubclass(entry.__class__, orm.Data):
                given_data_entry_ids.add(entry.pk)
            elif issubclass(entry.__class__, orm.ProcessNode):
                given_calculation_entry_ids.add(entry.pk)

    for entity, entity_set in entities_starting_set.items():
        entities_starting_set[entity] = list(entity_set)

    # We will iteratively explore the AiiDA graph to find further nodes that
    # should also be exported.
    # At the same time, we will create the links_uuid list of dicts to be exported

    if not silent:
        print('RETRIEVING LINKED NODES AND STORING LINKS...')

    to_be_exported, links_uuid, graph_traversal_rules = retrieve_linked_nodes(
        given_calculation_entry_ids, given_data_entry_ids, **kwargs)

    ## Universal "entities" attributed to all types of nodes
    # Logs
    if include_logs and to_be_exported:
        # Get related log(s) - universal for all nodes
        builder = orm.QueryBuilder()
        builder.append(orm.Log,
                       filters={'dbnode_id': {
                           'in': to_be_exported
                       }},
                       project='id')
        res = {_[0] for _ in builder.all()}
        given_log_entry_ids.update(res)

    # Comments
    if include_comments and to_be_exported:
        # Get related log(s) - universal for all nodes
        builder = orm.QueryBuilder()
        builder.append(orm.Comment,
                       filters={'dbnode_id': {
                           'in': to_be_exported
                       }},
                       project='id')
        res = {_[0] for _ in builder.all()}
        given_comment_entry_ids.update(res)

    # Here we get all the columns that we plan to project per entity that we
    # would like to extract
    given_entities = list()
    if given_group_entry_ids:
        given_entities.append(GROUP_ENTITY_NAME)
    if to_be_exported:
        given_entities.append(NODE_ENTITY_NAME)
    if given_computer_entry_ids:
        given_entities.append(COMPUTER_ENTITY_NAME)
    if given_log_entry_ids:
        given_entities.append(LOG_ENTITY_NAME)
    if given_comment_entry_ids:
        given_entities.append(COMMENT_ENTITY_NAME)

    entries_to_add = dict()
    for given_entity in given_entities:
        project_cols = ['id']
        # The following gets a list of fields that we need,
        # e.g. user, mtime, uuid, computer
        entity_prop = all_fields_info[given_entity].keys()

        # Here we do the necessary renaming of properties
        for prop in entity_prop:
            # nprop contains the list of projections
            nprop = (file_fields_to_model_fields[given_entity][prop] if prop
                     in file_fields_to_model_fields[given_entity] else prop)
            project_cols.append(nprop)

        # Getting the ids that correspond to the right entity
        if given_entity == GROUP_ENTITY_NAME:
            entry_ids_to_add = given_group_entry_ids
        elif given_entity == NODE_ENTITY_NAME:
            entry_ids_to_add = to_be_exported
        elif given_entity == COMPUTER_ENTITY_NAME:
            entry_ids_to_add = given_computer_entry_ids
        elif given_entity == LOG_ENTITY_NAME:
            entry_ids_to_add = given_log_entry_ids
        elif given_entity == COMMENT_ENTITY_NAME:
            entry_ids_to_add = given_comment_entry_ids

        builder = orm.QueryBuilder()
        builder.append(entity_names_to_entities[given_entity],
                       filters={'id': {
                           'in': entry_ids_to_add
                       }},
                       project=project_cols,
                       tag=given_entity,
                       outerjoin=True)
        entries_to_add[given_entity] = builder

    # TODO (Spyros) To see better! Especially for functional licenses
    # Check the licenses of exported data.
    if allowed_licenses is not None or forbidden_licenses is not None:
        builder = orm.QueryBuilder()
        builder.append(orm.Node,
                       project=['id', 'attributes.source.license'],
                       filters={'id': {
                           'in': to_be_exported
                       }})
        # Skip those nodes where the license is not set (this is the standard behavior with Django)
        node_licenses = list(
            (a, b) for [a, b] in builder.all() if b is not None)
        check_licenses(node_licenses, allowed_licenses, forbidden_licenses)

    ############################################################
    ##### Start automatic recursive export data generation #####
    ############################################################
    if not silent:
        print('STORING DATABASE ENTRIES...')

    export_data = dict()
    entity_separator = '_'
    for entity_name, partial_query in entries_to_add.items():

        foreign_fields = {
            k: v
            for k, v in all_fields_info[entity_name].items()
            # all_fields_info[model_name].items()
            if 'requires' in v
        }

        for value in foreign_fields.values():
            ref_model_name = value['requires']
            fill_in_query(partial_query, entity_name, ref_model_name,
                          [entity_name], entity_separator)

        for temp_d in partial_query.iterdict():
            for k in temp_d.keys():
                # Get current entity
                current_entity = k.split(entity_separator)[-1]

                # This is a empty result of an outer join.
                # It should not be taken into account.
                if temp_d[k]['id'] is None:
                    continue

                temp_d2 = {
                    temp_d[k]['id']:
                    serialize_dict(temp_d[k],
                                   remove_fields=['id'],
                                   rename_fields=model_fields_to_file_fields[
                                       current_entity])
                }
                try:
                    export_data[current_entity].update(temp_d2)
                except KeyError:
                    export_data[current_entity] = temp_d2

    #######################################
    # Manually manage attributes and extras
    #######################################
    # I use .get because there may be no nodes to export
    all_nodes_pk = list()
    if NODE_ENTITY_NAME in export_data:
        all_nodes_pk.extend(export_data.get(NODE_ENTITY_NAME).keys())

    if sum(len(model_data) for model_data in export_data.values()) == 0:
        if not silent:
            print('No nodes to store, exiting...')
        return

    if not silent:
        print('Exporting a total of {} db entries, of which {} nodes.'.format(
            sum(len(model_data) for model_data in export_data.values()),
            len(all_nodes_pk)))

    # ATTRIBUTES and EXTRAS
    if not silent:
        print('STORING NODE ATTRIBUTES AND EXTRAS...')
    node_attributes = {}
    node_extras = {}

    # A second QueryBuilder query to get the attributes and extras. See if this can be optimized
    if all_nodes_pk:
        all_nodes_query = orm.QueryBuilder()
        all_nodes_query.append(orm.Node,
                               filters={'id': {
                                   'in': all_nodes_pk
                               }},
                               project=['id', 'attributes', 'extras'])
        for res_pk, res_attributes, res_extras in all_nodes_query.iterall():
            node_attributes[str(res_pk)] = res_attributes
            node_extras[str(res_pk)] = res_extras

    if not silent:
        print('STORING GROUP ELEMENTS...')
    groups_uuid = dict()
    # If a group is in the exported date, we export the group/node correlation
    if GROUP_ENTITY_NAME in export_data:
        for curr_group in export_data[GROUP_ENTITY_NAME]:
            group_uuid_qb = orm.QueryBuilder()
            group_uuid_qb.append(entity_names_to_entities[GROUP_ENTITY_NAME],
                                 filters={'id': {
                                     '==': curr_group
                                 }},
                                 project=['uuid'],
                                 tag='group')
            group_uuid_qb.append(entity_names_to_entities[NODE_ENTITY_NAME],
                                 project=['uuid'],
                                 with_group='group')
            for res in group_uuid_qb.iterall():
                if str(res[0]) in groups_uuid:
                    groups_uuid[str(res[0])].append(str(res[1]))
                else:
                    groups_uuid[str(res[0])] = [str(res[1])]

    #######################################
    # Final check for unsealed ProcessNodes
    #######################################
    process_nodes = set()
    for node_pk, content in export_data.get(NODE_ENTITY_NAME, {}).items():
        if content['node_type'].startswith('process.'):
            process_nodes.add(node_pk)

    check_process_nodes_sealed(process_nodes)

    ######################################
    # Now I store
    ######################################
    # subfolder inside the export package
    nodesubfolder = folder.get_subfolder(NODES_EXPORT_SUBFOLDER,
                                         create=True,
                                         reset_limit=True)

    if not silent:
        print('STORING DATA...')

    data = {
        'node_attributes': node_attributes,
        'node_extras': node_extras,
        'export_data': export_data,
        'links_uuid': links_uuid,
        'groups_uuid': groups_uuid
    }

    # N.B. We're really calling zipfolder.open (if exporting a zipfile)
    with folder.open('data.json', mode='w') as fhandle:
        # fhandle.write(json.dumps(data, cls=UUIDEncoder))
        fhandle.write(json.dumps(data))

    # Add proper signature to unique identifiers & all_fields_info
    # Ignore if a key doesn't exist in any of the two dictionaries

    metadata = {
        'aiida_version': get_version(),
        'export_version': EXPORT_VERSION,
        'all_fields_info': all_fields_info,
        'unique_identifiers': unique_identifiers,
        'export_parameters': {
            'graph_traversal_rules': graph_traversal_rules,
            'entities_starting_set': entities_starting_set,
            'include_comments': include_comments,
            'include_logs': include_logs
        }
    }

    with folder.open('metadata.json', 'w') as fhandle:
        fhandle.write(json.dumps(metadata))

    if silent is not True:
        print('STORING REPOSITORY FILES...')

    # If there are no nodes, there are no repository files to store
    if all_nodes_pk:
        # Large speed increase by not getting the node itself and looping in memory in python, but just getting the uuid
        uuid_query = orm.QueryBuilder()
        uuid_query.append(orm.Node,
                          filters={'id': {
                              'in': all_nodes_pk
                          }},
                          project=['uuid'])
        for res in uuid_query.all():
            uuid = str(res[0])
            sharded_uuid = export_shard_uuid(uuid)

            # Important to set create=False, otherwise creates twice a subfolder. Maybe this is a bug of insert_path?
            thisnodefolder = nodesubfolder.get_subfolder(sharded_uuid,
                                                         create=False,
                                                         reset_limit=True)

            # Make sure the node's repository folder was not deleted
            src = RepositoryFolder(section=Repository._section_name, uuid=uuid)  # pylint: disable=protected-access
            if not src.exists():
                raise exceptions.ArchiveExportError(
                    'Unable to find the repository folder for Node with UUID={} in the local repository'
                    .format(uuid))

            # In this way, I copy the content of the folder, and not the folder itself
            thisnodefolder.insert_path(src=src.abspath, dest_name='.')
Esempio n. 4
0
def get_export_data(entries_queries: Dict[str, orm.QueryBuilder],
                    silent: bool) -> Dict[str, Dict[int, dict]]:
    """Start automatic recursive export data generation

    :param entries_queries: partial queries for all entities to export
    :param silent: suppress console prints and progress bar.

    :return: export data mappings by entity type -> pk -> db_columns, e.g.
        {'ENTITY_NAME': {<pk>: {'uuid': 'abc', ...}, ...}, ...}
        Note: this data does not yet contain attributes and extras

    """
    EXPORT_LOGGER.debug("GATHERING DATABASE ENTRIES...")

    all_fields_info, _ = get_all_fields_info()

    if entries_queries:
        progress_bar = get_progress_bar(total=len(entries_queries),
                                        disable=silent)

    export_data = defaultdict(dict)  # type: dict
    entity_separator = "_"
    for entity_name, partial_query in entries_queries.items():

        progress_bar.set_description_str(f"Exporting {entity_name}s",
                                         refresh=False)
        progress_bar.update()

        foreign_fields = {
            k: v
            for k, v in all_fields_info[entity_name].items() if "requires" in v
        }

        for value in foreign_fields.values():
            ref_model_name = value["requires"]
            fill_in_query(
                partial_query,
                entity_name,
                ref_model_name,
                [entity_name],
                entity_separator,
            )

        for temp_d in partial_query.iterdict():
            for key in temp_d:
                # Get current entity
                current_entity = key.split(entity_separator)[-1]

                # This is a empty result of an outer join.
                # It should not be taken into account.
                if temp_d[key]["id"] is None:
                    continue

                export_data[current_entity].update({
                    temp_d[key]["id"]:
                    serialize_dict(
                        temp_d[key],
                        remove_fields=["id"],
                        rename_fields=model_fields_to_file_fields[
                            current_entity],
                    )
                })

    return export_data
Esempio n. 5
0
def get_entry_queries(
    node_ids_to_be_exported: Set[int],
    entities_starting_set: DefaultDict[str, Set[str]],
    node_pk_2_uuid_mapping: Dict[int, str],
    silent: bool,
    include_comments: bool = True,
    include_logs: bool = True,
) -> Dict[str, orm.QueryBuilder]:
    """Gather partial queries for all entities to export."""
    # Progress bar initialization - Entities
    progress_bar = get_progress_bar(total=1, disable=silent)
    progress_bar.set_description_str("Initializing export of all entities",
                                     refresh=True)

    given_log_entry_ids = set()
    given_comment_entry_ids = set()
    all_fields_info, _ = get_all_fields_info()

    # Universal "entities" attributed to all types of nodes
    # Logs
    if include_logs and node_ids_to_be_exported:
        # Get related log(s) - universal for all nodes
        builder = orm.QueryBuilder()
        builder.append(
            orm.Log,
            filters={"dbnode_id": {
                "in": node_ids_to_be_exported
            }},
            project="uuid",
        )
        res = set(builder.all(flat=True))
        given_log_entry_ids.update(res)

    # Comments
    if include_comments and node_ids_to_be_exported:
        # Get related log(s) - universal for all nodes
        builder = orm.QueryBuilder()
        builder.append(
            orm.Comment,
            filters={"dbnode_id": {
                "in": node_ids_to_be_exported
            }},
            project="uuid",
        )
        res = set(builder.all(flat=True))
        given_comment_entry_ids.update(res)

    # Here we get all the columns that we plan to project per entity that we would like to extract
    given_entities = set(entities_starting_set.keys())
    if node_ids_to_be_exported:
        given_entities.add(NODE_ENTITY_NAME)
    if given_log_entry_ids:
        given_entities.add(LOG_ENTITY_NAME)
    if given_comment_entry_ids:
        given_entities.add(COMMENT_ENTITY_NAME)

    progress_bar.update()

    if given_entities:
        progress_bar = get_progress_bar(total=len(given_entities),
                                        disable=silent)
        pbar_base_str = "Preparing entities"

    entries_to_add = {}
    for given_entity in given_entities:
        progress_bar.set_description_str(f"{pbar_base_str} - {given_entity}s",
                                         refresh=False)
        progress_bar.update()

        project_cols = ["id"]
        # The following gets a list of fields that we need,
        # e.g. user, mtime, uuid, computer
        entity_prop = all_fields_info[given_entity].keys()

        # Here we do the necessary renaming of properties
        for prop in entity_prop:
            # nprop contains the list of projections
            nprop = (file_fields_to_model_fields[given_entity][prop] if prop
                     in file_fields_to_model_fields[given_entity] else prop)
            project_cols.append(nprop)

        # Getting the ids that correspond to the right entity
        entry_uuids_to_add = entities_starting_set.get(given_entity, set())
        if not entry_uuids_to_add:
            if given_entity == LOG_ENTITY_NAME:
                entry_uuids_to_add = given_log_entry_ids
            elif given_entity == COMMENT_ENTITY_NAME:
                entry_uuids_to_add = given_comment_entry_ids
        elif given_entity == NODE_ENTITY_NAME:
            entry_uuids_to_add.update(
                {node_pk_2_uuid_mapping[_]
                 for _ in node_ids_to_be_exported})

        builder = orm.QueryBuilder()
        builder.append(
            entity_names_to_entities[given_entity],
            filters={"uuid": {
                "in": entry_uuids_to_add
            }},
            project=project_cols,
            tag=given_entity,
            outerjoin=True,
        )
        entries_to_add[given_entity] = builder

    return entries_to_add
Esempio n. 6
0
def export_tree(
    entities: Optional[Iterable[Any]] = None,
    folder: Optional[Union[Folder, ZipFolder]] = None,
    allowed_licenses: Optional[Union[list, Callable]] = None,
    forbidden_licenses: Optional[Union[list, Callable]] = None,
    silent: bool = False,
    include_comments: bool = True,
    include_logs: bool = True,
    **traversal_rules: bool,
) -> None:
    """Export the entries passed in the 'entities' list to a file tree.

    .. deprecated:: 1.2.1
        Support for the parameter `what` will be removed in `v2.0.0`. Please use `entities` instead.

    :param entities: a list of entity instances; they can belong to different models/entities.

    :param folder: a temporary folder to build the archive before compression.

    :param allowed_licenses: List or function.
        If a list, then checks whether all licenses of Data nodes are in the list.
        If a function, then calls function for licenses of Data nodes,
        expecting True if license is allowed, False otherwise.

    :param forbidden_licenses: List or function.
        If a list, then checks whether all licenses of Data nodes are in the list.
        If a function, then calls function for licenses of Data nodes,
        expecting True if license is allowed, False otherwise.

    :param silent: suppress console prints and progress bar.

    :param include_comments: In-/exclude export of comments for given node(s) in ``entities``.
        Default: True, *include* comments in export (as well as relevant users).

    :param include_logs: In-/exclude export of logs for given node(s) in ``entities``.
        Default: True, *include* logs in export.

    :param traversal_rules: graph traversal rules.
        See :const:`aiida.common.links.GraphTraversalRules`
        what rule names are toggleable and what the defaults are.

    :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`:
        if there are any internal errors when exporting.
    :raises `~aiida.common.exceptions.LicensingException`:
        if any node is licensed under forbidden license.
    """

    if silent:
        logging.disable(level=logging.CRITICAL)

    EXPORT_LOGGER.debug("STARTING EXPORT...")

    # Backwards-compatibility
    entities = cast(
        Iterable[Any],
        deprecated_parameters(
            old={
                "name": "what",
                "value": traversal_rules.pop("what", None)
            },
            new={
                "name": "entities",
                "value": entities
            },
        ),
    )

    type_check(
        entities,
        (list, tuple, set),
        msg=
        "`entities` must be specified and given as a list of AiiDA entities",
    )
    entities = list(entities)

    type_check(
        folder,
        (Folder, ZipFolder),
        msg="`folder` must be specified and given as an AiiDA Folder entity",
    )
    folder = cast(Union[Folder, ZipFolder], folder)

    all_fields_info, unique_identifiers = get_all_fields_info()

    entities_starting_set, given_node_entry_ids = get_starting_node_ids(
        entities, silent)

    (
        node_ids_to_be_exported,
        node_pk_2_uuid_mapping,
        links_uuid,
        traversal_rules,
    ) = collect_export_nodes(given_node_entry_ids, silent, **traversal_rules)

    check_node_licenses(node_ids_to_be_exported, allowed_licenses,
                        forbidden_licenses)

    entries_queries = get_entry_queries(
        node_ids_to_be_exported,
        entities_starting_set,
        node_pk_2_uuid_mapping,
        silent,
        include_comments,
        include_logs,
    )

    export_data = get_export_data(entries_queries, silent)

    # Close progress up until this point in order to print properly
    close_progress_bar(leave=False)

    # note this was originally below the attributes and group_uuid gather
    check_process_nodes_sealed({
        node_pk
        for node_pk, content in export_data.get(NODE_ENTITY_NAME, {}).items()
        if content["node_type"].startswith("process.")
    })

    model_data = sum(len(model_data) for model_data in export_data.values())
    if not model_data:
        EXPORT_LOGGER.log(msg="Nothing to store, exiting...",
                          level=LOG_LEVEL_REPORT)
        return
    EXPORT_LOGGER.log(
        msg=(f"Exporting a total of {model_data} database entries, "
             f"of which {len(node_ids_to_be_exported)} are Nodes."),
        level=LOG_LEVEL_REPORT,
    )

    node_attributes, node_extras = get_node_data(export_data,
                                                 node_ids_to_be_exported,
                                                 silent)
    groups_uuid = get_groups_uuid(export_data, silent)

    # Turn sets into lists to be able to export them as JSON metadata.
    for entity, entity_set in entities_starting_set.items():
        entities_starting_set[entity] = list(entity_set)  # type: ignore

    metadata = {
        "aiida_version": get_version(),
        "export_version": EXPORT_VERSION,
        "all_fields_info": all_fields_info,
        "unique_identifiers": unique_identifiers,
        "export_parameters": {
            "graph_traversal_rules": traversal_rules,
            "entities_starting_set": entities_starting_set,
            "include_comments": include_comments,
            "include_logs": include_logs,
        },
    }

    all_node_uuids = {
        node_pk_2_uuid_mapping[_]
        for _ in node_ids_to_be_exported
    }

    write_to_archive(
        folder,
        metadata,
        all_node_uuids,
        export_data,
        node_attributes,
        node_extras,
        groups_uuid,
        links_uuid,
        silent,
    )

    close_progress_bar(leave=False)

    # Reset logging level
    if silent:
        logging.disable(level=logging.NOTSET)
Esempio n. 7
0
def export_tree(entities=None,
                folder=None,
                allowed_licenses=None,
                forbidden_licenses=None,
                silent=False,
                include_comments=True,
                include_logs=True,
                **kwargs):
    """Export the entries passed in the 'entities' list to a file tree.

    .. deprecated:: 1.2.1
        Support for the parameter `what` will be removed in `v2.0.0`. Please use `entities` instead.

    :param entities: a list of entity instances; they can belong to different models/entities.
    :type entities: list

    :param folder: a temporary folder to build the archive before compression.
    :type folder: :py:class:`~aiida.common.folders.Folder`

    :param allowed_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type allowed_licenses: list

    :param forbidden_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type forbidden_licenses: list

    :param silent: suppress console prints and progress bar.
    :type silent: bool

    :param include_comments: In-/exclude export of comments for given node(s) in ``entities``.
        Default: True, *include* comments in export (as well as relevant users).
    :type include_comments: bool

    :param include_logs: In-/exclude export of logs for given node(s) in ``entities``.
        Default: True, *include* logs in export.
    :type include_logs: bool

    :param kwargs: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` what rule names
        are toggleable and what the defaults are.

    :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`: if there are any internal errors when
        exporting.
    :raises `~aiida.common.exceptions.LicensingException`: if any node is licensed under forbidden license.
    """
    from collections import defaultdict
    from aiida.tools.graph.graph_traversers import get_nodes_export

    if silent:
        logging.disable(level=logging.CRITICAL)

    EXPORT_LOGGER.debug('STARTING EXPORT...')

    # Backwards-compatibility
    entities = deprecated_parameters(
        old={
            'name': 'what',
            'value': kwargs.pop('what', None)
        },
        new={
            'name': 'entities',
            'value': entities
        },
    )

    type_check(
        entities, (list, tuple, set),
        msg='`entities` must be specified and given as a list of AiiDA entities'
    )
    entities = list(entities)

    type_check(
        folder, (Folder, ZipFolder),
        msg='`folder` must be specified and given as an AiiDA Folder entity')

    all_fields_info, unique_identifiers = get_all_fields_info()

    entities_starting_set = defaultdict(set)

    # The set that contains the nodes ids of the nodes that should be exported
    given_node_entry_ids = set()
    given_log_entry_ids = set()
    given_comment_entry_ids = set()

    # Instantiate progress bar - go through list of `entities`
    pbar_total = len(entities) + 1 if entities else 1
    progress_bar = get_progress_bar(total=pbar_total,
                                    leave=False,
                                    disable=silent)
    progress_bar.set_description_str('Collecting chosen entities',
                                     refresh=False)

    # I store a list of the actual dbnodes
    for entry in entities:
        progress_bar.update()

        # This returns the class name (as in imports). E.g. for a model node:
        # aiida.backends.djsite.db.models.DbNode
        # entry_class_string = get_class_string(entry)
        # Now a load the backend-independent name into entry_entity_name, e.g. Node!
        # entry_entity_name = schema_to_entity_names(entry_class_string)
        if issubclass(entry.__class__, orm.Group):
            entities_starting_set[GROUP_ENTITY_NAME].add(entry.uuid)
        elif issubclass(entry.__class__, orm.Node):
            entities_starting_set[NODE_ENTITY_NAME].add(entry.uuid)
            given_node_entry_ids.add(entry.pk)
        elif issubclass(entry.__class__, orm.Computer):
            entities_starting_set[COMPUTER_ENTITY_NAME].add(entry.uuid)
        else:
            raise exceptions.ArchiveExportError(
                'I was given {} ({}), which is not a Node, Computer, or Group instance'
                .format(entry, type(entry)))

    # Add all the nodes contained within the specified groups
    if GROUP_ENTITY_NAME in entities_starting_set:

        progress_bar.set_description_str('Retrieving Nodes from Groups ...',
                                         refresh=True)

        # Use single query instead of given_group.nodes iterator for performance.
        qh_groups = orm.QueryBuilder().append(
            orm.Group,
            filters={
                'uuid': {
                    'in': entities_starting_set[GROUP_ENTITY_NAME]
                }
            },
            tag='groups').queryhelp

        # Delete this import once the dbexport.zip module has been renamed
        from builtins import zip  # pylint: disable=redefined-builtin

        node_results = orm.QueryBuilder(**qh_groups).append(
            orm.Node, project=['id', 'uuid'], with_group='groups').all()
        if node_results:
            pks, uuids = map(list, zip(*node_results))
            entities_starting_set[NODE_ENTITY_NAME].update(uuids)
            given_node_entry_ids.update(pks)
            del node_results, pks, uuids

        progress_bar.update()

    # We will iteratively explore the AiiDA graph to find further nodes that should also be exported.
    # At the same time, we will create the links_uuid list of dicts to be exported

    progress_bar = get_progress_bar(total=1, disable=silent)
    progress_bar.set_description_str(
        'Getting provenance and storing links ...', refresh=True)

    traverse_output = get_nodes_export(starting_pks=given_node_entry_ids,
                                       get_links=True,
                                       **kwargs)
    node_ids_to_be_exported = traverse_output['nodes']
    graph_traversal_rules = traverse_output['rules']

    # A utility dictionary for mapping PK to UUID.
    if node_ids_to_be_exported:
        qbuilder = orm.QueryBuilder().append(
            orm.Node,
            project=('id', 'uuid'),
            filters={'id': {
                'in': node_ids_to_be_exported
            }},
        )
        node_pk_2_uuid_mapping = dict(qbuilder.all())
    else:
        node_pk_2_uuid_mapping = {}

    # The set of tuples now has to be transformed to a list of dicts
    links_uuid = [{
        'input': node_pk_2_uuid_mapping[link.source_id],
        'output': node_pk_2_uuid_mapping[link.target_id],
        'label': link.link_label,
        'type': link.link_type
    } for link in traverse_output['links']]

    progress_bar.update()

    # Progress bar initialization - Entities
    progress_bar = get_progress_bar(total=1, disable=silent)
    progress_bar.set_description_str('Initializing export of all entities',
                                     refresh=True)

    ## Universal "entities" attributed to all types of nodes
    # Logs
    if include_logs and node_ids_to_be_exported:
        # Get related log(s) - universal for all nodes
        builder = orm.QueryBuilder()
        builder.append(orm.Log,
                       filters={'dbnode_id': {
                           'in': node_ids_to_be_exported
                       }},
                       project='uuid')
        res = set(builder.all(flat=True))
        given_log_entry_ids.update(res)

    # Comments
    if include_comments and node_ids_to_be_exported:
        # Get related log(s) - universal for all nodes
        builder = orm.QueryBuilder()
        builder.append(orm.Comment,
                       filters={'dbnode_id': {
                           'in': node_ids_to_be_exported
                       }},
                       project='uuid')
        res = set(builder.all(flat=True))
        given_comment_entry_ids.update(res)

    # Here we get all the columns that we plan to project per entity that we would like to extract
    given_entities = set(entities_starting_set.keys())
    if node_ids_to_be_exported:
        given_entities.add(NODE_ENTITY_NAME)
    if given_log_entry_ids:
        given_entities.add(LOG_ENTITY_NAME)
    if given_comment_entry_ids:
        given_entities.add(COMMENT_ENTITY_NAME)

    progress_bar.update()

    if given_entities:
        progress_bar = get_progress_bar(total=len(given_entities),
                                        disable=silent)
        pbar_base_str = 'Preparing entities'

    entries_to_add = dict()
    for given_entity in given_entities:
        progress_bar.set_description_str(pbar_base_str +
                                         ' - {}s'.format(given_entity),
                                         refresh=False)
        progress_bar.update()

        project_cols = ['id']
        # The following gets a list of fields that we need,
        # e.g. user, mtime, uuid, computer
        entity_prop = all_fields_info[given_entity].keys()

        # Here we do the necessary renaming of properties
        for prop in entity_prop:
            # nprop contains the list of projections
            nprop = (file_fields_to_model_fields[given_entity][prop] if prop
                     in file_fields_to_model_fields[given_entity] else prop)
            project_cols.append(nprop)

        # Getting the ids that correspond to the right entity
        entry_uuids_to_add = entities_starting_set.get(given_entity, set())
        if not entry_uuids_to_add:
            if given_entity == LOG_ENTITY_NAME:
                entry_uuids_to_add = given_log_entry_ids
            elif given_entity == COMMENT_ENTITY_NAME:
                entry_uuids_to_add = given_comment_entry_ids
        elif given_entity == NODE_ENTITY_NAME:
            entry_uuids_to_add.update(
                {node_pk_2_uuid_mapping[_]
                 for _ in node_ids_to_be_exported})

        builder = orm.QueryBuilder()
        builder.append(entity_names_to_entities[given_entity],
                       filters={'uuid': {
                           'in': entry_uuids_to_add
                       }},
                       project=project_cols,
                       tag=given_entity,
                       outerjoin=True)
        entries_to_add[given_entity] = builder

    # TODO (Spyros) To see better! Especially for functional licenses
    # Check the licenses of exported data.
    if allowed_licenses is not None or forbidden_licenses is not None:
        builder = orm.QueryBuilder()
        builder.append(orm.Node,
                       project=['id', 'attributes.source.license'],
                       filters={'id': {
                           'in': node_ids_to_be_exported
                       }})
        # Skip those nodes where the license is not set (this is the standard behavior with Django)
        node_licenses = list(
            (a, b) for [a, b] in builder.all() if b is not None)
        check_licenses(node_licenses, allowed_licenses, forbidden_licenses)

    ############################################################
    ##### Start automatic recursive export data generation #####
    ############################################################
    EXPORT_LOGGER.debug('GATHERING DATABASE ENTRIES...')

    if entries_to_add:
        progress_bar = get_progress_bar(total=len(entries_to_add),
                                        disable=silent)

    export_data = defaultdict(dict)
    entity_separator = '_'
    for entity_name, partial_query in entries_to_add.items():

        progress_bar.set_description_str('Exporting {}s'.format(entity_name),
                                         refresh=False)
        progress_bar.update()

        foreign_fields = {
            k: v
            for k, v in all_fields_info[entity_name].items() if 'requires' in v
        }

        for value in foreign_fields.values():
            ref_model_name = value['requires']
            fill_in_query(partial_query, entity_name, ref_model_name,
                          [entity_name], entity_separator)

        for temp_d in partial_query.iterdict():
            for key in temp_d:
                # Get current entity
                current_entity = key.split(entity_separator)[-1]

                # This is a empty result of an outer join.
                # It should not be taken into account.
                if temp_d[key]['id'] is None:
                    continue

                export_data[current_entity].update({
                    temp_d[key]['id']:
                    serialize_dict(temp_d[key],
                                   remove_fields=['id'],
                                   rename_fields=model_fields_to_file_fields[
                                       current_entity])
                })

    # Close progress up until this point in order to print properly
    close_progress_bar(leave=False)

    #######################################
    # Manually manage attributes and extras
    #######################################
    # Pointer. Renaming, since Nodes have now technically been retrieved and "stored"
    all_node_pks = node_ids_to_be_exported

    model_data = sum(len(model_data) for model_data in export_data.values())
    if not model_data:
        EXPORT_LOGGER.log(msg='Nothing to store, exiting...',
                          level=LOG_LEVEL_REPORT)
        return
    EXPORT_LOGGER.log(
        msg='Exporting a total of {} database entries, of which {} are Nodes.'.
        format(model_data, len(all_node_pks)),
        level=LOG_LEVEL_REPORT)

    # Instantiate new progress bar
    progress_bar = get_progress_bar(total=1, leave=False, disable=silent)

    # ATTRIBUTES and EXTRAS
    EXPORT_LOGGER.debug('GATHERING NODE ATTRIBUTES AND EXTRAS...')
    node_attributes = {}
    node_extras = {}

    # Another QueryBuilder query to get the attributes and extras. TODO: See if this can be optimized
    if all_node_pks:
        all_nodes_query = orm.QueryBuilder().append(
            orm.Node,
            filters={'id': {
                'in': all_node_pks
            }},
            project=['id', 'attributes', 'extras'])

        progress_bar = get_progress_bar(total=all_nodes_query.count(),
                                        disable=silent)
        progress_bar.set_description_str('Exporting Attributes and Extras',
                                         refresh=False)

        for node_pk, attributes, extras in all_nodes_query.iterall():
            progress_bar.update()

            node_attributes[str(node_pk)] = attributes
            node_extras[str(node_pk)] = extras

    EXPORT_LOGGER.debug('GATHERING GROUP ELEMENTS...')
    groups_uuid = defaultdict(list)
    # If a group is in the exported data, we export the group/node correlation
    if GROUP_ENTITY_NAME in export_data:
        group_uuids_with_node_uuids = orm.QueryBuilder().append(
            orm.Group,
            filters={
                'id': {
                    'in': export_data[GROUP_ENTITY_NAME]
                }
            },
            project='uuid',
            tag='groups').append(orm.Node, project='uuid', with_group='groups')

        # This part is _only_ for the progress bar
        total_node_uuids_for_groups = group_uuids_with_node_uuids.count()
        if total_node_uuids_for_groups:
            progress_bar = get_progress_bar(total=total_node_uuids_for_groups,
                                            disable=silent)
            progress_bar.set_description_str('Exporting Groups ...',
                                             refresh=False)

        for group_uuid, node_uuid in group_uuids_with_node_uuids.iterall():
            progress_bar.update()

            groups_uuid[group_uuid].append(node_uuid)

    #######################################
    # Final check for unsealed ProcessNodes
    #######################################
    process_nodes = set()
    for node_pk, content in export_data.get(NODE_ENTITY_NAME, {}).items():
        if content['node_type'].startswith('process.'):
            process_nodes.add(node_pk)

    check_process_nodes_sealed(process_nodes)

    ######################################
    # Now collecting and storing
    ######################################
    # subfolder inside the export package
    nodesubfolder = folder.get_subfolder(NODES_EXPORT_SUBFOLDER,
                                         create=True,
                                         reset_limit=True)

    EXPORT_LOGGER.debug('ADDING DATA TO EXPORT ARCHIVE...')

    data = {
        'node_attributes': node_attributes,
        'node_extras': node_extras,
        'export_data': export_data,
        'links_uuid': links_uuid,
        'groups_uuid': groups_uuid
    }

    # N.B. We're really calling zipfolder.open (if exporting a zipfile)
    with folder.open('data.json', mode='w') as fhandle:
        # fhandle.write(json.dumps(data, cls=UUIDEncoder))
        fhandle.write(json.dumps(data))

    # Turn sets into lists to be able to export them as JSON metadata.
    for entity, entity_set in entities_starting_set.items():
        entities_starting_set[entity] = list(entity_set)

    metadata = {
        'aiida_version': get_version(),
        'export_version': EXPORT_VERSION,
        'all_fields_info': all_fields_info,
        'unique_identifiers': unique_identifiers,
        'export_parameters': {
            'graph_traversal_rules': graph_traversal_rules,
            'entities_starting_set': entities_starting_set,
            'include_comments': include_comments,
            'include_logs': include_logs
        }
    }

    with folder.open('metadata.json', 'w') as fhandle:
        fhandle.write(json.dumps(metadata))

    EXPORT_LOGGER.debug('ADDING REPOSITORY FILES TO EXPORT ARCHIVE...')

    # If there are no nodes, there are no repository files to store
    if all_node_pks:
        all_node_uuids = {node_pk_2_uuid_mapping[_] for _ in all_node_pks}

        progress_bar = get_progress_bar(total=len(all_node_uuids),
                                        disable=silent)
        pbar_base_str = 'Exporting repository - '

        for uuid in all_node_uuids:
            sharded_uuid = export_shard_uuid(uuid)

            progress_bar.set_description_str(
                pbar_base_str + 'UUID={}'.format(uuid.split('-')[0]),
                refresh=False)
            progress_bar.update()

            # Important to set create=False, otherwise creates twice a subfolder. Maybe this is a bug of insert_path?
            thisnodefolder = nodesubfolder.get_subfolder(sharded_uuid,
                                                         create=False,
                                                         reset_limit=True)

            # Make sure the node's repository folder was not deleted
            src = RepositoryFolder(section=Repository._section_name, uuid=uuid)  # pylint: disable=protected-access
            if not src.exists():
                raise exceptions.ArchiveExportError(
                    'Unable to find the repository folder for Node with UUID={} in the local repository'
                    .format(uuid))

            # In this way, I copy the content of the folder, and not the folder itself
            thisnodefolder.insert_path(src=src.abspath, dest_name='.')

    close_progress_bar(leave=False)

    # Reset logging level
    if silent:
        logging.disable(level=logging.NOTSET)
Esempio n. 8
0
def export(
    entities: Optional[Iterable[Any]] = None,
    filename: Optional[str] = None,
    file_format: Union[str, Type[ArchiveWriterAbstract]] = ExportFileFormat.ZIP,
    overwrite: bool = False,
    silent: Optional[bool] = None,
    use_compression: Optional[bool] = None,
    include_comments: bool = True,
    include_logs: bool = True,
    allowed_licenses: Optional[Union[list, Callable]] = None,
    forbidden_licenses: Optional[Union[list, Callable]] = None,
    writer_init: Optional[Dict[str, Any]] = None,
    batch_size: int = 100,
    **traversal_rules: bool,
) -> ArchiveWriterAbstract:
    """Export AiiDA data to an archive file.

    Note, the logging level and progress reporter should be set externally, for example::

        from aiida.common.progress_reporter import set_progress_bar_tqdm

        EXPORT_LOGGER.setLevel('DEBUG')
        set_progress_bar_tqdm(leave=True)
        export(...)

    .. deprecated:: 1.5.0
        Support for the parameter `silent` will be removed in `v2.0.0`.
        Please set the log level and progress bar implementation independently.

    .. deprecated:: 1.5.0
        Support for the parameter `use_compression` will be removed in `v2.0.0`.
        Please use `writer_init={'use_compression': True}`.

    .. deprecated:: 1.2.1
        Support for the parameters `what` and `outfile` will be removed in `v2.0.0`.
        Please use `entities` and `filename` instead, respectively.

    :param entities: a list of entity instances;
        they can belong to different models/entities.

    :param filename: the filename (possibly including the absolute path)
        of the file on which to export.

    :param file_format: 'zip', 'tar.gz' or 'folder' or a specific writer class.

    :param overwrite: if True, overwrite the output file without asking, if it exists.
        If False, raise an
        :py:class:`~aiida.tools.importexport.common.exceptions.ArchiveExportError`
        if the output file already exists.

    :param allowed_licenses: List or function.
        If a list, then checks whether all licenses of Data nodes are in the list. If a function,
        then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.

    :param forbidden_licenses: List or function. If a list,
        then checks whether all licenses of Data nodes are in the list. If a function,
        then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.

    :param include_comments: In-/exclude export of comments for given node(s) in ``entities``.
        Default: True, *include* comments in export (as well as relevant users).

    :param include_logs: In-/exclude export of logs for given node(s) in ``entities``.
        Default: True, *include* logs in export.

    :param writer_init: Additional key-word arguments to pass to the writer class init

    :param batch_size: batch database query results in sub-collections to reduce memory usage

    :param traversal_rules: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules`
        what rule names are toggleable and what the defaults are.

    :returns: a dictionary of data regarding the export process (timings, etc)

    :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`:
        if there are any internal errors when exporting.
    :raises `~aiida.common.exceptions.LicensingException`:
        if any node is licensed under forbidden license.
    """
    # pylint: disable=too-many-locals,too-many-branches,too-many-statements

    # Backwards-compatibility
    entities = cast(
        Iterable[Any],
        deprecated_parameters(
            old={
                'name': 'what',
                'value': traversal_rules.pop('what', None)
            },
            new={
                'name': 'entities',
                'value': entities
            },
        ),
    )
    filename = cast(
        str,
        deprecated_parameters(
            old={
                'name': 'outfile',
                'value': traversal_rules.pop('outfile', None)
            },
            new={
                'name': 'filename',
                'value': filename
            },
        ),
    )
    if silent is not None:
        warnings.warn(
            'silent keyword is deprecated and will be removed in AiiDA v2.0.0, set the logger level explicitly instead',
            AiidaDeprecationWarning
        )  # pylint: disable=no-member

    type_check(
        entities,
        (list, tuple, set),
        msg='`entities` must be specified and given as a list of AiiDA entities',
    )
    entities = list(entities)
    if type_check(filename, str, allow_none=True) is None:
        filename = 'export_data.aiida'

    if not overwrite and os.path.exists(filename):
        raise exceptions.ArchiveExportError(f"The output file '{filename}' already exists")

    # validate the traversal rules and generate a full set for reporting
    validate_traversal_rules(GraphTraversalRules.EXPORT, **traversal_rules)
    full_traversal_rules = {
        name: traversal_rules.get(name, rule.default) for name, rule in GraphTraversalRules.EXPORT.value.items()
    }

    # setup the archive writer
    writer_init = writer_init or {}
    if use_compression is not None:
        warnings.warn(
            'use_compression argument is deprecated and will be removed in AiiDA v2.0.0 (which will always compress)',
            AiidaDeprecationWarning
        )  # pylint: disable=no-member
        writer_init['use_compression'] = use_compression
    if isinstance(file_format, str):
        writer = get_writer(file_format)(filepath=filename, **writer_init)
    elif issubclass(file_format, ArchiveWriterAbstract):
        writer = file_format(filepath=filename, **writer_init)
    else:
        raise TypeError('file_format must be a string or ArchiveWriterAbstract class')

    summary(
        file_format=writer.file_format_verbose,
        export_version=writer.export_version,
        outfile=filename,
        include_comments=include_comments,
        include_logs=include_logs,
        traversal_rules=full_traversal_rules
    )

    EXPORT_LOGGER.debug('STARTING EXPORT...')

    all_fields_info, unique_identifiers = get_all_fields_info()
    entities_starting_set, given_node_entry_ids = _get_starting_node_ids(entities)

    # Initialize the writer
    with writer as writer_context:

        # Iteratively explore the AiiDA graph to find further nodes that should also be exported
        with get_progress_reporter()(desc='Traversing provenance via links ...', total=1) as progress:
            traverse_output = get_nodes_export(starting_pks=given_node_entry_ids, get_links=True, **traversal_rules)
            progress.update()
        node_ids_to_be_exported = traverse_output['nodes']

        EXPORT_LOGGER.debug('WRITING METADATA...')

        writer_context.write_metadata(
            ArchiveMetadata(
                export_version=EXPORT_VERSION,
                aiida_version=get_version(),
                unique_identifiers=unique_identifiers,
                all_fields_info=all_fields_info,
                graph_traversal_rules=traverse_output['rules'],
                # Turn sets into lists to be able to export them as JSON metadata.
                entities_starting_set={
                    entity: list(entity_set) for entity, entity_set in entities_starting_set.items()
                },
                include_comments=include_comments,
                include_logs=include_logs,
            )
        )

        # Create a mapping of node PK to UUID.
        node_pk_2_uuid_mapping: Dict[int, str] = {}
        if node_ids_to_be_exported:
            qbuilder = orm.QueryBuilder().append(
                orm.Node,
                project=('id', 'uuid'),
                filters={'id': {
                    'in': node_ids_to_be_exported
                }},
            )
            node_pk_2_uuid_mapping = dict(qbuilder.all(batch_size=batch_size))

        # check that no nodes are being exported with incorrect licensing
        _check_node_licenses(node_ids_to_be_exported, allowed_licenses, forbidden_licenses)

        # write the link data
        if traverse_output['links'] is not None:
            with get_progress_reporter()(total=len(traverse_output['links']), desc='Writing links') as progress:
                for link in traverse_output['links']:
                    progress.update()
                    writer_context.write_link({
                        'input': node_pk_2_uuid_mapping[link.source_id],
                        'output': node_pk_2_uuid_mapping[link.target_id],
                        'label': link.link_label,
                        'type': link.link_type,
                    })

        # generate a list of queries to encapsulate all required entities
        entity_queries = _collect_entity_queries(
            node_ids_to_be_exported,
            entities_starting_set,
            node_pk_2_uuid_mapping,
            include_comments,
            include_logs,
        )

        total_entities = sum(query.count() for query in entity_queries.values())

        # write all entity data fields
        if total_entities:
            exported_entity_pks = _write_entity_data(
                total_entities=total_entities,
                entity_queries=entity_queries,
                writer=writer_context,
                batch_size=batch_size
            )
        else:
            exported_entity_pks = defaultdict(set)
            EXPORT_LOGGER.info('No entities were found to export')

        # write mappings of groups to the nodes they contain
        if exported_entity_pks[GROUP_ENTITY_NAME]:

            EXPORT_LOGGER.debug('Writing group UUID -> [nodes UUIDs]')

            _write_group_mappings(
                group_pks=exported_entity_pks[GROUP_ENTITY_NAME], batch_size=batch_size, writer=writer_context
            )

        # copy all required node repositories
        if exported_entity_pks[NODE_ENTITY_NAME]:

            _write_node_repositories(
                node_pks=exported_entity_pks[NODE_ENTITY_NAME],
                node_pk_2_uuid_mapping=node_pk_2_uuid_mapping,
                writer=writer_context
            )

        EXPORT_LOGGER.info('Finalizing Export...')

    # summarize export
    export_summary = '\n  - '.join(f'{name:<6}: {len(pks)}' for name, pks in exported_entity_pks.items())
    if exported_entity_pks:
        EXPORT_LOGGER.info('Exported Entities:\n  - ' + export_summary + '\n')
    # TODO
    # EXPORT_LOGGER.info('Writer Information:\n %s', writer.export_info)

    return writer
Esempio n. 9
0
def _write_entity_data(
    total_entities: int, entity_queries: Dict[str, orm.QueryBuilder], writer: ArchiveWriterAbstract, batch_size: int
) -> Dict[str, Set[int]]:
    """Iterate through data returned from entity queries, serialize the DB fields, then write to the export."""
    all_fields_info, unique_identifiers = get_all_fields_info()
    entity_separator = '_'

    exported_entity_pks: Dict[str, Set[int]] = defaultdict(set)
    unsealed_node_pks: Set[int] = set()

    with get_progress_reporter()(total=total_entities, desc='Writing entity data') as progress:

        for entity_name, entity_query in entity_queries.items():

            foreign_fields = {k: v for k, v in all_fields_info[entity_name].items() if 'requires' in v}
            for value in foreign_fields.values():
                ref_model_name = value['requires']
                fill_in_query(
                    entity_query,
                    entity_name,
                    ref_model_name,
                    [entity_name],
                    entity_separator,
                )

            for query_results in entity_query.iterdict(batch_size=batch_size):

                progress.update()

                for key, value in query_results.items():

                    pk = value['id']

                    # This is an empty result of an outer join.
                    # It should not be taken into account.
                    if pk is None:
                        continue

                    # Get current entity
                    current_entity = key.split(entity_separator)[-1]

                    # don't allow duplication
                    if pk in exported_entity_pks[current_entity]:
                        continue

                    exported_entity_pks[current_entity].add(pk)

                    fields = serialize_dict(
                        value,
                        remove_fields=['id'],
                        rename_fields=model_fields_to_file_fields[current_entity],
                    )

                    if current_entity == NODE_ENTITY_NAME and fields['node_type'].startswith('process.'):
                        if fields['attributes'].get('sealed', False) is not True:
                            unsealed_node_pks.add(pk)

                    writer.write_entity_data(current_entity, pk, unique_identifiers[current_entity], fields)

    if unsealed_node_pks:
        raise exceptions.ExportValidationError(
            'All ProcessNodes must be sealed before they can be exported. '
            f"Node(s) with PK(s): {', '.join(str(pk) for pk in unsealed_node_pks)} is/are not sealed."
        )

    return exported_entity_pks