Ejemplo n.º 1
0
def export_tree(what,
                folder,
                allowed_licenses=None,
                forbidden_licenses=None,
                silent=False,
                include_comments=True,
                include_logs=True,
                **kwargs):
    """Export the entries passed in the 'what' list to a file tree.

    :param what: a list of entity instances; they can belong to different models/entities.
    :type what: list

    :param folder: a temporary folder to build the archive before compression.
    :type folder: :py:class:`~aiida.common.folders.Folder`

    :param allowed_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type allowed_licenses: list

    :param forbidden_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type forbidden_licenses: list

    :param silent: suppress prints.
    :type silent: bool

    :param include_comments: In-/exclude export of comments for given node(s) in ``what``.
        Default: True, *include* comments in export (as well as relevant users).
    :type include_comments: bool

    :param include_logs: In-/exclude export of logs for given node(s) in ``what``.
        Default: True, *include* logs in export.
    :type include_logs: bool

    :param kwargs: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` what rule names
        are toggleable and what the defaults are.

    :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`: if there are any internal errors when
        exporting.
    :raises `~aiida.common.exceptions.LicensingException`: if any node is licensed under forbidden license.
    """
    from collections import defaultdict

    if not silent:
        print('STARTING EXPORT...')

    all_fields_info, unique_identifiers = get_all_fields_info()

    entities_starting_set = defaultdict(set)

    # The set that contains the nodes ids of the nodes that should be exported
    given_data_entry_ids = set()
    given_calculation_entry_ids = set()
    given_group_entry_ids = set()
    given_computer_entry_ids = set()
    given_groups = set()
    given_log_entry_ids = set()
    given_comment_entry_ids = set()

    # I store a list of the actual dbnodes
    for entry in what:
        # This returns the class name (as in imports). E.g. for a model node:
        # aiida.backends.djsite.db.models.DbNode
        # entry_class_string = get_class_string(entry)
        # Now a load the backend-independent name into entry_entity_name, e.g. Node!
        # entry_entity_name = schema_to_entity_names(entry_class_string)
        if issubclass(entry.__class__, orm.Group):
            entities_starting_set[GROUP_ENTITY_NAME].add(entry.uuid)
            given_group_entry_ids.add(entry.id)
            given_groups.add(entry)
        elif issubclass(entry.__class__, orm.Node):
            entities_starting_set[NODE_ENTITY_NAME].add(entry.uuid)
            if issubclass(entry.__class__, orm.Data):
                given_data_entry_ids.add(entry.pk)
            elif issubclass(entry.__class__, orm.ProcessNode):
                given_calculation_entry_ids.add(entry.pk)
        elif issubclass(entry.__class__, orm.Computer):
            entities_starting_set[COMPUTER_ENTITY_NAME].add(entry.uuid)
            given_computer_entry_ids.add(entry.pk)
        else:
            raise exceptions.ArchiveExportError(
                'I was given {} ({}), which is not a Node, Computer, or Group instance'
                .format(entry, type(entry)))

    # Add all the nodes contained within the specified groups
    for group in given_groups:
        for entry in group.nodes:
            entities_starting_set[NODE_ENTITY_NAME].add(entry.uuid)
            if issubclass(entry.__class__, orm.Data):
                given_data_entry_ids.add(entry.pk)
            elif issubclass(entry.__class__, orm.ProcessNode):
                given_calculation_entry_ids.add(entry.pk)

    for entity, entity_set in entities_starting_set.items():
        entities_starting_set[entity] = list(entity_set)

    # We will iteratively explore the AiiDA graph to find further nodes that
    # should also be exported.
    # At the same time, we will create the links_uuid list of dicts to be exported

    if not silent:
        print('RETRIEVING LINKED NODES AND STORING LINKS...')

    to_be_exported, links_uuid, graph_traversal_rules = retrieve_linked_nodes(
        given_calculation_entry_ids, given_data_entry_ids, **kwargs)

    ## Universal "entities" attributed to all types of nodes
    # Logs
    if include_logs and to_be_exported:
        # Get related log(s) - universal for all nodes
        builder = orm.QueryBuilder()
        builder.append(orm.Log,
                       filters={'dbnode_id': {
                           'in': to_be_exported
                       }},
                       project='id')
        res = {_[0] for _ in builder.all()}
        given_log_entry_ids.update(res)

    # Comments
    if include_comments and to_be_exported:
        # Get related log(s) - universal for all nodes
        builder = orm.QueryBuilder()
        builder.append(orm.Comment,
                       filters={'dbnode_id': {
                           'in': to_be_exported
                       }},
                       project='id')
        res = {_[0] for _ in builder.all()}
        given_comment_entry_ids.update(res)

    # Here we get all the columns that we plan to project per entity that we
    # would like to extract
    given_entities = list()
    if given_group_entry_ids:
        given_entities.append(GROUP_ENTITY_NAME)
    if to_be_exported:
        given_entities.append(NODE_ENTITY_NAME)
    if given_computer_entry_ids:
        given_entities.append(COMPUTER_ENTITY_NAME)
    if given_log_entry_ids:
        given_entities.append(LOG_ENTITY_NAME)
    if given_comment_entry_ids:
        given_entities.append(COMMENT_ENTITY_NAME)

    entries_to_add = dict()
    for given_entity in given_entities:
        project_cols = ['id']
        # The following gets a list of fields that we need,
        # e.g. user, mtime, uuid, computer
        entity_prop = all_fields_info[given_entity].keys()

        # Here we do the necessary renaming of properties
        for prop in entity_prop:
            # nprop contains the list of projections
            nprop = (file_fields_to_model_fields[given_entity][prop] if prop
                     in file_fields_to_model_fields[given_entity] else prop)
            project_cols.append(nprop)

        # Getting the ids that correspond to the right entity
        if given_entity == GROUP_ENTITY_NAME:
            entry_ids_to_add = given_group_entry_ids
        elif given_entity == NODE_ENTITY_NAME:
            entry_ids_to_add = to_be_exported
        elif given_entity == COMPUTER_ENTITY_NAME:
            entry_ids_to_add = given_computer_entry_ids
        elif given_entity == LOG_ENTITY_NAME:
            entry_ids_to_add = given_log_entry_ids
        elif given_entity == COMMENT_ENTITY_NAME:
            entry_ids_to_add = given_comment_entry_ids

        builder = orm.QueryBuilder()
        builder.append(entity_names_to_entities[given_entity],
                       filters={'id': {
                           'in': entry_ids_to_add
                       }},
                       project=project_cols,
                       tag=given_entity,
                       outerjoin=True)
        entries_to_add[given_entity] = builder

    # TODO (Spyros) To see better! Especially for functional licenses
    # Check the licenses of exported data.
    if allowed_licenses is not None or forbidden_licenses is not None:
        builder = orm.QueryBuilder()
        builder.append(orm.Node,
                       project=['id', 'attributes.source.license'],
                       filters={'id': {
                           'in': to_be_exported
                       }})
        # Skip those nodes where the license is not set (this is the standard behavior with Django)
        node_licenses = list(
            (a, b) for [a, b] in builder.all() if b is not None)
        check_licenses(node_licenses, allowed_licenses, forbidden_licenses)

    ############################################################
    ##### Start automatic recursive export data generation #####
    ############################################################
    if not silent:
        print('STORING DATABASE ENTRIES...')

    export_data = dict()
    entity_separator = '_'
    for entity_name, partial_query in entries_to_add.items():

        foreign_fields = {
            k: v
            for k, v in all_fields_info[entity_name].items()
            # all_fields_info[model_name].items()
            if 'requires' in v
        }

        for value in foreign_fields.values():
            ref_model_name = value['requires']
            fill_in_query(partial_query, entity_name, ref_model_name,
                          [entity_name], entity_separator)

        for temp_d in partial_query.iterdict():
            for k in temp_d.keys():
                # Get current entity
                current_entity = k.split(entity_separator)[-1]

                # This is a empty result of an outer join.
                # It should not be taken into account.
                if temp_d[k]['id'] is None:
                    continue

                temp_d2 = {
                    temp_d[k]['id']:
                    serialize_dict(temp_d[k],
                                   remove_fields=['id'],
                                   rename_fields=model_fields_to_file_fields[
                                       current_entity])
                }
                try:
                    export_data[current_entity].update(temp_d2)
                except KeyError:
                    export_data[current_entity] = temp_d2

    #######################################
    # Manually manage attributes and extras
    #######################################
    # I use .get because there may be no nodes to export
    all_nodes_pk = list()
    if NODE_ENTITY_NAME in export_data:
        all_nodes_pk.extend(export_data.get(NODE_ENTITY_NAME).keys())

    if sum(len(model_data) for model_data in export_data.values()) == 0:
        if not silent:
            print('No nodes to store, exiting...')
        return

    if not silent:
        print('Exporting a total of {} db entries, of which {} nodes.'.format(
            sum(len(model_data) for model_data in export_data.values()),
            len(all_nodes_pk)))

    # ATTRIBUTES and EXTRAS
    if not silent:
        print('STORING NODE ATTRIBUTES AND EXTRAS...')
    node_attributes = {}
    node_extras = {}

    # A second QueryBuilder query to get the attributes and extras. See if this can be optimized
    if all_nodes_pk:
        all_nodes_query = orm.QueryBuilder()
        all_nodes_query.append(orm.Node,
                               filters={'id': {
                                   'in': all_nodes_pk
                               }},
                               project=['id', 'attributes', 'extras'])
        for res_pk, res_attributes, res_extras in all_nodes_query.iterall():
            node_attributes[str(res_pk)] = res_attributes
            node_extras[str(res_pk)] = res_extras

    if not silent:
        print('STORING GROUP ELEMENTS...')
    groups_uuid = dict()
    # If a group is in the exported date, we export the group/node correlation
    if GROUP_ENTITY_NAME in export_data:
        for curr_group in export_data[GROUP_ENTITY_NAME]:
            group_uuid_qb = orm.QueryBuilder()
            group_uuid_qb.append(entity_names_to_entities[GROUP_ENTITY_NAME],
                                 filters={'id': {
                                     '==': curr_group
                                 }},
                                 project=['uuid'],
                                 tag='group')
            group_uuid_qb.append(entity_names_to_entities[NODE_ENTITY_NAME],
                                 project=['uuid'],
                                 with_group='group')
            for res in group_uuid_qb.iterall():
                if str(res[0]) in groups_uuid:
                    groups_uuid[str(res[0])].append(str(res[1]))
                else:
                    groups_uuid[str(res[0])] = [str(res[1])]

    #######################################
    # Final check for unsealed ProcessNodes
    #######################################
    process_nodes = set()
    for node_pk, content in export_data.get(NODE_ENTITY_NAME, {}).items():
        if content['node_type'].startswith('process.'):
            process_nodes.add(node_pk)

    check_process_nodes_sealed(process_nodes)

    ######################################
    # Now I store
    ######################################
    # subfolder inside the export package
    nodesubfolder = folder.get_subfolder(NODES_EXPORT_SUBFOLDER,
                                         create=True,
                                         reset_limit=True)

    if not silent:
        print('STORING DATA...')

    data = {
        'node_attributes': node_attributes,
        'node_extras': node_extras,
        'export_data': export_data,
        'links_uuid': links_uuid,
        'groups_uuid': groups_uuid
    }

    # N.B. We're really calling zipfolder.open (if exporting a zipfile)
    with folder.open('data.json', mode='w') as fhandle:
        # fhandle.write(json.dumps(data, cls=UUIDEncoder))
        fhandle.write(json.dumps(data))

    # Add proper signature to unique identifiers & all_fields_info
    # Ignore if a key doesn't exist in any of the two dictionaries

    metadata = {
        'aiida_version': get_version(),
        'export_version': EXPORT_VERSION,
        'all_fields_info': all_fields_info,
        'unique_identifiers': unique_identifiers,
        'export_parameters': {
            'graph_traversal_rules': graph_traversal_rules,
            'entities_starting_set': entities_starting_set,
            'include_comments': include_comments,
            'include_logs': include_logs
        }
    }

    with folder.open('metadata.json', 'w') as fhandle:
        fhandle.write(json.dumps(metadata))

    if silent is not True:
        print('STORING REPOSITORY FILES...')

    # If there are no nodes, there are no repository files to store
    if all_nodes_pk:
        # Large speed increase by not getting the node itself and looping in memory in python, but just getting the uuid
        uuid_query = orm.QueryBuilder()
        uuid_query.append(orm.Node,
                          filters={'id': {
                              'in': all_nodes_pk
                          }},
                          project=['uuid'])
        for res in uuid_query.all():
            uuid = str(res[0])
            sharded_uuid = export_shard_uuid(uuid)

            # Important to set create=False, otherwise creates twice a subfolder. Maybe this is a bug of insert_path?
            thisnodefolder = nodesubfolder.get_subfolder(sharded_uuid,
                                                         create=False,
                                                         reset_limit=True)

            # Make sure the node's repository folder was not deleted
            src = RepositoryFolder(section=Repository._section_name, uuid=uuid)  # pylint: disable=protected-access
            if not src.exists():
                raise exceptions.ArchiveExportError(
                    'Unable to find the repository folder for Node with UUID={} in the local repository'
                    .format(uuid))

            # In this way, I copy the content of the folder, and not the folder itself
            thisnodefolder.insert_path(src=src.abspath, dest_name='.')
Ejemplo n.º 2
0
def export_tree(entities=None,
                folder=None,
                allowed_licenses=None,
                forbidden_licenses=None,
                silent=False,
                include_comments=True,
                include_logs=True,
                **kwargs):
    """Export the entries passed in the 'entities' list to a file tree.

    .. deprecated:: 1.2.1
        Support for the parameter `what` will be removed in `v2.0.0`. Please use `entities` instead.

    :param entities: a list of entity instances; they can belong to different models/entities.
    :type entities: list

    :param folder: a temporary folder to build the archive before compression.
    :type folder: :py:class:`~aiida.common.folders.Folder`

    :param allowed_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type allowed_licenses: list

    :param forbidden_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type forbidden_licenses: list

    :param silent: suppress console prints and progress bar.
    :type silent: bool

    :param include_comments: In-/exclude export of comments for given node(s) in ``entities``.
        Default: True, *include* comments in export (as well as relevant users).
    :type include_comments: bool

    :param include_logs: In-/exclude export of logs for given node(s) in ``entities``.
        Default: True, *include* logs in export.
    :type include_logs: bool

    :param kwargs: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` what rule names
        are toggleable and what the defaults are.

    :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`: if there are any internal errors when
        exporting.
    :raises `~aiida.common.exceptions.LicensingException`: if any node is licensed under forbidden license.
    """
    from collections import defaultdict
    from aiida.tools.graph.graph_traversers import get_nodes_export

    if silent:
        logging.disable(level=logging.CRITICAL)

    EXPORT_LOGGER.debug('STARTING EXPORT...')

    # Backwards-compatibility
    entities = deprecated_parameters(
        old={
            'name': 'what',
            'value': kwargs.pop('what', None)
        },
        new={
            'name': 'entities',
            'value': entities
        },
    )

    type_check(
        entities, (list, tuple, set),
        msg='`entities` must be specified and given as a list of AiiDA entities'
    )
    entities = list(entities)

    type_check(
        folder, (Folder, ZipFolder),
        msg='`folder` must be specified and given as an AiiDA Folder entity')

    all_fields_info, unique_identifiers = get_all_fields_info()

    entities_starting_set = defaultdict(set)

    # The set that contains the nodes ids of the nodes that should be exported
    given_node_entry_ids = set()
    given_log_entry_ids = set()
    given_comment_entry_ids = set()

    # Instantiate progress bar - go through list of `entities`
    pbar_total = len(entities) + 1 if entities else 1
    progress_bar = get_progress_bar(total=pbar_total,
                                    leave=False,
                                    disable=silent)
    progress_bar.set_description_str('Collecting chosen entities',
                                     refresh=False)

    # I store a list of the actual dbnodes
    for entry in entities:
        progress_bar.update()

        # This returns the class name (as in imports). E.g. for a model node:
        # aiida.backends.djsite.db.models.DbNode
        # entry_class_string = get_class_string(entry)
        # Now a load the backend-independent name into entry_entity_name, e.g. Node!
        # entry_entity_name = schema_to_entity_names(entry_class_string)
        if issubclass(entry.__class__, orm.Group):
            entities_starting_set[GROUP_ENTITY_NAME].add(entry.uuid)
        elif issubclass(entry.__class__, orm.Node):
            entities_starting_set[NODE_ENTITY_NAME].add(entry.uuid)
            given_node_entry_ids.add(entry.pk)
        elif issubclass(entry.__class__, orm.Computer):
            entities_starting_set[COMPUTER_ENTITY_NAME].add(entry.uuid)
        else:
            raise exceptions.ArchiveExportError(
                'I was given {} ({}), which is not a Node, Computer, or Group instance'
                .format(entry, type(entry)))

    # Add all the nodes contained within the specified groups
    if GROUP_ENTITY_NAME in entities_starting_set:

        progress_bar.set_description_str('Retrieving Nodes from Groups ...',
                                         refresh=True)

        # Use single query instead of given_group.nodes iterator for performance.
        qh_groups = orm.QueryBuilder().append(
            orm.Group,
            filters={
                'uuid': {
                    'in': entities_starting_set[GROUP_ENTITY_NAME]
                }
            },
            tag='groups').queryhelp

        # Delete this import once the dbexport.zip module has been renamed
        from builtins import zip  # pylint: disable=redefined-builtin

        node_results = orm.QueryBuilder(**qh_groups).append(
            orm.Node, project=['id', 'uuid'], with_group='groups').all()
        if node_results:
            pks, uuids = map(list, zip(*node_results))
            entities_starting_set[NODE_ENTITY_NAME].update(uuids)
            given_node_entry_ids.update(pks)
            del node_results, pks, uuids

        progress_bar.update()

    # We will iteratively explore the AiiDA graph to find further nodes that should also be exported.
    # At the same time, we will create the links_uuid list of dicts to be exported

    progress_bar = get_progress_bar(total=1, disable=silent)
    progress_bar.set_description_str(
        'Getting provenance and storing links ...', refresh=True)

    traverse_output = get_nodes_export(starting_pks=given_node_entry_ids,
                                       get_links=True,
                                       **kwargs)
    node_ids_to_be_exported = traverse_output['nodes']
    graph_traversal_rules = traverse_output['rules']

    # A utility dictionary for mapping PK to UUID.
    if node_ids_to_be_exported:
        qbuilder = orm.QueryBuilder().append(
            orm.Node,
            project=('id', 'uuid'),
            filters={'id': {
                'in': node_ids_to_be_exported
            }},
        )
        node_pk_2_uuid_mapping = dict(qbuilder.all())
    else:
        node_pk_2_uuid_mapping = {}

    # The set of tuples now has to be transformed to a list of dicts
    links_uuid = [{
        'input': node_pk_2_uuid_mapping[link.source_id],
        'output': node_pk_2_uuid_mapping[link.target_id],
        'label': link.link_label,
        'type': link.link_type
    } for link in traverse_output['links']]

    progress_bar.update()

    # Progress bar initialization - Entities
    progress_bar = get_progress_bar(total=1, disable=silent)
    progress_bar.set_description_str('Initializing export of all entities',
                                     refresh=True)

    ## Universal "entities" attributed to all types of nodes
    # Logs
    if include_logs and node_ids_to_be_exported:
        # Get related log(s) - universal for all nodes
        builder = orm.QueryBuilder()
        builder.append(orm.Log,
                       filters={'dbnode_id': {
                           'in': node_ids_to_be_exported
                       }},
                       project='uuid')
        res = set(builder.all(flat=True))
        given_log_entry_ids.update(res)

    # Comments
    if include_comments and node_ids_to_be_exported:
        # Get related log(s) - universal for all nodes
        builder = orm.QueryBuilder()
        builder.append(orm.Comment,
                       filters={'dbnode_id': {
                           'in': node_ids_to_be_exported
                       }},
                       project='uuid')
        res = set(builder.all(flat=True))
        given_comment_entry_ids.update(res)

    # Here we get all the columns that we plan to project per entity that we would like to extract
    given_entities = set(entities_starting_set.keys())
    if node_ids_to_be_exported:
        given_entities.add(NODE_ENTITY_NAME)
    if given_log_entry_ids:
        given_entities.add(LOG_ENTITY_NAME)
    if given_comment_entry_ids:
        given_entities.add(COMMENT_ENTITY_NAME)

    progress_bar.update()

    if given_entities:
        progress_bar = get_progress_bar(total=len(given_entities),
                                        disable=silent)
        pbar_base_str = 'Preparing entities'

    entries_to_add = dict()
    for given_entity in given_entities:
        progress_bar.set_description_str(pbar_base_str +
                                         ' - {}s'.format(given_entity),
                                         refresh=False)
        progress_bar.update()

        project_cols = ['id']
        # The following gets a list of fields that we need,
        # e.g. user, mtime, uuid, computer
        entity_prop = all_fields_info[given_entity].keys()

        # Here we do the necessary renaming of properties
        for prop in entity_prop:
            # nprop contains the list of projections
            nprop = (file_fields_to_model_fields[given_entity][prop] if prop
                     in file_fields_to_model_fields[given_entity] else prop)
            project_cols.append(nprop)

        # Getting the ids that correspond to the right entity
        entry_uuids_to_add = entities_starting_set.get(given_entity, set())
        if not entry_uuids_to_add:
            if given_entity == LOG_ENTITY_NAME:
                entry_uuids_to_add = given_log_entry_ids
            elif given_entity == COMMENT_ENTITY_NAME:
                entry_uuids_to_add = given_comment_entry_ids
        elif given_entity == NODE_ENTITY_NAME:
            entry_uuids_to_add.update(
                {node_pk_2_uuid_mapping[_]
                 for _ in node_ids_to_be_exported})

        builder = orm.QueryBuilder()
        builder.append(entity_names_to_entities[given_entity],
                       filters={'uuid': {
                           'in': entry_uuids_to_add
                       }},
                       project=project_cols,
                       tag=given_entity,
                       outerjoin=True)
        entries_to_add[given_entity] = builder

    # TODO (Spyros) To see better! Especially for functional licenses
    # Check the licenses of exported data.
    if allowed_licenses is not None or forbidden_licenses is not None:
        builder = orm.QueryBuilder()
        builder.append(orm.Node,
                       project=['id', 'attributes.source.license'],
                       filters={'id': {
                           'in': node_ids_to_be_exported
                       }})
        # Skip those nodes where the license is not set (this is the standard behavior with Django)
        node_licenses = list(
            (a, b) for [a, b] in builder.all() if b is not None)
        check_licenses(node_licenses, allowed_licenses, forbidden_licenses)

    ############################################################
    ##### Start automatic recursive export data generation #####
    ############################################################
    EXPORT_LOGGER.debug('GATHERING DATABASE ENTRIES...')

    if entries_to_add:
        progress_bar = get_progress_bar(total=len(entries_to_add),
                                        disable=silent)

    export_data = defaultdict(dict)
    entity_separator = '_'
    for entity_name, partial_query in entries_to_add.items():

        progress_bar.set_description_str('Exporting {}s'.format(entity_name),
                                         refresh=False)
        progress_bar.update()

        foreign_fields = {
            k: v
            for k, v in all_fields_info[entity_name].items() if 'requires' in v
        }

        for value in foreign_fields.values():
            ref_model_name = value['requires']
            fill_in_query(partial_query, entity_name, ref_model_name,
                          [entity_name], entity_separator)

        for temp_d in partial_query.iterdict():
            for key in temp_d:
                # Get current entity
                current_entity = key.split(entity_separator)[-1]

                # This is a empty result of an outer join.
                # It should not be taken into account.
                if temp_d[key]['id'] is None:
                    continue

                export_data[current_entity].update({
                    temp_d[key]['id']:
                    serialize_dict(temp_d[key],
                                   remove_fields=['id'],
                                   rename_fields=model_fields_to_file_fields[
                                       current_entity])
                })

    # Close progress up until this point in order to print properly
    close_progress_bar(leave=False)

    #######################################
    # Manually manage attributes and extras
    #######################################
    # Pointer. Renaming, since Nodes have now technically been retrieved and "stored"
    all_node_pks = node_ids_to_be_exported

    model_data = sum(len(model_data) for model_data in export_data.values())
    if not model_data:
        EXPORT_LOGGER.log(msg='Nothing to store, exiting...',
                          level=LOG_LEVEL_REPORT)
        return
    EXPORT_LOGGER.log(
        msg='Exporting a total of {} database entries, of which {} are Nodes.'.
        format(model_data, len(all_node_pks)),
        level=LOG_LEVEL_REPORT)

    # Instantiate new progress bar
    progress_bar = get_progress_bar(total=1, leave=False, disable=silent)

    # ATTRIBUTES and EXTRAS
    EXPORT_LOGGER.debug('GATHERING NODE ATTRIBUTES AND EXTRAS...')
    node_attributes = {}
    node_extras = {}

    # Another QueryBuilder query to get the attributes and extras. TODO: See if this can be optimized
    if all_node_pks:
        all_nodes_query = orm.QueryBuilder().append(
            orm.Node,
            filters={'id': {
                'in': all_node_pks
            }},
            project=['id', 'attributes', 'extras'])

        progress_bar = get_progress_bar(total=all_nodes_query.count(),
                                        disable=silent)
        progress_bar.set_description_str('Exporting Attributes and Extras',
                                         refresh=False)

        for node_pk, attributes, extras in all_nodes_query.iterall():
            progress_bar.update()

            node_attributes[str(node_pk)] = attributes
            node_extras[str(node_pk)] = extras

    EXPORT_LOGGER.debug('GATHERING GROUP ELEMENTS...')
    groups_uuid = defaultdict(list)
    # If a group is in the exported data, we export the group/node correlation
    if GROUP_ENTITY_NAME in export_data:
        group_uuids_with_node_uuids = orm.QueryBuilder().append(
            orm.Group,
            filters={
                'id': {
                    'in': export_data[GROUP_ENTITY_NAME]
                }
            },
            project='uuid',
            tag='groups').append(orm.Node, project='uuid', with_group='groups')

        # This part is _only_ for the progress bar
        total_node_uuids_for_groups = group_uuids_with_node_uuids.count()
        if total_node_uuids_for_groups:
            progress_bar = get_progress_bar(total=total_node_uuids_for_groups,
                                            disable=silent)
            progress_bar.set_description_str('Exporting Groups ...',
                                             refresh=False)

        for group_uuid, node_uuid in group_uuids_with_node_uuids.iterall():
            progress_bar.update()

            groups_uuid[group_uuid].append(node_uuid)

    #######################################
    # Final check for unsealed ProcessNodes
    #######################################
    process_nodes = set()
    for node_pk, content in export_data.get(NODE_ENTITY_NAME, {}).items():
        if content['node_type'].startswith('process.'):
            process_nodes.add(node_pk)

    check_process_nodes_sealed(process_nodes)

    ######################################
    # Now collecting and storing
    ######################################
    # subfolder inside the export package
    nodesubfolder = folder.get_subfolder(NODES_EXPORT_SUBFOLDER,
                                         create=True,
                                         reset_limit=True)

    EXPORT_LOGGER.debug('ADDING DATA TO EXPORT ARCHIVE...')

    data = {
        'node_attributes': node_attributes,
        'node_extras': node_extras,
        'export_data': export_data,
        'links_uuid': links_uuid,
        'groups_uuid': groups_uuid
    }

    # N.B. We're really calling zipfolder.open (if exporting a zipfile)
    with folder.open('data.json', mode='w') as fhandle:
        # fhandle.write(json.dumps(data, cls=UUIDEncoder))
        fhandle.write(json.dumps(data))

    # Turn sets into lists to be able to export them as JSON metadata.
    for entity, entity_set in entities_starting_set.items():
        entities_starting_set[entity] = list(entity_set)

    metadata = {
        'aiida_version': get_version(),
        'export_version': EXPORT_VERSION,
        'all_fields_info': all_fields_info,
        'unique_identifiers': unique_identifiers,
        'export_parameters': {
            'graph_traversal_rules': graph_traversal_rules,
            'entities_starting_set': entities_starting_set,
            'include_comments': include_comments,
            'include_logs': include_logs
        }
    }

    with folder.open('metadata.json', 'w') as fhandle:
        fhandle.write(json.dumps(metadata))

    EXPORT_LOGGER.debug('ADDING REPOSITORY FILES TO EXPORT ARCHIVE...')

    # If there are no nodes, there are no repository files to store
    if all_node_pks:
        all_node_uuids = {node_pk_2_uuid_mapping[_] for _ in all_node_pks}

        progress_bar = get_progress_bar(total=len(all_node_uuids),
                                        disable=silent)
        pbar_base_str = 'Exporting repository - '

        for uuid in all_node_uuids:
            sharded_uuid = export_shard_uuid(uuid)

            progress_bar.set_description_str(
                pbar_base_str + 'UUID={}'.format(uuid.split('-')[0]),
                refresh=False)
            progress_bar.update()

            # Important to set create=False, otherwise creates twice a subfolder. Maybe this is a bug of insert_path?
            thisnodefolder = nodesubfolder.get_subfolder(sharded_uuid,
                                                         create=False,
                                                         reset_limit=True)

            # Make sure the node's repository folder was not deleted
            src = RepositoryFolder(section=Repository._section_name, uuid=uuid)  # pylint: disable=protected-access
            if not src.exists():
                raise exceptions.ArchiveExportError(
                    'Unable to find the repository folder for Node with UUID={} in the local repository'
                    .format(uuid))

            # In this way, I copy the content of the folder, and not the folder itself
            thisnodefolder.insert_path(src=src.abspath, dest_name='.')

    close_progress_bar(leave=False)

    # Reset logging level
    if silent:
        logging.disable(level=logging.NOTSET)
Ejemplo n.º 3
0
def get_export_data(entries_queries: Dict[str, orm.QueryBuilder],
                    silent: bool) -> Dict[str, Dict[int, dict]]:
    """Start automatic recursive export data generation

    :param entries_queries: partial queries for all entities to export
    :param silent: suppress console prints and progress bar.

    :return: export data mappings by entity type -> pk -> db_columns, e.g.
        {'ENTITY_NAME': {<pk>: {'uuid': 'abc', ...}, ...}, ...}
        Note: this data does not yet contain attributes and extras

    """
    EXPORT_LOGGER.debug("GATHERING DATABASE ENTRIES...")

    all_fields_info, _ = get_all_fields_info()

    if entries_queries:
        progress_bar = get_progress_bar(total=len(entries_queries),
                                        disable=silent)

    export_data = defaultdict(dict)  # type: dict
    entity_separator = "_"
    for entity_name, partial_query in entries_queries.items():

        progress_bar.set_description_str(f"Exporting {entity_name}s",
                                         refresh=False)
        progress_bar.update()

        foreign_fields = {
            k: v
            for k, v in all_fields_info[entity_name].items() if "requires" in v
        }

        for value in foreign_fields.values():
            ref_model_name = value["requires"]
            fill_in_query(
                partial_query,
                entity_name,
                ref_model_name,
                [entity_name],
                entity_separator,
            )

        for temp_d in partial_query.iterdict():
            for key in temp_d:
                # Get current entity
                current_entity = key.split(entity_separator)[-1]

                # This is a empty result of an outer join.
                # It should not be taken into account.
                if temp_d[key]["id"] is None:
                    continue

                export_data[current_entity].update({
                    temp_d[key]["id"]:
                    serialize_dict(
                        temp_d[key],
                        remove_fields=["id"],
                        rename_fields=model_fields_to_file_fields[
                            current_entity],
                    )
                })

    return export_data
Ejemplo n.º 4
0
def _write_entity_data(
    total_entities: int, entity_queries: Dict[str, orm.QueryBuilder], writer: ArchiveWriterAbstract, batch_size: int
) -> Dict[str, Set[int]]:
    """Iterate through data returned from entity queries, serialize the DB fields, then write to the export."""
    all_fields_info, unique_identifiers = get_all_fields_info()
    entity_separator = '_'

    exported_entity_pks: Dict[str, Set[int]] = defaultdict(set)
    unsealed_node_pks: Set[int] = set()

    with get_progress_reporter()(total=total_entities, desc='Writing entity data') as progress:

        for entity_name, entity_query in entity_queries.items():

            foreign_fields = {k: v for k, v in all_fields_info[entity_name].items() if 'requires' in v}
            for value in foreign_fields.values():
                ref_model_name = value['requires']
                fill_in_query(
                    entity_query,
                    entity_name,
                    ref_model_name,
                    [entity_name],
                    entity_separator,
                )

            for query_results in entity_query.iterdict(batch_size=batch_size):

                progress.update()

                for key, value in query_results.items():

                    pk = value['id']

                    # This is an empty result of an outer join.
                    # It should not be taken into account.
                    if pk is None:
                        continue

                    # Get current entity
                    current_entity = key.split(entity_separator)[-1]

                    # don't allow duplication
                    if pk in exported_entity_pks[current_entity]:
                        continue

                    exported_entity_pks[current_entity].add(pk)

                    fields = serialize_dict(
                        value,
                        remove_fields=['id'],
                        rename_fields=model_fields_to_file_fields[current_entity],
                    )

                    if current_entity == NODE_ENTITY_NAME and fields['node_type'].startswith('process.'):
                        if fields['attributes'].get('sealed', False) is not True:
                            unsealed_node_pks.add(pk)

                    writer.write_entity_data(current_entity, pk, unique_identifiers[current_entity], fields)

    if unsealed_node_pks:
        raise exceptions.ExportValidationError(
            'All ProcessNodes must be sealed before they can be exported. '
            f"Node(s) with PK(s): {', '.join(str(pk) for pk in unsealed_node_pks)} is/are not sealed."
        )

    return exported_entity_pks