def test_missing_node_repo_folder_import(self, temp_dir):
        """
        Make sure `~aiida.tools.importexport.common.exceptions.CorruptArchive` is raised during import when missing
        Node repository folder.
        Create and export a Node and manually remove its repository folder in the export file.
        Attempt to import it and make sure `~aiida.tools.importexport.common.exceptions.CorruptArchive` is raised,
        due to the missing folder.
        """
        import tarfile

        from aiida.common.folders import SandboxFolder
        from aiida.tools.importexport.common.archive import extract_tar
        from aiida.tools.importexport.common.config import NODES_EXPORT_SUBFOLDER
        from aiida.tools.importexport.common.utils import export_shard_uuid

        node = orm.CalculationNode().store()
        node.seal()
        node_uuid = node.uuid

        node_repo = RepositoryFolder(section=Repository._section_name, uuid=node_uuid)  # pylint: disable=protected-access
        self.assertTrue(
            node_repo.exists(), msg='Newly created and stored Node should have had an existing repository folder'
        )

        # Export and reset db
        filename = os.path.join(temp_dir, 'export.aiida')
        export([node], filename=filename, file_format='tar.gz', silent=True)
        self.reset_database()

        # Untar export file, remove repository folder, re-tar
        node_shard_uuid = export_shard_uuid(node_uuid)
        node_top_folder = node_shard_uuid.split('/')[0]
        with SandboxFolder() as folder:
            extract_tar(filename, folder, silent=True, nodes_export_subfolder=NODES_EXPORT_SUBFOLDER)
            node_folder = folder.get_subfolder(os.path.join(NODES_EXPORT_SUBFOLDER, node_shard_uuid))
            self.assertTrue(
                node_folder.exists(), msg="The Node's repository folder should still exist in the export file"
            )

            # Removing the Node's repository folder from the export file
            shutil.rmtree(
                folder.get_subfolder(os.path.join(NODES_EXPORT_SUBFOLDER, node_top_folder)).abspath, ignore_errors=True
            )
            self.assertFalse(
                node_folder.exists(),
                msg="The Node's repository folder should now have been removed in the export file"
            )

            filename_corrupt = os.path.join(temp_dir, 'export_corrupt.aiida')
            with tarfile.open(filename_corrupt, 'w:gz', format=tarfile.PAX_FORMAT, dereference=True) as tar:
                tar.add(folder.abspath, arcname='')

        # Try to import, check it raises and check the raise message
        with self.assertRaises(exceptions.CorruptArchive) as exc:
            import_data(filename_corrupt, silent=True)

        self.assertIn(
            'Unable to find the repository folder for Node with UUID={}'.format(node_uuid), str(exc.exception)
        )
Exemplo n.º 2
0
    def test_empty_repo_folder_export(self, temp_dir):
        """Check a Node's empty repository folder is exported properly"""
        from aiida.common.folders import Folder
        from aiida.tools.importexport.dbexport import export_tree

        node = orm.Dict().store()
        node_uuid = node.uuid

        node_repo = RepositoryFolder(section=Repository._section_name, uuid=node_uuid)  # pylint: disable=protected-access
        self.assertTrue(
            node_repo.exists(), msg='Newly created and stored Node should have had an existing repository folder'
        )
        for filename, is_file in node_repo.get_content_list(only_paths=False):
            abspath_filename = os.path.join(node_repo.abspath, filename)
            if is_file:
                os.remove(abspath_filename)
            else:
                shutil.rmtree(abspath_filename, ignore_errors=False)
        self.assertFalse(
            node_repo.get_content_list(),
            msg=f'Repository folder should be empty, instead the following was found: {node_repo.get_content_list()}'
        )

        archive_variants = {
            'archive folder': os.path.join(temp_dir, 'export_tree'),
            'tar archive': os.path.join(temp_dir, 'export.tar.gz'),
            'zip archive': os.path.join(temp_dir, 'export.zip')
        }

        export_tree([node], folder=Folder(archive_variants['archive folder']))
        export([node], filename=archive_variants['tar archive'], file_format='tar.gz')
        export([node], filename=archive_variants['zip archive'], file_format='zip')

        for variant, filename in archive_variants.items():
            self.reset_database()
            node_count = orm.QueryBuilder().append(orm.Dict, project='uuid').count()
            self.assertEqual(node_count, 0, msg=f'After DB reset {node_count} Dict Nodes was (wrongly) found')

            import_data(filename)
            builder = orm.QueryBuilder().append(orm.Dict, project='uuid')
            imported_node_count = builder.count()
            self.assertEqual(
                imported_node_count,
                1,
                msg='After {} import a single Dict Node should have been found, '
                'instead {} was/were found'.format(variant, imported_node_count)
            )
            imported_node_uuid = builder.all()[0][0]
            self.assertEqual(
                imported_node_uuid,
                node_uuid,
                msg='The wrong UUID was found for the imported {}: '
                '{}. It should have been: {}'.format(variant, imported_node_uuid, node_uuid)
            )
Exemplo n.º 3
0
    def __init__(self, **kwargs):
        super(Node, self).__init__()

        self._temp_folder = None

        dbnode = kwargs.pop('dbnode', None)

        # Set the internal parameters
        # Can be redefined in the subclasses
        self._init_internal_params()

        if dbnode is not None:
            if not isinstance(dbnode, DbNode):
                raise TypeError("dbnode is not a DbNode instance")
            if dbnode.id is None:
                raise ValueError("If cannot load an aiida.orm.Node instance "
                                 "from an unsaved DbNode object.")
            if kwargs:
                raise ValueError("If you pass a dbnode, you cannot pass any "
                                 "further parameter")

            # If I am loading, I cannot modify it
            self._to_be_stored = False

            self._dbnode = dbnode

            # If this is changed, fix also the importer
            self._repo_folder = RepositoryFolder(section=self._section_name,
                                                 uuid=self._dbnode.uuid)

        else:
            # TODO: allow to get the user from the parameters
            user = get_automatic_user()

            self._dbnode = DbNode(user=user,
                                  uuid=get_new_uuid(),
                                  type=self._plugin_type_string)

            self._to_be_stored = True

            # As creating the temp folder may require some time on slow
            # filesystems, we defer its creation
            self._temp_folder = None
            # Used only before the first save
            self._attrs_cache = {}
            # If this is changed, fix also the importer
            self._repo_folder = RepositoryFolder(section=self._section_name,
                                                 uuid=self.uuid)

            # Automatically set all *other* attributes, if possible, otherwise
            # stop
            self._set_with_defaults(**kwargs)
Exemplo n.º 4
0
def _write_node_repositories(
    *, node_pks: Set[int], node_pk_2_uuid_mapping: Dict[int, str], writer: ArchiveWriterAbstract
):
    """Write all exported node repositories to the archive file."""
    with get_progress_reporter()(total=len(node_pks), desc='Exporting node repositories: ') as progress:

        for pk in node_pks:

            uuid = node_pk_2_uuid_mapping[pk]

            progress.set_description_str(f'Exporting node repositories: {pk}', refresh=False)
            progress.update()

            src = RepositoryFolder(section=Repository._section_name, uuid=uuid)  # pylint: disable=protected-access
            if not src.exists():
                raise exceptions.ArchiveExportError(
                    f'Unable to find the repository folder for Node with UUID={uuid} '
                    'in the local repository'
                )
            writer.write_node_repo_folder(uuid, src._abspath)  # pylint: disable=protected-access
Exemplo n.º 5
0
def get_paths_to_tar(query, verbosity=1):
    """
    Constucting a set that contains all the paths that have to be tarred
    :param query: The querybuilder instance
    :returns: a set of filenames
    """
    paths_to_tar = set()
    res = query.all()
    if verbosity:
        print('{} items to back up'.format(len(res)))
    for node, in res:
        if verbosity >1:
            print("Backing up {}".format(node))
        path = RepositoryFolder(
                section='node', #Node._section_name, # hardcoded to be independent of AiiDA Version
                uuid=node.uuid).abspath
        # path: [REPO]/repository/node/24/64/fd5a-749e-4572-b390-53efb766256a
        path_LevelUp = '/'.join(path.split('/')[:-1])
        # path_LevelUp: [REPO]/repository/node/24/64
        paths_to_tar.add(path_LevelUp)
    return sorted(paths_to_tar)
Exemplo n.º 6
0
    def _get_source_directory(self, item):
        """

        :param item:
        :return:
        """
        if type(item) == DbWorkflow:
            source_dir = os.path.normpath(
                RepositoryFolder(section=Workflow._section_name,
                                 uuid=item.uuid).abspath)
        elif type(item) == DbNode:
            source_dir = os.path.normpath(
                RepositoryFolder(section=Node._section_name,
                                 uuid=item.uuid).abspath)
        else:
            # Raise exception
            self._logger.error("Unexpected item type to backup: {}".format(
                type(item)))
            raise BackupError("Unexpected item type to backup: {}".format(
                type(item)))
        return source_dir
Exemplo n.º 7
0
    def test_missing_node_repo_folder_export(self, temp_dir):
        """
        Make sure `~aiida.tools.importexport.common.exceptions.ArchiveExportError` is raised during export when missing
        Node repository folder.
        Create and store a new Node and manually remove its repository folder.
        Attempt to export it and make sure `~aiida.tools.importexport.common.exceptions.ArchiveExportError` is raised,
        due to the missing folder.
        """
        node = orm.CalculationNode().store()
        node.seal()
        node_uuid = node.uuid

        node_repo = RepositoryFolder(section=Repository._section_name,
                                     uuid=node_uuid)  # pylint: disable=protected-access
        self.assertTrue(
            node_repo.exists(),
            msg=
            'Newly created and stored Node should have had an existing repository folder'
        )

        # Removing the Node's local repository folder
        shutil.rmtree(node_repo.abspath, ignore_errors=True)
        self.assertFalse(
            node_repo.exists(),
            msg=
            'Newly created and stored Node should have had its repository folder removed'
        )

        # Try to export, check it raises and check the raise message
        filename = os.path.join(temp_dir, 'export.tar.gz')
        with self.assertRaises(exceptions.ArchiveExportError) as exc:
            export([node], outfile=filename, silent=True)

        self.assertIn(
            'Unable to find the repository folder for Node with UUID={}'.
            format(node_uuid), str(exc.exception))
        self.assertFalse(os.path.exists(filename),
                         msg='The export file should not exist')
Exemplo n.º 8
0
    def store(self):
        """
        Stores the DbWorkflow object data in the database
        """
        if self._to_be_stored:
            self._dbworkflowinstance.save()

            if hasattr(self, '_params'):
                self.dbworkflowinstance.add_parameters(self._params, force=False)

            self._repo_folder = RepositoryFolder(section=self._section_name, uuid=self.uuid)
            self.repo_folder.replace_with_folder(self.get_temp_folder().abspath, move=True, overwrite=True)

            self._temp_folder = None
            self._to_be_stored = False

        # Important to allow to do w = WorkflowSubClass().store()
        return self
Exemplo n.º 9
0
    def _get_source_directory(self, item):
        """Retrieve the node repository folder

        :param item: Subclasses of Node.
        :type item: :class:`~aiida.orm.nodes.node.Node`

        :return: Normalized path to the Node's repository folder.
        :rtype: str
        """
        # pylint: disable=protected-access
        if isinstance(item, Node):
            source_dir = os.path.normpath(RepositoryFolder(section=Repository._section_name, uuid=item.uuid).abspath)
        else:
            # Raise exception
            msg = 'Unexpected item type to backup: %s'
            self._logger.error(msg, type(item))
            raise BackupError(msg % type(item))
        return source_dir
Exemplo n.º 10
0
def export_tree(entities=None,
                folder=None,
                allowed_licenses=None,
                forbidden_licenses=None,
                silent=False,
                include_comments=True,
                include_logs=True,
                **kwargs):
    """Export the entries passed in the 'entities' list to a file tree.

    .. deprecated:: 1.2.1
        Support for the parameter `what` will be removed in `v2.0.0`. Please use `entities` instead.

    :param entities: a list of entity instances; they can belong to different models/entities.
    :type entities: list

    :param folder: a temporary folder to build the archive before compression.
    :type folder: :py:class:`~aiida.common.folders.Folder`

    :param allowed_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type allowed_licenses: list

    :param forbidden_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type forbidden_licenses: list

    :param silent: suppress console prints and progress bar.
    :type silent: bool

    :param include_comments: In-/exclude export of comments for given node(s) in ``entities``.
        Default: True, *include* comments in export (as well as relevant users).
    :type include_comments: bool

    :param include_logs: In-/exclude export of logs for given node(s) in ``entities``.
        Default: True, *include* logs in export.
    :type include_logs: bool

    :param kwargs: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` what rule names
        are toggleable and what the defaults are.

    :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`: if there are any internal errors when
        exporting.
    :raises `~aiida.common.exceptions.LicensingException`: if any node is licensed under forbidden license.
    """
    from collections import defaultdict
    from aiida.tools.graph.graph_traversers import get_nodes_export

    if silent:
        logging.disable(level=logging.CRITICAL)

    EXPORT_LOGGER.debug('STARTING EXPORT...')

    # Backwards-compatibility
    entities = deprecated_parameters(
        old={
            'name': 'what',
            'value': kwargs.pop('what', None)
        },
        new={
            'name': 'entities',
            'value': entities
        },
    )

    type_check(
        entities, (list, tuple, set),
        msg='`entities` must be specified and given as a list of AiiDA entities'
    )
    entities = list(entities)

    type_check(
        folder, (Folder, ZipFolder),
        msg='`folder` must be specified and given as an AiiDA Folder entity')

    all_fields_info, unique_identifiers = get_all_fields_info()

    entities_starting_set = defaultdict(set)

    # The set that contains the nodes ids of the nodes that should be exported
    given_node_entry_ids = set()
    given_log_entry_ids = set()
    given_comment_entry_ids = set()

    # Instantiate progress bar - go through list of `entities`
    pbar_total = len(entities) + 1 if entities else 1
    progress_bar = get_progress_bar(total=pbar_total,
                                    leave=False,
                                    disable=silent)
    progress_bar.set_description_str('Collecting chosen entities',
                                     refresh=False)

    # I store a list of the actual dbnodes
    for entry in entities:
        progress_bar.update()

        # This returns the class name (as in imports). E.g. for a model node:
        # aiida.backends.djsite.db.models.DbNode
        # entry_class_string = get_class_string(entry)
        # Now a load the backend-independent name into entry_entity_name, e.g. Node!
        # entry_entity_name = schema_to_entity_names(entry_class_string)
        if issubclass(entry.__class__, orm.Group):
            entities_starting_set[GROUP_ENTITY_NAME].add(entry.uuid)
        elif issubclass(entry.__class__, orm.Node):
            entities_starting_set[NODE_ENTITY_NAME].add(entry.uuid)
            given_node_entry_ids.add(entry.pk)
        elif issubclass(entry.__class__, orm.Computer):
            entities_starting_set[COMPUTER_ENTITY_NAME].add(entry.uuid)
        else:
            raise exceptions.ArchiveExportError(
                'I was given {} ({}), which is not a Node, Computer, or Group instance'
                .format(entry, type(entry)))

    # Add all the nodes contained within the specified groups
    if GROUP_ENTITY_NAME in entities_starting_set:

        progress_bar.set_description_str('Retrieving Nodes from Groups ...',
                                         refresh=True)

        # Use single query instead of given_group.nodes iterator for performance.
        qh_groups = orm.QueryBuilder().append(
            orm.Group,
            filters={
                'uuid': {
                    'in': entities_starting_set[GROUP_ENTITY_NAME]
                }
            },
            tag='groups').queryhelp

        # Delete this import once the dbexport.zip module has been renamed
        from builtins import zip  # pylint: disable=redefined-builtin

        node_results = orm.QueryBuilder(**qh_groups).append(
            orm.Node, project=['id', 'uuid'], with_group='groups').all()
        if node_results:
            pks, uuids = map(list, zip(*node_results))
            entities_starting_set[NODE_ENTITY_NAME].update(uuids)
            given_node_entry_ids.update(pks)
            del node_results, pks, uuids

        progress_bar.update()

    # We will iteratively explore the AiiDA graph to find further nodes that should also be exported.
    # At the same time, we will create the links_uuid list of dicts to be exported

    progress_bar = get_progress_bar(total=1, disable=silent)
    progress_bar.set_description_str(
        'Getting provenance and storing links ...', refresh=True)

    traverse_output = get_nodes_export(starting_pks=given_node_entry_ids,
                                       get_links=True,
                                       **kwargs)
    node_ids_to_be_exported = traverse_output['nodes']
    graph_traversal_rules = traverse_output['rules']

    # A utility dictionary for mapping PK to UUID.
    if node_ids_to_be_exported:
        qbuilder = orm.QueryBuilder().append(
            orm.Node,
            project=('id', 'uuid'),
            filters={'id': {
                'in': node_ids_to_be_exported
            }},
        )
        node_pk_2_uuid_mapping = dict(qbuilder.all())
    else:
        node_pk_2_uuid_mapping = {}

    # The set of tuples now has to be transformed to a list of dicts
    links_uuid = [{
        'input': node_pk_2_uuid_mapping[link.source_id],
        'output': node_pk_2_uuid_mapping[link.target_id],
        'label': link.link_label,
        'type': link.link_type
    } for link in traverse_output['links']]

    progress_bar.update()

    # Progress bar initialization - Entities
    progress_bar = get_progress_bar(total=1, disable=silent)
    progress_bar.set_description_str('Initializing export of all entities',
                                     refresh=True)

    ## Universal "entities" attributed to all types of nodes
    # Logs
    if include_logs and node_ids_to_be_exported:
        # Get related log(s) - universal for all nodes
        builder = orm.QueryBuilder()
        builder.append(orm.Log,
                       filters={'dbnode_id': {
                           'in': node_ids_to_be_exported
                       }},
                       project='uuid')
        res = set(builder.all(flat=True))
        given_log_entry_ids.update(res)

    # Comments
    if include_comments and node_ids_to_be_exported:
        # Get related log(s) - universal for all nodes
        builder = orm.QueryBuilder()
        builder.append(orm.Comment,
                       filters={'dbnode_id': {
                           'in': node_ids_to_be_exported
                       }},
                       project='uuid')
        res = set(builder.all(flat=True))
        given_comment_entry_ids.update(res)

    # Here we get all the columns that we plan to project per entity that we would like to extract
    given_entities = set(entities_starting_set.keys())
    if node_ids_to_be_exported:
        given_entities.add(NODE_ENTITY_NAME)
    if given_log_entry_ids:
        given_entities.add(LOG_ENTITY_NAME)
    if given_comment_entry_ids:
        given_entities.add(COMMENT_ENTITY_NAME)

    progress_bar.update()

    if given_entities:
        progress_bar = get_progress_bar(total=len(given_entities),
                                        disable=silent)
        pbar_base_str = 'Preparing entities'

    entries_to_add = dict()
    for given_entity in given_entities:
        progress_bar.set_description_str(pbar_base_str +
                                         ' - {}s'.format(given_entity),
                                         refresh=False)
        progress_bar.update()

        project_cols = ['id']
        # The following gets a list of fields that we need,
        # e.g. user, mtime, uuid, computer
        entity_prop = all_fields_info[given_entity].keys()

        # Here we do the necessary renaming of properties
        for prop in entity_prop:
            # nprop contains the list of projections
            nprop = (file_fields_to_model_fields[given_entity][prop] if prop
                     in file_fields_to_model_fields[given_entity] else prop)
            project_cols.append(nprop)

        # Getting the ids that correspond to the right entity
        entry_uuids_to_add = entities_starting_set.get(given_entity, set())
        if not entry_uuids_to_add:
            if given_entity == LOG_ENTITY_NAME:
                entry_uuids_to_add = given_log_entry_ids
            elif given_entity == COMMENT_ENTITY_NAME:
                entry_uuids_to_add = given_comment_entry_ids
        elif given_entity == NODE_ENTITY_NAME:
            entry_uuids_to_add.update(
                {node_pk_2_uuid_mapping[_]
                 for _ in node_ids_to_be_exported})

        builder = orm.QueryBuilder()
        builder.append(entity_names_to_entities[given_entity],
                       filters={'uuid': {
                           'in': entry_uuids_to_add
                       }},
                       project=project_cols,
                       tag=given_entity,
                       outerjoin=True)
        entries_to_add[given_entity] = builder

    # TODO (Spyros) To see better! Especially for functional licenses
    # Check the licenses of exported data.
    if allowed_licenses is not None or forbidden_licenses is not None:
        builder = orm.QueryBuilder()
        builder.append(orm.Node,
                       project=['id', 'attributes.source.license'],
                       filters={'id': {
                           'in': node_ids_to_be_exported
                       }})
        # Skip those nodes where the license is not set (this is the standard behavior with Django)
        node_licenses = list(
            (a, b) for [a, b] in builder.all() if b is not None)
        check_licenses(node_licenses, allowed_licenses, forbidden_licenses)

    ############################################################
    ##### Start automatic recursive export data generation #####
    ############################################################
    EXPORT_LOGGER.debug('GATHERING DATABASE ENTRIES...')

    if entries_to_add:
        progress_bar = get_progress_bar(total=len(entries_to_add),
                                        disable=silent)

    export_data = defaultdict(dict)
    entity_separator = '_'
    for entity_name, partial_query in entries_to_add.items():

        progress_bar.set_description_str('Exporting {}s'.format(entity_name),
                                         refresh=False)
        progress_bar.update()

        foreign_fields = {
            k: v
            for k, v in all_fields_info[entity_name].items() if 'requires' in v
        }

        for value in foreign_fields.values():
            ref_model_name = value['requires']
            fill_in_query(partial_query, entity_name, ref_model_name,
                          [entity_name], entity_separator)

        for temp_d in partial_query.iterdict():
            for key in temp_d:
                # Get current entity
                current_entity = key.split(entity_separator)[-1]

                # This is a empty result of an outer join.
                # It should not be taken into account.
                if temp_d[key]['id'] is None:
                    continue

                export_data[current_entity].update({
                    temp_d[key]['id']:
                    serialize_dict(temp_d[key],
                                   remove_fields=['id'],
                                   rename_fields=model_fields_to_file_fields[
                                       current_entity])
                })

    # Close progress up until this point in order to print properly
    close_progress_bar(leave=False)

    #######################################
    # Manually manage attributes and extras
    #######################################
    # Pointer. Renaming, since Nodes have now technically been retrieved and "stored"
    all_node_pks = node_ids_to_be_exported

    model_data = sum(len(model_data) for model_data in export_data.values())
    if not model_data:
        EXPORT_LOGGER.log(msg='Nothing to store, exiting...',
                          level=LOG_LEVEL_REPORT)
        return
    EXPORT_LOGGER.log(
        msg='Exporting a total of {} database entries, of which {} are Nodes.'.
        format(model_data, len(all_node_pks)),
        level=LOG_LEVEL_REPORT)

    # Instantiate new progress bar
    progress_bar = get_progress_bar(total=1, leave=False, disable=silent)

    # ATTRIBUTES and EXTRAS
    EXPORT_LOGGER.debug('GATHERING NODE ATTRIBUTES AND EXTRAS...')
    node_attributes = {}
    node_extras = {}

    # Another QueryBuilder query to get the attributes and extras. TODO: See if this can be optimized
    if all_node_pks:
        all_nodes_query = orm.QueryBuilder().append(
            orm.Node,
            filters={'id': {
                'in': all_node_pks
            }},
            project=['id', 'attributes', 'extras'])

        progress_bar = get_progress_bar(total=all_nodes_query.count(),
                                        disable=silent)
        progress_bar.set_description_str('Exporting Attributes and Extras',
                                         refresh=False)

        for node_pk, attributes, extras in all_nodes_query.iterall():
            progress_bar.update()

            node_attributes[str(node_pk)] = attributes
            node_extras[str(node_pk)] = extras

    EXPORT_LOGGER.debug('GATHERING GROUP ELEMENTS...')
    groups_uuid = defaultdict(list)
    # If a group is in the exported data, we export the group/node correlation
    if GROUP_ENTITY_NAME in export_data:
        group_uuids_with_node_uuids = orm.QueryBuilder().append(
            orm.Group,
            filters={
                'id': {
                    'in': export_data[GROUP_ENTITY_NAME]
                }
            },
            project='uuid',
            tag='groups').append(orm.Node, project='uuid', with_group='groups')

        # This part is _only_ for the progress bar
        total_node_uuids_for_groups = group_uuids_with_node_uuids.count()
        if total_node_uuids_for_groups:
            progress_bar = get_progress_bar(total=total_node_uuids_for_groups,
                                            disable=silent)
            progress_bar.set_description_str('Exporting Groups ...',
                                             refresh=False)

        for group_uuid, node_uuid in group_uuids_with_node_uuids.iterall():
            progress_bar.update()

            groups_uuid[group_uuid].append(node_uuid)

    #######################################
    # Final check for unsealed ProcessNodes
    #######################################
    process_nodes = set()
    for node_pk, content in export_data.get(NODE_ENTITY_NAME, {}).items():
        if content['node_type'].startswith('process.'):
            process_nodes.add(node_pk)

    check_process_nodes_sealed(process_nodes)

    ######################################
    # Now collecting and storing
    ######################################
    # subfolder inside the export package
    nodesubfolder = folder.get_subfolder(NODES_EXPORT_SUBFOLDER,
                                         create=True,
                                         reset_limit=True)

    EXPORT_LOGGER.debug('ADDING DATA TO EXPORT ARCHIVE...')

    data = {
        'node_attributes': node_attributes,
        'node_extras': node_extras,
        'export_data': export_data,
        'links_uuid': links_uuid,
        'groups_uuid': groups_uuid
    }

    # N.B. We're really calling zipfolder.open (if exporting a zipfile)
    with folder.open('data.json', mode='w') as fhandle:
        # fhandle.write(json.dumps(data, cls=UUIDEncoder))
        fhandle.write(json.dumps(data))

    # Turn sets into lists to be able to export them as JSON metadata.
    for entity, entity_set in entities_starting_set.items():
        entities_starting_set[entity] = list(entity_set)

    metadata = {
        'aiida_version': get_version(),
        'export_version': EXPORT_VERSION,
        'all_fields_info': all_fields_info,
        'unique_identifiers': unique_identifiers,
        'export_parameters': {
            'graph_traversal_rules': graph_traversal_rules,
            'entities_starting_set': entities_starting_set,
            'include_comments': include_comments,
            'include_logs': include_logs
        }
    }

    with folder.open('metadata.json', 'w') as fhandle:
        fhandle.write(json.dumps(metadata))

    EXPORT_LOGGER.debug('ADDING REPOSITORY FILES TO EXPORT ARCHIVE...')

    # If there are no nodes, there are no repository files to store
    if all_node_pks:
        all_node_uuids = {node_pk_2_uuid_mapping[_] for _ in all_node_pks}

        progress_bar = get_progress_bar(total=len(all_node_uuids),
                                        disable=silent)
        pbar_base_str = 'Exporting repository - '

        for uuid in all_node_uuids:
            sharded_uuid = export_shard_uuid(uuid)

            progress_bar.set_description_str(
                pbar_base_str + 'UUID={}'.format(uuid.split('-')[0]),
                refresh=False)
            progress_bar.update()

            # Important to set create=False, otherwise creates twice a subfolder. Maybe this is a bug of insert_path?
            thisnodefolder = nodesubfolder.get_subfolder(sharded_uuid,
                                                         create=False,
                                                         reset_limit=True)

            # Make sure the node's repository folder was not deleted
            src = RepositoryFolder(section=Repository._section_name, uuid=uuid)  # pylint: disable=protected-access
            if not src.exists():
                raise exceptions.ArchiveExportError(
                    'Unable to find the repository folder for Node with UUID={} in the local repository'
                    .format(uuid))

            # In this way, I copy the content of the folder, and not the folder itself
            thisnodefolder.insert_path(src=src.abspath, dest_name='.')

    close_progress_bar(leave=False)

    # Reset logging level
    if silent:
        logging.disable(level=logging.NOTSET)
Exemplo n.º 11
0
 def __init__(self, uuid, is_stored, base_path=None):
     self._is_stored = is_stored
     self._base_path = base_path
     self._temp_folder = None
     self._repo_folder = RepositoryFolder(section=self._section_name,
                                          uuid=uuid)
Exemplo n.º 12
0
class Repository:
    """Class that represents the repository of a `Node` instance.

        .. deprecated:: 1.4.0
            This class has been deprecated and will be removed in `v2.0.0`.
    """

    # Name to be used for the Repository section
    _section_name = 'node'

    def __init__(self, uuid, is_stored, base_path=None):
        self._is_stored = is_stored
        self._base_path = base_path
        self._temp_folder = None
        self._repo_folder = RepositoryFolder(section=self._section_name,
                                             uuid=uuid)

    def __del__(self):
        """Clean the sandboxfolder if it was instantiated."""
        if getattr(self, '_temp_folder', None) is not None:
            self._temp_folder.erase()

    def validate_mutability(self):
        """Raise if the repository is immutable.

        :raises aiida.common.ModificationNotAllowed: if repository is marked as immutable because the corresponding node
            is stored
        """
        if self._is_stored:
            raise exceptions.ModificationNotAllowed(
                'cannot modify the repository after the node has been stored')

    @staticmethod
    def validate_object_key(key):
        """Validate the key of an object.

        :param key: an object key in the repository
        :raises ValueError: if the key is not a valid object key
        """
        if key and os.path.isabs(key):
            raise ValueError('the key must be a relative path')

    def list_objects(self, key=None):
        """Return a list of the objects contained in this repository, optionally in the given sub directory.

        :param key: fully qualified identifier for the object within the repository
        :return: a list of `File` named tuples representing the objects present in directory with the given key
        """
        folder = self._get_base_folder()

        if key:
            folder = folder.get_subfolder(key)

        objects = []

        for filename in folder.get_content_list():
            if os.path.isdir(os.path.join(folder.abspath, filename)):
                objects.append(File(filename, FileType.DIRECTORY))
            else:
                objects.append(File(filename, FileType.FILE))

        return sorted(objects, key=lambda x: x.name)

    def list_object_names(self, key=None):
        """Return a list of the object names contained in this repository, optionally in the given sub directory.

        :param key: fully qualified identifier for the object within the repository
        :return: a list of `File` named tuples representing the objects present in directory with the given key
        """
        return [entry.name for entry in self.list_objects(key)]

    def open(self, key, mode='r'):
        """Open a file handle to an object stored under the given key.

        :param key: fully qualified identifier for the object within the repository
        :param mode: the mode under which to open the handle
        """
        return open(self._get_base_folder().get_abs_path(key), mode=mode)

    def get_object(self, key):
        """Return the object identified by key.

        :param key: fully qualified identifier for the object within the repository
        :return: a `File` named tuple representing the object located at key
        :raises IOError: if no object with the given key exists
        """
        self.validate_object_key(key)

        try:
            directory, filename = key.rsplit(os.sep, 1)
        except ValueError:
            directory, filename = None, key

        folder = self._get_base_folder()

        if directory:
            folder = folder.get_subfolder(directory)

        filepath = os.path.join(folder.abspath, filename)

        if os.path.isdir(filepath):
            return File(filename, FileType.DIRECTORY)

        if os.path.isfile(filepath):
            return File(filename, FileType.FILE)

        raise IOError('object {} does not exist'.format(key))

    def get_object_content(self, key, mode='r'):
        """Return the content of a object identified by key.

        :param key: fully qualified identifier for the object within the repository
        :param mode: the mode under which to open the handle
        """
        with self.open(key, mode=mode) as handle:
            return handle.read()

    def put_object_from_tree(self,
                             path,
                             key=None,
                             contents_only=True,
                             force=False):
        """Store a new object under `key` with the contents of the directory located at `path` on this file system.

        .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised.
            This check can be avoided by using the `force` flag, but this should be used with extreme caution!

        :param path: absolute path of directory whose contents to copy to the repository
        :param key: fully qualified identifier for the object within the repository
        :param contents_only: boolean, if True, omit the top level directory of the path and only copy its contents.
        :param force: boolean, if True, will skip the mutability check
        :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False`
        """
        if not force:
            self.validate_mutability()

        self.validate_object_key(key)

        if not os.path.isabs(path):
            raise ValueError('the `path` must be an absolute path')

        folder = self._get_base_folder()

        if key:
            folder = folder.get_subfolder(key, create=True)

        if contents_only:
            for entry in os.listdir(path):
                folder.insert_path(os.path.join(path, entry))
        else:
            folder.insert_path(path)

    def put_object_from_file(self,
                             path,
                             key,
                             mode=None,
                             encoding=None,
                             force=False):
        """Store a new object under `key` with contents of the file located at `path` on this file system.

        .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised.
            This check can be avoided by using the `force` flag, but this should be used with extreme caution!

        :param path: absolute path of file whose contents to copy to the repository
        :param key: fully qualified identifier for the object within the repository
        :param mode: the file mode with which the object will be written
            Deprecated: will be removed in `v2.0.0`
        :param encoding: the file encoding with which the object will be written
            Deprecated: will be removed in `v2.0.0`
        :param force: boolean, if True, will skip the mutability check
        :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False`
        """
        # pylint: disable=unused-argument,no-member
        # Note that the defaults of `mode` and `encoding` had to be change to `None` from `w` and `utf-8` resptively, in
        # order to detect when they were being passed such that the deprecation warning can be emitted. The defaults did
        # not make sense and so ignoring them is justified, since the side-effect of this function, a file being copied,
        # will continue working the same.
        if mode is not None:
            warnings.warn(
                'the `mode` argument is deprecated and will be removed in `v2.0.0`',
                AiidaDeprecationWarning)

        if encoding is not None:
            warnings.warn(
                'the `encoding` argument is deprecated and will be removed in `v2.0.0`',
                AiidaDeprecationWarning)

        if not force:
            self.validate_mutability()

        self.validate_object_key(key)

        with open(path, mode='rb') as handle:
            self.put_object_from_filelike(handle,
                                          key,
                                          mode='wb',
                                          encoding=None)

    def put_object_from_filelike(self,
                                 handle,
                                 key,
                                 mode='w',
                                 encoding='utf8',
                                 force=False):
        """Store a new object under `key` with contents of filelike object `handle`.

        .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised.
            This check can be avoided by using the `force` flag, but this should be used with extreme caution!

        :param handle: filelike object with the content to be stored
        :param key: fully qualified identifier for the object within the repository
        :param mode: the file mode with which the object will be written
        :param encoding: the file encoding with which the object will be written
        :param force: boolean, if True, will skip the mutability check
        :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False`
        """
        if not force:
            self.validate_mutability()

        self.validate_object_key(key)

        folder = self._get_base_folder()

        while os.sep in key:
            basepath, key = key.split(os.sep, 1)
            folder = folder.get_subfolder(basepath, create=True)

        folder.create_file_from_filelike(handle,
                                         key,
                                         mode=mode,
                                         encoding=encoding)

    def delete_object(self, key, force=False):
        """Delete the object from the repository.

        .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised.
            This check can be avoided by using the `force` flag, but this should be used with extreme caution!

        :param key: fully qualified identifier for the object within the repository
        :param force: boolean, if True, will skip the mutability check
        :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False`
        """
        if not force:
            self.validate_mutability()

        self.validate_object_key(key)

        self._get_base_folder().remove_path(key)

    def erase(self, force=False):
        """Delete the repository folder.

        .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised.
            This check can be avoided by using the `force` flag, but this should be used with extreme caution!

        :param force: boolean, if True, will skip the mutability check
        :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False`
        """
        if not force:
            self.validate_mutability()

        self._get_base_folder().erase()

    def store(self):
        """Store the contents of the sandbox folder into the repository folder."""
        if self._is_stored:
            raise exceptions.ModificationNotAllowed(
                'repository is already stored')

        self._repo_folder.replace_with_folder(self._get_temp_folder().abspath,
                                              move=True,
                                              overwrite=True)
        self._is_stored = True

    def restore(self):
        """Move the contents from the repository folder back into the sandbox folder."""
        if not self._is_stored:
            raise exceptions.ModificationNotAllowed(
                'repository is not yet stored')

        self._temp_folder.replace_with_folder(self._repo_folder.abspath,
                                              move=True,
                                              overwrite=True)
        self._is_stored = False

    def _get_base_folder(self):
        """Return the base sub folder in the repository.

        :return: a Folder object.
        """
        if self._is_stored:
            folder = self._repo_folder
        else:
            folder = self._get_temp_folder()

        if self._base_path is not None:
            folder = folder.get_subfolder(self._base_path, reset_limit=True)
            folder.create()

        return folder

    def _get_temp_folder(self):
        """Return the temporary sandbox folder.

        :return: a SandboxFolder object mapping the node in the repository.
        """
        if self._temp_folder is None:
            self._temp_folder = SandboxFolder()

        return self._temp_folder
Exemplo n.º 13
0
    def __init__(self, **kwargs):
        from aiida.backends.djsite.db.models import DbNode
        super(Node, self).__init__()

        self._temp_folder = None

        dbnode = kwargs.pop('dbnode', None)

        # Set the internal parameters
        # Can be redefined in the subclasses
        self._init_internal_params()

        if dbnode is not None:
            if not isinstance(dbnode, DbNode):
                raise TypeError("dbnode is not a DbNode instance")
            if dbnode.pk is None:
                raise ValueError("If cannot load an aiida.orm.Node instance "
                                 "from an unsaved Django DbNode object.")
            if kwargs:
                raise ValueError("If you pass a dbnode, you cannot pass any "
                                 "further parameter")

            # If I am loading, I cannot modify it
            self._to_be_stored = False

            self._dbnode = dbnode

            # If this is changed, fix also the importer
            self._repo_folder = RepositoryFolder(section=self._section_name,
                                                 uuid=self._dbnode.uuid)

        # NO VALIDATION ON __init__ BY DEFAULT, IT IS TOO SLOW SINCE IT OFTEN
        # REQUIRES MULTIPLE DB HITS
        # try:
        #                # Note: the validation often requires to load at least one
        #                # attribute, and therefore it will take a lot of time
        #                # because it has to cache every attribute.
        #                self._validate()
        #            except ValidationError as e:
        #                raise DbContentError("The data in the DB with UUID={} is not "
        #                                     "valid for class {}: {}".format(
        #                    uuid, self.__class__.__name__, e.message))
        else:
            # TODO: allow to get the user from the parameters
            user = get_automatic_user()
            self._dbnode = DbNode(user=user,
                                  uuid=get_new_uuid(),
                                  type=self._plugin_type_string)

            self._to_be_stored = True

            # As creating the temp folder may require some time on slow
            # filesystems, we defer its creation
            self._temp_folder = None
            # Used only before the first save
            self._attrs_cache = {}
            # If this is changed, fix also the importer
            self._repo_folder = RepositoryFolder(section=self._section_name,
                                                 uuid=self.uuid)

            # Automatically set all *other* attributes, if possible, otherwise
            # stop
            self._set_with_defaults(**kwargs)
Exemplo n.º 14
0
    def __init__(self, **kwargs):
        """
        Initializes the Workflow super class, store the instance in the DB and in case
        stores the starting parameters.

        If initialized with an uuid the Workflow is loaded from the DB, if not a new
        workflow is generated and added to the DB following the stack frameworks. This
        means that only modules inside aiida.workflows are allowed to implements
        the workflow super calls and be stored. The caller names, modules and files are
        retrieved from the stack.

        :param uuid: a string with the uuid of the object to be loaded.
        :param params: a dictionary of storable objects to initialize the specific workflow
        :raise: NotExistent: if there is no entry of the desired workflow kind with
                             the given uuid.
        """

        from aiida.backends.djsite.db.models import DbWorkflow
        self._to_be_stored = True

        self._logger = logger.getChild(self.__class__.__name__)

        uuid = kwargs.pop('uuid', None)

        if uuid is not None:
            self._to_be_stored = False
            if kwargs:
                raise ValueError("If you pass a UUID, you cannot pass any further parameter")

            try:
                self._dbworkflowinstance = DbWorkflow.objects.get(uuid=uuid)

                # self.logger.info("Workflow found in the database, now retrieved")
                self._repo_folder = RepositoryFolder(section=self._section_name, uuid=self.uuid)

            except ObjectDoesNotExist:
                raise NotExistent("No entry with the UUID {} found".format(uuid))

        else:
            # ATTENTION: Do not move this code outside or encapsulate it in a function
            import inspect

            stack = inspect.stack()

            # cur_fr  = inspect.currentframe()
            #call_fr = inspect.getouterframes(cur_fr, 2)

            # Get all the caller data
            caller_frame = stack[1][0]
            caller_file = stack[1][1]
            caller_funct = stack[1][3]

            caller_module = inspect.getmodule(caller_frame)
            caller_module_class = caller_frame.f_locals.get('self', None).__class__

            if not caller_funct == "__init__":
                raise SystemError("A workflow must implement the __init__ class explicitly")

            # Test if the launcher is another workflow

            # print "caller_module", caller_module
            # print "caller_module_class", caller_module_class
            # print "caller_file", caller_file
            # print "caller_funct", caller_funct

            # Accept only the aiida.workflows packages
            if caller_module == None or not caller_module.__name__.startswith("aiida.workflows"):
                raise SystemError("The superclass can't be called directly")

            self.caller_module = caller_module.__name__
            self.caller_module_class = caller_module_class.__name__
            self.caller_file = caller_file
            self.caller_funct = caller_funct

            self._temp_folder = SandboxFolder()
            self.current_folder.insert_path(self.caller_file, self.caller_module_class)
            # self.store()

            # Test if there are parameters as input
            params = kwargs.pop('params', None)

            if params is not None:
                if type(params) is dict:
                    self.set_params(params)

            # This stores the MD5 as well, to test in case the workflow has
            # been modified after the launch
            self._dbworkflowinstance = DbWorkflow(user=get_automatic_user(),
                                                  module=self.caller_module,
                                                  module_class=self.caller_module_class,
                                                  script_path=self.caller_file,
                                                  script_md5=md5_file(self.caller_file))

        self.attach_calc_lazy_storage = {}
        self.attach_subwf_lazy_storage = {}
Exemplo n.º 15
0
    def _backup_needed_files(self, query_sets):
        from aiida.backends.djsite.settings.settings import REPOSITORY_PATH
        repository_path = os.path.normpath(REPOSITORY_PATH)

        parent_dir_set = set()
        copy_counter = 0

        dir_no_to_copy = 0
        for query_set in query_sets:
            dir_no_to_copy += query_set.count()

        self._logger.info(
            "Start copying {} directories".format(dir_no_to_copy))

        last_progress_print = datetime.datetime.now()
        percent_progress = 0

        for query_set in query_sets:
            for item in query_set.iterator():
                if type(item) == DbWorkflow:
                    source_dir = os.path.normpath(
                        RepositoryFolder(section=Workflow._section_name,
                                         uuid=item.uuid).abspath)
                elif type(item) == DbNode:
                    source_dir = os.path.normpath(
                        RepositoryFolder(section=Node._section_name,
                                         uuid=item.uuid).abspath)
                else:
                    # Raise exception
                    self._logger.error(
                        "Unexpected item type to backup: {}".format(
                            type(item)))
                    raise BackupError(
                        "Unexpected item type to backup: {}".format(
                            type(item)))

                # Get the relative directory without the / which
                # separates the repository_path from the relative_dir.
                relative_dir = source_dir[(len(repository_path) + 1):]
                destination_dir = os.path.join(self._backup_dir, relative_dir)

                # Remove the destination directory if it already exists
                if os.path.exists(destination_dir):
                    shutil.rmtree(destination_dir)

                # Copy the needed directory
                try:
                    shutil.copytree(source_dir, destination_dir, True, None)
                except EnvironmentError as e:
                    self._logger.warning(
                        "Problem copying directory {} ".format(source_dir) +
                        "to {}. ".format(destination_dir) +
                        "More information: {} (Error no: {})".format(
                            e.strerror, e.errno))
                    # Raise envEr

                # Extract the needed parent directories
                Backup._extract_parent_dirs(relative_dir, parent_dir_set)
                copy_counter += 1

                if (self._logger.getEffectiveLevel() <= logging.INFO
                        and (datetime.datetime.now() -
                             last_progress_print).seconds > 60):
                    last_progress_print = datetime.datetime.now()
                    percent_progress = (copy_counter * 100 / dir_no_to_copy)
                    self._logger.info(
                        "Copied {} ".format(copy_counter) +
                        "directories [{}]".format(item.__class__.__name__, ) +
                        " ({}/100)".format(percent_progress))

                if (self._logger.getEffectiveLevel() <= logging.INFO
                        and percent_progress <
                    (copy_counter * 100 / dir_no_to_copy)):
                    percent_progress = (copy_counter * 100 / dir_no_to_copy)
                    last_progress_print = datetime.datetime.now()
                    self._logger.info(
                        "Copied {} ".format(copy_counter) +
                        "directories [{}]".format(item.__class__.__name__, ) +
                        " ({}/100)".format(percent_progress))

        self._logger.info("{} directories copied".format(copy_counter))

        self._logger.info("Start setting permissions")
        perm_counter = 0
        for tempRelPath in parent_dir_set:
            try:
                shutil.copystat(os.path.join(repository_path, tempRelPath),
                                os.path.join(self._backup_dir, tempRelPath))
            except OSError as e:
                self._logger.warning(
                    "Problem setting permissions to directory " +
                    "{}.".format(os.path.join(self._backup_dir, tempRelPath)))
                self._logger.warning(os.path.join(repository_path,
                                                  tempRelPath))
                self._logger.warning(
                    "More information: " +
                    "{} (Error no: {})".format(e.strerror, e.errno))
            perm_counter += 1

        self._logger.info("Set correct permissions "
                          "to {} directories.".format(perm_counter))

        self._logger.info("End of backup")
        self._logger.info("Backed up objects with modification timestamp "
                          "less or equal to {}".format(self._oldest_object_bk))
Exemplo n.º 16
0
def write_to_archive(
    folder: Union[Folder, ZipFolder],
    metadata: dict,
    all_node_uuids: Set[str],
    export_data: Dict[str, Dict[int, dict]],
    node_attributes: Dict[str, dict],
    node_extras: Dict[str, dict],
    groups_uuid: Dict[str, List[str]],
    links_uuid: List[dict],
    silent: bool,
) -> None:
    """Store data to the archive."""
    ######################################
    # Now collecting and storing
    ######################################
    # subfolder inside the export package
    nodesubfolder = folder.get_subfolder(NODES_EXPORT_SUBFOLDER,
                                         create=True,
                                         reset_limit=True)

    EXPORT_LOGGER.debug("ADDING DATA TO EXPORT ARCHIVE...")

    data = {
        "node_attributes": node_attributes,
        "node_extras": node_extras,
        "export_data": export_data,
        "links_uuid": links_uuid,
        "groups_uuid": groups_uuid,
    }

    # N.B. We're really calling zipfolder.open (if exporting a zipfile)
    with folder.open("data.json", mode="w") as fhandle:
        # fhandle.write(json.dumps(data, cls=UUIDEncoder))
        fhandle.write(json.dumps(data))

    with folder.open("metadata.json", "w") as fhandle:
        fhandle.write(json.dumps(metadata))

    EXPORT_LOGGER.debug("ADDING REPOSITORY FILES TO EXPORT ARCHIVE...")

    # If there are no nodes, there are no repository files to store
    if all_node_uuids:

        progress_bar = get_progress_bar(total=len(all_node_uuids),
                                        disable=silent)
        pbar_base_str = "Exporting repository - "

        for uuid in all_node_uuids:
            sharded_uuid = export_shard_uuid(uuid)

            progress_bar.set_description_str(
                f"{pbar_base_str}UUID={uuid.split('-')[0]}", refresh=False)
            progress_bar.update()

            # Important to set create=False, otherwise creates twice a subfolder.
            # Maybe this is a bug of insert_path?
            thisnodefolder = nodesubfolder.get_subfolder(sharded_uuid,
                                                         create=False,
                                                         reset_limit=True)

            # Make sure the node's repository folder was not deleted
            src = RepositoryFolder(section=Repository._section_name, uuid=uuid)  # pylint: disable=protected-access
            if not src.exists():
                raise exceptions.ArchiveExportError(
                    f"Unable to find the repository folder for Node with UUID={uuid} "
                    "in the local repository")

            # In this way, I copy the content of the folder, and not the folder itself
            thisnodefolder.insert_path(src=src.abspath, dest_name=".")
Exemplo n.º 17
0
def export_tree(what,
                folder,
                allowed_licenses=None,
                forbidden_licenses=None,
                silent=False,
                include_comments=True,
                include_logs=True,
                **kwargs):
    """Export the entries passed in the 'what' list to a file tree.

    :param what: a list of entity instances; they can belong to different models/entities.
    :type what: list

    :param folder: a temporary folder to build the archive before compression.
    :type folder: :py:class:`~aiida.common.folders.Folder`

    :param allowed_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type allowed_licenses: list

    :param forbidden_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the
        list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False
        otherwise.
    :type forbidden_licenses: list

    :param silent: suppress prints.
    :type silent: bool

    :param include_comments: In-/exclude export of comments for given node(s) in ``what``.
        Default: True, *include* comments in export (as well as relevant users).
    :type include_comments: bool

    :param include_logs: In-/exclude export of logs for given node(s) in ``what``.
        Default: True, *include* logs in export.
    :type include_logs: bool

    :param kwargs: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` what rule names
        are toggleable and what the defaults are.

    :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`: if there are any internal errors when
        exporting.
    :raises `~aiida.common.exceptions.LicensingException`: if any node is licensed under forbidden license.
    """
    from collections import defaultdict

    if not silent:
        print('STARTING EXPORT...')

    all_fields_info, unique_identifiers = get_all_fields_info()

    entities_starting_set = defaultdict(set)

    # The set that contains the nodes ids of the nodes that should be exported
    given_data_entry_ids = set()
    given_calculation_entry_ids = set()
    given_group_entry_ids = set()
    given_computer_entry_ids = set()
    given_groups = set()
    given_log_entry_ids = set()
    given_comment_entry_ids = set()

    # I store a list of the actual dbnodes
    for entry in what:
        # This returns the class name (as in imports). E.g. for a model node:
        # aiida.backends.djsite.db.models.DbNode
        # entry_class_string = get_class_string(entry)
        # Now a load the backend-independent name into entry_entity_name, e.g. Node!
        # entry_entity_name = schema_to_entity_names(entry_class_string)
        if issubclass(entry.__class__, orm.Group):
            entities_starting_set[GROUP_ENTITY_NAME].add(entry.uuid)
            given_group_entry_ids.add(entry.id)
            given_groups.add(entry)
        elif issubclass(entry.__class__, orm.Node):
            entities_starting_set[NODE_ENTITY_NAME].add(entry.uuid)
            if issubclass(entry.__class__, orm.Data):
                given_data_entry_ids.add(entry.pk)
            elif issubclass(entry.__class__, orm.ProcessNode):
                given_calculation_entry_ids.add(entry.pk)
        elif issubclass(entry.__class__, orm.Computer):
            entities_starting_set[COMPUTER_ENTITY_NAME].add(entry.uuid)
            given_computer_entry_ids.add(entry.pk)
        else:
            raise exceptions.ArchiveExportError(
                'I was given {} ({}), which is not a Node, Computer, or Group instance'
                .format(entry, type(entry)))

    # Add all the nodes contained within the specified groups
    for group in given_groups:
        for entry in group.nodes:
            entities_starting_set[NODE_ENTITY_NAME].add(entry.uuid)
            if issubclass(entry.__class__, orm.Data):
                given_data_entry_ids.add(entry.pk)
            elif issubclass(entry.__class__, orm.ProcessNode):
                given_calculation_entry_ids.add(entry.pk)

    for entity, entity_set in entities_starting_set.items():
        entities_starting_set[entity] = list(entity_set)

    # We will iteratively explore the AiiDA graph to find further nodes that
    # should also be exported.
    # At the same time, we will create the links_uuid list of dicts to be exported

    if not silent:
        print('RETRIEVING LINKED NODES AND STORING LINKS...')

    to_be_exported, links_uuid, graph_traversal_rules = retrieve_linked_nodes(
        given_calculation_entry_ids, given_data_entry_ids, **kwargs)

    ## Universal "entities" attributed to all types of nodes
    # Logs
    if include_logs and to_be_exported:
        # Get related log(s) - universal for all nodes
        builder = orm.QueryBuilder()
        builder.append(orm.Log,
                       filters={'dbnode_id': {
                           'in': to_be_exported
                       }},
                       project='id')
        res = {_[0] for _ in builder.all()}
        given_log_entry_ids.update(res)

    # Comments
    if include_comments and to_be_exported:
        # Get related log(s) - universal for all nodes
        builder = orm.QueryBuilder()
        builder.append(orm.Comment,
                       filters={'dbnode_id': {
                           'in': to_be_exported
                       }},
                       project='id')
        res = {_[0] for _ in builder.all()}
        given_comment_entry_ids.update(res)

    # Here we get all the columns that we plan to project per entity that we
    # would like to extract
    given_entities = list()
    if given_group_entry_ids:
        given_entities.append(GROUP_ENTITY_NAME)
    if to_be_exported:
        given_entities.append(NODE_ENTITY_NAME)
    if given_computer_entry_ids:
        given_entities.append(COMPUTER_ENTITY_NAME)
    if given_log_entry_ids:
        given_entities.append(LOG_ENTITY_NAME)
    if given_comment_entry_ids:
        given_entities.append(COMMENT_ENTITY_NAME)

    entries_to_add = dict()
    for given_entity in given_entities:
        project_cols = ['id']
        # The following gets a list of fields that we need,
        # e.g. user, mtime, uuid, computer
        entity_prop = all_fields_info[given_entity].keys()

        # Here we do the necessary renaming of properties
        for prop in entity_prop:
            # nprop contains the list of projections
            nprop = (file_fields_to_model_fields[given_entity][prop] if prop
                     in file_fields_to_model_fields[given_entity] else prop)
            project_cols.append(nprop)

        # Getting the ids that correspond to the right entity
        if given_entity == GROUP_ENTITY_NAME:
            entry_ids_to_add = given_group_entry_ids
        elif given_entity == NODE_ENTITY_NAME:
            entry_ids_to_add = to_be_exported
        elif given_entity == COMPUTER_ENTITY_NAME:
            entry_ids_to_add = given_computer_entry_ids
        elif given_entity == LOG_ENTITY_NAME:
            entry_ids_to_add = given_log_entry_ids
        elif given_entity == COMMENT_ENTITY_NAME:
            entry_ids_to_add = given_comment_entry_ids

        builder = orm.QueryBuilder()
        builder.append(entity_names_to_entities[given_entity],
                       filters={'id': {
                           'in': entry_ids_to_add
                       }},
                       project=project_cols,
                       tag=given_entity,
                       outerjoin=True)
        entries_to_add[given_entity] = builder

    # TODO (Spyros) To see better! Especially for functional licenses
    # Check the licenses of exported data.
    if allowed_licenses is not None or forbidden_licenses is not None:
        builder = orm.QueryBuilder()
        builder.append(orm.Node,
                       project=['id', 'attributes.source.license'],
                       filters={'id': {
                           'in': to_be_exported
                       }})
        # Skip those nodes where the license is not set (this is the standard behavior with Django)
        node_licenses = list(
            (a, b) for [a, b] in builder.all() if b is not None)
        check_licenses(node_licenses, allowed_licenses, forbidden_licenses)

    ############################################################
    ##### Start automatic recursive export data generation #####
    ############################################################
    if not silent:
        print('STORING DATABASE ENTRIES...')

    export_data = dict()
    entity_separator = '_'
    for entity_name, partial_query in entries_to_add.items():

        foreign_fields = {
            k: v
            for k, v in all_fields_info[entity_name].items()
            # all_fields_info[model_name].items()
            if 'requires' in v
        }

        for value in foreign_fields.values():
            ref_model_name = value['requires']
            fill_in_query(partial_query, entity_name, ref_model_name,
                          [entity_name], entity_separator)

        for temp_d in partial_query.iterdict():
            for k in temp_d.keys():
                # Get current entity
                current_entity = k.split(entity_separator)[-1]

                # This is a empty result of an outer join.
                # It should not be taken into account.
                if temp_d[k]['id'] is None:
                    continue

                temp_d2 = {
                    temp_d[k]['id']:
                    serialize_dict(temp_d[k],
                                   remove_fields=['id'],
                                   rename_fields=model_fields_to_file_fields[
                                       current_entity])
                }
                try:
                    export_data[current_entity].update(temp_d2)
                except KeyError:
                    export_data[current_entity] = temp_d2

    #######################################
    # Manually manage attributes and extras
    #######################################
    # I use .get because there may be no nodes to export
    all_nodes_pk = list()
    if NODE_ENTITY_NAME in export_data:
        all_nodes_pk.extend(export_data.get(NODE_ENTITY_NAME).keys())

    if sum(len(model_data) for model_data in export_data.values()) == 0:
        if not silent:
            print('No nodes to store, exiting...')
        return

    if not silent:
        print('Exporting a total of {} db entries, of which {} nodes.'.format(
            sum(len(model_data) for model_data in export_data.values()),
            len(all_nodes_pk)))

    # ATTRIBUTES and EXTRAS
    if not silent:
        print('STORING NODE ATTRIBUTES AND EXTRAS...')
    node_attributes = {}
    node_extras = {}

    # A second QueryBuilder query to get the attributes and extras. See if this can be optimized
    if all_nodes_pk:
        all_nodes_query = orm.QueryBuilder()
        all_nodes_query.append(orm.Node,
                               filters={'id': {
                                   'in': all_nodes_pk
                               }},
                               project=['id', 'attributes', 'extras'])
        for res_pk, res_attributes, res_extras in all_nodes_query.iterall():
            node_attributes[str(res_pk)] = res_attributes
            node_extras[str(res_pk)] = res_extras

    if not silent:
        print('STORING GROUP ELEMENTS...')
    groups_uuid = dict()
    # If a group is in the exported date, we export the group/node correlation
    if GROUP_ENTITY_NAME in export_data:
        for curr_group in export_data[GROUP_ENTITY_NAME]:
            group_uuid_qb = orm.QueryBuilder()
            group_uuid_qb.append(entity_names_to_entities[GROUP_ENTITY_NAME],
                                 filters={'id': {
                                     '==': curr_group
                                 }},
                                 project=['uuid'],
                                 tag='group')
            group_uuid_qb.append(entity_names_to_entities[NODE_ENTITY_NAME],
                                 project=['uuid'],
                                 with_group='group')
            for res in group_uuid_qb.iterall():
                if str(res[0]) in groups_uuid:
                    groups_uuid[str(res[0])].append(str(res[1]))
                else:
                    groups_uuid[str(res[0])] = [str(res[1])]

    #######################################
    # Final check for unsealed ProcessNodes
    #######################################
    process_nodes = set()
    for node_pk, content in export_data.get(NODE_ENTITY_NAME, {}).items():
        if content['node_type'].startswith('process.'):
            process_nodes.add(node_pk)

    check_process_nodes_sealed(process_nodes)

    ######################################
    # Now I store
    ######################################
    # subfolder inside the export package
    nodesubfolder = folder.get_subfolder(NODES_EXPORT_SUBFOLDER,
                                         create=True,
                                         reset_limit=True)

    if not silent:
        print('STORING DATA...')

    data = {
        'node_attributes': node_attributes,
        'node_extras': node_extras,
        'export_data': export_data,
        'links_uuid': links_uuid,
        'groups_uuid': groups_uuid
    }

    # N.B. We're really calling zipfolder.open (if exporting a zipfile)
    with folder.open('data.json', mode='w') as fhandle:
        # fhandle.write(json.dumps(data, cls=UUIDEncoder))
        fhandle.write(json.dumps(data))

    # Add proper signature to unique identifiers & all_fields_info
    # Ignore if a key doesn't exist in any of the two dictionaries

    metadata = {
        'aiida_version': get_version(),
        'export_version': EXPORT_VERSION,
        'all_fields_info': all_fields_info,
        'unique_identifiers': unique_identifiers,
        'export_parameters': {
            'graph_traversal_rules': graph_traversal_rules,
            'entities_starting_set': entities_starting_set,
            'include_comments': include_comments,
            'include_logs': include_logs
        }
    }

    with folder.open('metadata.json', 'w') as fhandle:
        fhandle.write(json.dumps(metadata))

    if silent is not True:
        print('STORING REPOSITORY FILES...')

    # If there are no nodes, there are no repository files to store
    if all_nodes_pk:
        # Large speed increase by not getting the node itself and looping in memory in python, but just getting the uuid
        uuid_query = orm.QueryBuilder()
        uuid_query.append(orm.Node,
                          filters={'id': {
                              'in': all_nodes_pk
                          }},
                          project=['uuid'])
        for res in uuid_query.all():
            uuid = str(res[0])
            sharded_uuid = export_shard_uuid(uuid)

            # Important to set create=False, otherwise creates twice a subfolder. Maybe this is a bug of insert_path?
            thisnodefolder = nodesubfolder.get_subfolder(sharded_uuid,
                                                         create=False,
                                                         reset_limit=True)

            # Make sure the node's repository folder was not deleted
            src = RepositoryFolder(section=Repository._section_name, uuid=uuid)  # pylint: disable=protected-access
            if not src.exists():
                raise exceptions.ArchiveExportError(
                    'Unable to find the repository folder for Node with UUID={} in the local repository'
                    .format(uuid))

            # In this way, I copy the content of the folder, and not the folder itself
            thisnodefolder.insert_path(src=src.abspath, dest_name='.')
Exemplo n.º 18
0
def import_data_sqla(in_path,
                     group=None,
                     ignore_unknown_nodes=False,
                     extras_mode_existing='kcl',
                     extras_mode_new='import',
                     comment_mode='newest',
                     silent=False,
                     **kwargs):
    """Import exported AiiDA archive to the AiiDA database and repository.

    Specific for the SQLAlchemy backend.
    If ``in_path`` is a folder, calls extract_tree; otherwise, tries to detect the compression format
    (zip, tar.gz, tar.bz2, ...) and calls the correct function.

    :param in_path: the path to a file or folder that can be imported in AiiDA.
    :type in_path: str

    :param group: Group wherein all imported Nodes will be placed.
    :type group: :py:class:`~aiida.orm.groups.Group`

    :param extras_mode_existing: 3 letter code that will identify what to do with the extras import.
        The first letter acts on extras that are present in the original node and not present in the imported node.
        Can be either:
        'k' (keep it) or
        'n' (do not keep it).
        The second letter acts on the imported extras that are not present in the original node.
        Can be either:
        'c' (create it) or
        'n' (do not create it).
        The third letter defines what to do in case of a name collision.
        Can be either:
        'l' (leave the old value),
        'u' (update with a new value),
        'd' (delete the extra), or
        'a' (ask what to do if the content is different).
    :type extras_mode_existing: str

    :param extras_mode_new: 'import' to import extras of new nodes or 'none' to ignore them.
    :type extras_mode_new: str

    :param comment_mode: Comment import modes (when same UUIDs are found).
        Can be either:
        'newest' (will keep the Comment with the most recent modification time (mtime)) or
        'overwrite' (will overwrite existing Comments with the ones from the import file).
    :type comment_mode: str

    :param silent: suppress progress bar and summary.
    :type silent: bool

    :return: New and existing Nodes and Links.
    :rtype: dict

    :raises `~aiida.tools.importexport.common.exceptions.ImportValidationError`: if parameters or the contents of
        `metadata.json` or `data.json` can not be validated.
    :raises `~aiida.tools.importexport.common.exceptions.CorruptArchive`: if the provided archive at ``in_path`` is
        corrupted.
    :raises `~aiida.tools.importexport.common.exceptions.IncompatibleArchiveVersionError`: if the provided archive's
        export version is not equal to the export version of AiiDA at the moment of import.
    :raises `~aiida.tools.importexport.common.exceptions.ArchiveImportError`: if there are any internal errors when
        importing.
    :raises `~aiida.tools.importexport.common.exceptions.ImportUniquenessError`: if a new unique entity can not be
        created.
    """
    from aiida.backends.sqlalchemy.models.node import DbNode, DbLink
    from aiida.backends.sqlalchemy.utils import flag_modified

    # This is the export version expected by this function
    expected_export_version = StrictVersion(EXPORT_VERSION)

    # The returned dictionary with new and existing nodes and links
    ret_dict = {}

    # Initial check(s)
    if group:
        if not isinstance(group, Group):
            raise exceptions.ImportValidationError(
                'group must be a Group entity')
        elif not group.is_stored:
            group.store()

    if silent:
        logging.disable(level=logging.CRITICAL)

    ################
    # EXTRACT DATA #
    ################
    # The sandbox has to remain open until the end
    with SandboxFolder() as folder:
        if os.path.isdir(in_path):
            extract_tree(in_path, folder)
        else:
            if tarfile.is_tarfile(in_path):
                extract_tar(in_path,
                            folder,
                            silent=silent,
                            nodes_export_subfolder=NODES_EXPORT_SUBFOLDER,
                            **kwargs)
            elif zipfile.is_zipfile(in_path):
                extract_zip(in_path,
                            folder,
                            silent=silent,
                            nodes_export_subfolder=NODES_EXPORT_SUBFOLDER,
                            **kwargs)
            else:
                raise exceptions.ImportValidationError(
                    'Unable to detect the input file format, it is neither a '
                    'tar file, nor a (possibly compressed) zip file.')

        if not folder.get_content_list():
            raise exceptions.CorruptArchive(
                'The provided file/folder ({}) is empty'.format(in_path))
        try:
            IMPORT_LOGGER.debug('CACHING metadata.json')
            with open(folder.get_abs_path('metadata.json'),
                      encoding='utf8') as fhandle:
                metadata = json.load(fhandle)

            IMPORT_LOGGER.debug('CACHING data.json')
            with open(folder.get_abs_path('data.json'),
                      encoding='utf8') as fhandle:
                data = json.load(fhandle)
        except IOError as error:
            raise exceptions.CorruptArchive(
                'Unable to find the file {} in the import file or folder'.
                format(error.filename))

        ######################
        # PRELIMINARY CHECKS #
        ######################
        export_version = StrictVersion(str(metadata['export_version']))
        if export_version != expected_export_version:
            msg = 'Export file version is {}, can import only version {}'\
                    .format(metadata['export_version'], expected_export_version)
            if export_version < expected_export_version:
                msg += "\nUse 'verdi export migrate' to update this export file."
            else:
                msg += '\nUpdate your AiiDA version in order to import this file.'

            raise exceptions.IncompatibleArchiveVersionError(msg)

        start_summary(in_path, comment_mode, extras_mode_new,
                      extras_mode_existing)

        ###################################################################
        #           CREATE UUID REVERSE TABLES AND CHECK IF               #
        #              I HAVE ALL NODES FOR THE LINKS                     #
        ###################################################################
        IMPORT_LOGGER.debug(
            'CHECKING IF NODES FROM LINKS ARE IN DB OR ARCHIVE...')

        linked_nodes = set(
            chain.from_iterable(
                (l['input'], l['output']) for l in data['links_uuid']))
        group_nodes = set(chain.from_iterable(data['groups_uuid'].values()))

        # Check that UUIDs are valid
        linked_nodes = set(x for x in linked_nodes if validate_uuid(x))
        group_nodes = set(x for x in group_nodes if validate_uuid(x))

        import_nodes_uuid = set()
        for value in data['export_data'].get(NODE_ENTITY_NAME, {}).values():
            import_nodes_uuid.add(value['uuid'])

        unknown_nodes = linked_nodes.union(group_nodes) - import_nodes_uuid

        if unknown_nodes and not ignore_unknown_nodes:
            raise exceptions.DanglingLinkError(
                'The import file refers to {} nodes with unknown UUID, therefore it cannot be imported. Either first '
                'import the unknown nodes, or export also the parents when exporting. The unknown UUIDs are:\n'
                ''.format(len(unknown_nodes)) +
                '\n'.join('* {}'.format(uuid) for uuid in unknown_nodes))

        ###################################
        # DOUBLE-CHECK MODEL DEPENDENCIES #
        ###################################
        # The entity import order. It is defined by the database model relationships.
        entity_order = [
            USER_ENTITY_NAME, COMPUTER_ENTITY_NAME, NODE_ENTITY_NAME,
            GROUP_ENTITY_NAME, LOG_ENTITY_NAME, COMMENT_ENTITY_NAME
        ]

        #  I make a new list that contains the entity names:
        # eg: ['User', 'Computer', 'Node', 'Group']
        for import_field_name in metadata['all_fields_info']:
            if import_field_name not in entity_order:
                raise exceptions.ImportValidationError(
                    "You are trying to import an unknown model '{}'!".format(
                        import_field_name))

        for idx, entity_name in enumerate(entity_order):
            dependencies = []
            # for every field, I checked the dependencies given as value for key requires
            for field in metadata['all_fields_info'][entity_name].values():
                try:
                    dependencies.append(field['requires'])
                except KeyError:
                    # (No ForeignKey)
                    pass
            for dependency in dependencies:
                if dependency not in entity_order[:idx]:
                    raise exceptions.ArchiveImportError(
                        'Entity {} requires {} but would be loaded first; stopping...'
                        .format(entity_name, dependency))

        ###################################################
        # CREATE IMPORT DATA DIRECT UNIQUE_FIELD MAPPINGS #
        ###################################################
        # This is nested dictionary of entity_name:{id:uuid}
        # to map one id (the pk) to a different one.
        # One of the things to remove for v0.4
        # {
        # 'Node': {2362: '82a897b5-fb3a-47d7-8b22-c5fe1b4f2c14',
        #           2363: 'ef04aa5d-99e7-4bfd-95ef-fe412a6a3524', 2364: '1dc59576-af21-4d71-81c2-bac1fc82a84a'},
        # 'User': {1: 'aiida@localhost'}
        # }
        IMPORT_LOGGER.debug('CREATING PK-2-UUID/EMAIL MAPPING...')
        import_unique_ids_mappings = {}
        # Export data since v0.3 contains the keys entity_name
        for entity_name, import_data in data['export_data'].items():
            # Again I need the entity_name since that's what's being stored since 0.3
            if entity_name in metadata['unique_identifiers']:
                # I have to reconvert the pk to integer
                import_unique_ids_mappings[entity_name] = {
                    int(k): v[metadata['unique_identifiers'][entity_name]]
                    for k, v in import_data.items()
                }
        ###############
        # IMPORT DATA #
        ###############
        # DO ALL WITH A TRANSACTION
        import aiida.backends.sqlalchemy

        session = aiida.backends.sqlalchemy.get_scoped_session()

        try:
            foreign_ids_reverse_mappings = {}
            new_entries = {}
            existing_entries = {}

            IMPORT_LOGGER.debug('GENERATING LIST OF DATA...')

            # Instantiate progress bar
            progress_bar = get_progress_bar(total=1,
                                            leave=False,
                                            disable=silent)
            pbar_base_str = 'Generating list of data - '

            # Get total entities from data.json
            # To be used with progress bar
            number_of_entities = 0

            # I first generate the list of data
            for entity_name in entity_order:
                entity = entity_names_to_entities[entity_name]
                # I get the unique identifier, since v0.3 stored under entity_name
                unique_identifier = metadata['unique_identifiers'].get(
                    entity_name, None)

                # so, new_entries. Also, since v0.3 it makes more sense to use the entity_name
                new_entries[entity_name] = {}
                existing_entries[entity_name] = {}
                foreign_ids_reverse_mappings[entity_name] = {}

                # Not necessarily all models are exported
                if entity_name in data['export_data']:

                    IMPORT_LOGGER.debug('  %s...', entity_name)

                    progress_bar.set_description_str(pbar_base_str +
                                                     entity_name,
                                                     refresh=False)
                    number_of_entities += len(data['export_data'][entity_name])

                    if unique_identifier is not None:
                        import_unique_ids = set(
                            v[unique_identifier]
                            for v in data['export_data'][entity_name].values())

                        relevant_db_entries = {}
                        if import_unique_ids:
                            builder = QueryBuilder()
                            builder.append(entity,
                                           filters={
                                               unique_identifier: {
                                                   'in': import_unique_ids
                                               }
                                           },
                                           project='*')

                            if builder.count():
                                progress_bar = get_progress_bar(
                                    total=builder.count(), disable=silent)
                                for object_ in builder.iterall():
                                    progress_bar.update()

                                    relevant_db_entries.update({
                                        getattr(object_[0], unique_identifier):
                                        object_[0]
                                    })

                            foreign_ids_reverse_mappings[entity_name] = {
                                k: v.pk
                                for k, v in relevant_db_entries.items()
                            }

                        IMPORT_LOGGER.debug('    GOING THROUGH ARCHIVE...')

                        imported_comp_names = set()
                        for key, value in data['export_data'][
                                entity_name].items():
                            if entity_name == GROUP_ENTITY_NAME:
                                # Check if there is already a group with the same name,
                                # and if so, recreate the name
                                orig_label = value['label']
                                dupl_counter = 0
                                while QueryBuilder().append(
                                        entity,
                                        filters={
                                            'label': {
                                                '==': value['label']
                                            }
                                        }).count():
                                    # Rename the new group
                                    value[
                                        'label'] = orig_label + DUPL_SUFFIX.format(
                                            dupl_counter)
                                    dupl_counter += 1
                                    if dupl_counter == 100:
                                        raise exceptions.ImportUniquenessError(
                                            'A group of that label ( {} ) already exists and I could not create a new '
                                            'one'.format(orig_label))

                            elif entity_name == COMPUTER_ENTITY_NAME:
                                # The following is done for compatibility
                                # reasons in case the export file was generated
                                # with the Django export method. In Django the
                                # metadata and the transport parameters are
                                # stored as (unicode) strings of the serialized
                                # JSON objects and not as simple serialized
                                # JSON objects.
                                if isinstance(value['metadata'], (str, bytes)):
                                    value['metadata'] = json.loads(
                                        value['metadata'])

                                # Check if there is already a computer with the
                                # same name in the database
                                builder = QueryBuilder()
                                builder.append(
                                    entity,
                                    filters={'name': {
                                        '==': value['name']
                                    }},
                                    project=['*'],
                                    tag='res')
                                dupl = builder.count(
                                ) or value['name'] in imported_comp_names
                                dupl_counter = 0
                                orig_name = value['name']
                                while dupl:
                                    # Rename the new computer
                                    value[
                                        'name'] = orig_name + DUPL_SUFFIX.format(
                                            dupl_counter)
                                    builder = QueryBuilder()
                                    builder.append(entity,
                                                   filters={
                                                       'name': {
                                                           '==': value['name']
                                                       }
                                                   },
                                                   project=['*'],
                                                   tag='res')
                                    dupl = builder.count(
                                    ) or value['name'] in imported_comp_names
                                    dupl_counter += 1
                                    if dupl_counter == 100:
                                        raise exceptions.ImportUniquenessError(
                                            'A computer of that name ( {} ) already exists and I could not create a '
                                            'new one'.format(orig_name))

                                imported_comp_names.add(value['name'])

                            if value[unique_identifier] in relevant_db_entries:
                                # Already in DB
                                # again, switched to entity_name in v0.3
                                existing_entries[entity_name][key] = value
                            else:
                                # To be added
                                new_entries[entity_name][key] = value
                    else:
                        new_entries[entity_name] = data['export_data'][
                            entity_name]

            # Progress bar - reset for import
            progress_bar = get_progress_bar(total=number_of_entities,
                                            disable=silent)
            reset_progress_bar = {}

            # I import data from the given model
            for entity_name in entity_order:
                entity = entity_names_to_entities[entity_name]
                fields_info = metadata['all_fields_info'].get(entity_name, {})
                unique_identifier = metadata['unique_identifiers'].get(
                    entity_name, '')

                # Progress bar initialization - Model
                if reset_progress_bar:
                    progress_bar = get_progress_bar(
                        total=reset_progress_bar['total'], disable=silent)
                    progress_bar.n = reset_progress_bar['n']
                    reset_progress_bar = {}
                pbar_base_str = '{}s - '.format(entity_name)
                progress_bar.set_description_str(pbar_base_str +
                                                 'Initializing',
                                                 refresh=True)

                # EXISTING ENTRIES
                if existing_entries[entity_name]:
                    # Progress bar update - Model
                    progress_bar.set_description_str(
                        pbar_base_str + '{} existing entries'.format(
                            len(existing_entries[entity_name])),
                        refresh=True)

                for import_entry_pk, entry_data in existing_entries[
                        entity_name].items():
                    unique_id = entry_data[unique_identifier]
                    existing_entry_pk = foreign_ids_reverse_mappings[
                        entity_name][unique_id]
                    import_data = dict(
                        deserialize_field(k,
                                          v,
                                          fields_info=fields_info,
                                          import_unique_ids_mappings=
                                          import_unique_ids_mappings,
                                          foreign_ids_reverse_mappings=
                                          foreign_ids_reverse_mappings)
                        for k, v in entry_data.items())
                    # TODO COMPARE, AND COMPARE ATTRIBUTES

                    if entity_name == COMMENT_ENTITY_NAME:
                        new_entry_uuid = merge_comment(import_data,
                                                       comment_mode)
                        if new_entry_uuid is not None:
                            entry_data[unique_identifier] = new_entry_uuid
                            new_entries[entity_name][
                                import_entry_pk] = entry_data

                    if entity_name not in ret_dict:
                        ret_dict[entity_name] = {'new': [], 'existing': []}
                    ret_dict[entity_name]['existing'].append(
                        (import_entry_pk, existing_entry_pk))
                    IMPORT_LOGGER.debug('Existing %s: %s (%s->%s)',
                                        entity_name, unique_id,
                                        import_entry_pk, existing_entry_pk)

                # Store all objects for this model in a list, and store them
                # all in once at the end.
                objects_to_create = list()
                # In the following list we add the objects to be updated
                objects_to_update = list()
                # This is needed later to associate the import entry with the new pk
                import_new_entry_pks = dict()

                # NEW ENTRIES
                if new_entries[entity_name]:
                    # Progress bar update - Model
                    progress_bar.set_description_str(
                        pbar_base_str +
                        '{} new entries'.format(len(new_entries[entity_name])),
                        refresh=True)

                for import_entry_pk, entry_data in new_entries[
                        entity_name].items():
                    unique_id = entry_data[unique_identifier]
                    import_data = dict(
                        deserialize_field(k,
                                          v,
                                          fields_info=fields_info,
                                          import_unique_ids_mappings=
                                          import_unique_ids_mappings,
                                          foreign_ids_reverse_mappings=
                                          foreign_ids_reverse_mappings)
                        for k, v in entry_data.items())

                    # We convert the Django fields to SQLA. Note that some of
                    # the Django fields were converted to SQLA compatible
                    # fields by the deserialize_field method. This was done
                    # for optimization reasons in Django but makes them
                    # compatible with the SQLA schema and they don't need any
                    # further conversion.
                    if entity_name in file_fields_to_model_fields:
                        for file_fkey in file_fields_to_model_fields[
                                entity_name]:

                            # This is an exception because the DbLog model defines the `_metadata` column instead of the
                            # `metadata` column used in the Django model. This is because the SqlAlchemy model base
                            # class already has a metadata attribute that cannot be overridden. For consistency, the
                            # `DbLog` class however expects the `metadata` keyword in its constructor, so we should
                            # ignore the mapping here
                            if entity_name == LOG_ENTITY_NAME and file_fkey == 'metadata':
                                continue

                            model_fkey = file_fields_to_model_fields[
                                entity_name][file_fkey]
                            if model_fkey in import_data:
                                continue
                            import_data[model_fkey] = import_data[file_fkey]
                            import_data.pop(file_fkey, None)

                    db_entity = get_object_from_string(
                        entity_names_to_sqla_schema[entity_name])

                    objects_to_create.append(db_entity(**import_data))
                    import_new_entry_pks[unique_id] = import_entry_pk

                if entity_name == NODE_ENTITY_NAME:
                    IMPORT_LOGGER.debug(
                        'STORING NEW NODE REPOSITORY FILES & ATTRIBUTES...')

                    # NEW NODES
                    for object_ in objects_to_create:
                        import_entry_uuid = object_.uuid
                        import_entry_pk = import_new_entry_pks[
                            import_entry_uuid]

                        # Progress bar initialization - Node
                        progress_bar.update()
                        pbar_node_base_str = pbar_base_str + 'UUID={} - '.format(
                            import_entry_uuid.split('-')[0])

                        # Before storing entries in the DB, I store the files (if these are nodes).
                        # Note: only for new entries!
                        subfolder = folder.get_subfolder(
                            os.path.join(NODES_EXPORT_SUBFOLDER,
                                         export_shard_uuid(import_entry_uuid)))
                        if not subfolder.exists():
                            raise exceptions.CorruptArchive(
                                'Unable to find the repository folder for Node with UUID={} in the exported '
                                'file'.format(import_entry_uuid))
                        destdir = RepositoryFolder(
                            section=Repository._section_name,
                            uuid=import_entry_uuid)
                        # Replace the folder, possibly destroying existing previous folders, and move the files
                        # (faster if we are on the same filesystem, and in any case the source is a SandboxFolder)
                        progress_bar.set_description_str(pbar_node_base_str +
                                                         'Repository',
                                                         refresh=True)
                        destdir.replace_with_folder(subfolder.abspath,
                                                    move=True,
                                                    overwrite=True)

                        # For Nodes, we also have to store Attributes!
                        IMPORT_LOGGER.debug('STORING NEW NODE ATTRIBUTES...')
                        progress_bar.set_description_str(pbar_node_base_str +
                                                         'Attributes',
                                                         refresh=True)

                        # Get attributes from import file
                        try:
                            object_.attributes = data['node_attributes'][str(
                                import_entry_pk)]
                        except KeyError:
                            raise exceptions.CorruptArchive(
                                'Unable to find attribute info for Node with UUID={}'
                                .format(import_entry_uuid))

                        # For DbNodes, we also have to store extras
                        if extras_mode_new == 'import':
                            IMPORT_LOGGER.debug('STORING NEW NODE EXTRAS...')
                            progress_bar.set_description_str(
                                pbar_node_base_str + 'Extras', refresh=True)

                            # Get extras from import file
                            try:
                                extras = data['node_extras'][str(
                                    import_entry_pk)]
                            except KeyError:
                                raise exceptions.CorruptArchive(
                                    'Unable to find extra info for Node with UUID={}'
                                    .format(import_entry_uuid))
                            # TODO: remove when aiida extras will be moved somewhere else
                            # from here
                            extras = {
                                key: value
                                for key, value in extras.items()
                                if not key.startswith('_aiida_')
                            }
                            if object_.node_type.endswith('code.Code.'):
                                extras = {
                                    key: value
                                    for key, value in extras.items()
                                    if not key == 'hidden'
                                }
                            # till here
                            object_.extras = extras
                        elif extras_mode_new == 'none':
                            IMPORT_LOGGER.debug('SKIPPING NEW NODE EXTRAS...')
                        else:
                            raise exceptions.ImportValidationError(
                                "Unknown extras_mode_new value: {}, should be either 'import' or 'none'"
                                ''.format(extras_mode_new))

                    # EXISTING NODES (Extras)
                    IMPORT_LOGGER.debug('UPDATING EXISTING NODE EXTRAS...')

                    import_existing_entry_pks = {
                        entry_data[unique_identifier]: import_entry_pk
                        for import_entry_pk, entry_data in
                        existing_entries[entity_name].items()
                    }
                    for node in session.query(DbNode).filter(
                            DbNode.uuid.in_(import_existing_entry_pks)).all():
                        import_entry_uuid = str(node.uuid)
                        import_entry_pk = import_existing_entry_pks[
                            import_entry_uuid]

                        # Progress bar initialization - Node
                        pbar_node_base_str = pbar_base_str + 'UUID={} - '.format(
                            import_entry_uuid.split('-')[0])
                        progress_bar.set_description_str(pbar_node_base_str +
                                                         'Extras',
                                                         refresh=False)
                        progress_bar.update()

                        # Get extras from import file
                        try:
                            extras = data['node_extras'][str(import_entry_pk)]
                        except KeyError:
                            raise exceptions.CorruptArchive(
                                'Unable to find extra info for Node with UUID={}'
                                .format(import_entry_uuid))

                        old_extras = node.extras.copy()
                        # TODO: remove when aiida extras will be moved somewhere else
                        # from here
                        extras = {
                            key: value
                            for key, value in extras.items()
                            if not key.startswith('_aiida_')
                        }
                        if node.node_type.endswith('code.Code.'):
                            extras = {
                                key: value
                                for key, value in extras.items()
                                if not key == 'hidden'
                            }
                        # till here
                        new_extras = merge_extras(node.extras, extras,
                                                  extras_mode_existing)
                        if new_extras != old_extras:
                            node.extras = new_extras
                            flag_modified(node, 'extras')
                            objects_to_update.append(node)

                else:
                    # Update progress bar with new non-Node entries
                    progress_bar.update(n=len(existing_entries[entity_name]) +
                                        len(new_entries[entity_name]))

                progress_bar.set_description_str(pbar_base_str + 'Storing',
                                                 refresh=True)

                # Store them all in once; However, the PK are not set in this way...
                if objects_to_create:
                    session.add_all(objects_to_create)
                if objects_to_update:
                    session.add_all(objects_to_update)

                session.flush()

                just_saved = {}
                if import_new_entry_pks.keys():
                    reset_progress_bar = {
                        'total': progress_bar.total,
                        'n': progress_bar.n
                    }
                    progress_bar = get_progress_bar(
                        total=len(import_new_entry_pks), disable=silent)

                    builder = QueryBuilder()
                    builder.append(entity,
                                   filters={
                                       unique_identifier: {
                                           'in':
                                           list(import_new_entry_pks.keys())
                                       }
                                   },
                                   project=[unique_identifier, 'id'])

                    for entry in builder.iterall():
                        progress_bar.update()

                        just_saved.update({entry[0]: entry[1]})

                progress_bar.set_description_str(pbar_base_str + 'Done!',
                                                 refresh=True)

                # Now I have the PKs, print the info
                # Moreover, add newly created Nodes to foreign_ids_reverse_mappings
                for unique_id, new_pk in just_saved.items():
                    from uuid import UUID
                    if isinstance(unique_id, UUID):
                        unique_id = str(unique_id)
                    import_entry_pk = import_new_entry_pks[unique_id]
                    foreign_ids_reverse_mappings[entity_name][
                        unique_id] = new_pk
                    if entity_name not in ret_dict:
                        ret_dict[entity_name] = {'new': [], 'existing': []}
                    ret_dict[entity_name]['new'].append(
                        (import_entry_pk, new_pk))

                    IMPORT_LOGGER.debug('N %s: %s (%s->%s)', entity_name,
                                        unique_id, import_entry_pk, new_pk)

            IMPORT_LOGGER.debug('STORING NODE LINKS...')

            import_links = data['links_uuid']

            if import_links:
                progress_bar = get_progress_bar(total=len(import_links),
                                                disable=silent)
                pbar_base_str = 'Links - '

            for link in import_links:
                # Check for dangling Links within the, supposed, self-consistent archive
                progress_bar.set_description_str(
                    pbar_base_str + 'label={}'.format(link['label']),
                    refresh=False)
                progress_bar.update()

                try:
                    in_id = foreign_ids_reverse_mappings[NODE_ENTITY_NAME][
                        link['input']]
                    out_id = foreign_ids_reverse_mappings[NODE_ENTITY_NAME][
                        link['output']]
                except KeyError:
                    if ignore_unknown_nodes:
                        continue
                    raise exceptions.ImportValidationError(
                        'Trying to create a link with one or both unknown nodes, stopping (in_uuid={}, out_uuid={}, '
                        'label={}, type={})'.format(link['input'],
                                                    link['output'],
                                                    link['label'],
                                                    link['type']))

                # Since backend specific Links (DbLink) are not validated upon creation, we will now validate them.
                source = QueryBuilder().append(Node,
                                               filters={
                                                   'id': in_id
                                               },
                                               project='*').first()[0]
                target = QueryBuilder().append(Node,
                                               filters={
                                                   'id': out_id
                                               },
                                               project='*').first()[0]
                link_type = LinkType(link['type'])

                # Check for existence of a triple link, i.e. unique triple.
                # If it exists, then the link already exists, continue to next link, otherwise, validate link.
                if link_triple_exists(source, target, link_type,
                                      link['label']):
                    continue

                try:
                    validate_link(source, target, link_type, link['label'])
                except ValueError as why:
                    raise exceptions.ImportValidationError(
                        'Error occurred during Link validation: {}'.format(
                            why))

                # New link
                session.add(
                    DbLink(input_id=in_id,
                           output_id=out_id,
                           label=link['label'],
                           type=link['type']))
                if 'Link' not in ret_dict:
                    ret_dict['Link'] = {'new': []}
                ret_dict['Link']['new'].append((in_id, out_id))

            IMPORT_LOGGER.debug('   (%d new links...)',
                                len(ret_dict.get('Link', {}).get('new', [])))

            IMPORT_LOGGER.debug('STORING GROUP ELEMENTS...')

            import_groups = data['groups_uuid']

            if import_groups:
                progress_bar = get_progress_bar(total=len(import_groups),
                                                disable=silent)
                pbar_base_str = 'Groups - '

            for groupuuid, groupnodes in import_groups.items():
                # # TODO: cache these to avoid too many queries
                qb_group = QueryBuilder().append(
                    Group, filters={'uuid': {
                        '==': groupuuid
                    }})
                group_ = qb_group.first()[0]

                progress_bar.set_description_str(
                    pbar_base_str + 'label={}'.format(group_.label),
                    refresh=False)
                progress_bar.update()

                nodes_ids_to_add = [
                    foreign_ids_reverse_mappings[NODE_ENTITY_NAME][node_uuid]
                    for node_uuid in groupnodes
                ]
                qb_nodes = QueryBuilder().append(
                    Node, filters={'id': {
                        'in': nodes_ids_to_add
                    }})
                # Adding nodes to group avoiding the SQLA ORM to increase speed
                nodes_to_add = [n[0].backend_entity for n in qb_nodes.all()]
                group_.backend_entity.add_nodes(nodes_to_add, skip_orm=True)

            ######################################################
            # Put everything in a specific group
            ######################################################
            existing = existing_entries.get(NODE_ENTITY_NAME, {})
            existing_pk = [
                foreign_ids_reverse_mappings[NODE_ENTITY_NAME][v['uuid']]
                for v in existing.values()
            ]
            new = new_entries.get(NODE_ENTITY_NAME, {})
            new_pk = [
                foreign_ids_reverse_mappings[NODE_ENTITY_NAME][v['uuid']]
                for v in new.values()
            ]

            pks_for_group = existing_pk + new_pk

            # So that we do not create empty groups
            if pks_for_group:
                # If user specified a group, import all things into it
                if not group:
                    from aiida.backends.sqlalchemy.models.group import DbGroup

                    # Get an unique name for the import group, based on the current (local) time
                    basename = timezone.localtime(
                        timezone.now()).strftime('%Y%m%d-%H%M%S')
                    counter = 0
                    group_label = basename
                    while session.query(DbGroup).filter(
                            DbGroup.label == group_label).count() > 0:
                        counter += 1
                        group_label = '{}_{}'.format(basename, counter)

                        if counter == 100:
                            raise exceptions.ImportUniquenessError(
                                "Overflow of import groups (more than 100 import groups exists with basename '{}')"
                                ''.format(basename))
                    group = ImportGroup(label=group_label)
                    session.add(group.backend_entity._dbmodel)

                # Adding nodes to group avoiding the SQLA ORM to increase speed
                builder = QueryBuilder().append(
                    Node, filters={'id': {
                        'in': pks_for_group
                    }})

                progress_bar = get_progress_bar(total=len(pks_for_group),
                                                disable=silent)
                progress_bar.set_description_str(
                    'Creating import Group - Preprocessing', refresh=True)
                first = True

                nodes = []
                for entry in builder.iterall():
                    if first:
                        progress_bar.set_description_str(
                            'Creating import Group', refresh=False)
                        first = False
                    progress_bar.update()
                    nodes.append(entry[0].backend_entity)
                group.backend_entity.add_nodes(nodes, skip_orm=True)
                progress_bar.set_description_str('Done (cleaning up)',
                                                 refresh=True)
            else:
                IMPORT_LOGGER.debug(
                    'No Nodes to import, so no Group created, if it did not already exist'
                )

            IMPORT_LOGGER.debug('COMMITTING EVERYTHING...')
            session.commit()

            # Finalize Progress bar
            close_progress_bar(leave=False)

            # Summarize import
            result_summary(ret_dict, getattr(group, 'label', None))

        except:
            # Finalize Progress bar
            close_progress_bar(leave=False)

            result_summary({}, None)

            IMPORT_LOGGER.debug('Rolling back')
            session.rollback()
            raise

    # Reset logging level
    if silent:
        logging.disable(level=logging.NOTSET)

    return ret_dict
Exemplo n.º 19
0
def import_data_dj(in_path,
                   group=None,
                   ignore_unknown_nodes=False,
                   extras_mode_existing='kcl',
                   extras_mode_new='import',
                   comment_mode='newest',
                   silent=False):
    """Import exported AiiDA archive to the AiiDA database and repository.

    Specific for the Django backend.
    If ``in_path`` is a folder, calls extract_tree; otherwise, tries to detect the compression format
    (zip, tar.gz, tar.bz2, ...) and calls the correct function.

    :param in_path: the path to a file or folder that can be imported in AiiDA.
    :type in_path: str

    :param group: Group wherein all imported Nodes will be placed.
    :type group: :py:class:`~aiida.orm.groups.Group`

    :param extras_mode_existing: 3 letter code that will identify what to do with the extras import.
        The first letter acts on extras that are present in the original node and not present in the imported node.
        Can be either:
        'k' (keep it) or
        'n' (do not keep it).
        The second letter acts on the imported extras that are not present in the original node.
        Can be either:
        'c' (create it) or
        'n' (do not create it).
        The third letter defines what to do in case of a name collision.
        Can be either:
        'l' (leave the old value),
        'u' (update with a new value),
        'd' (delete the extra), or
        'a' (ask what to do if the content is different).
    :type extras_mode_existing: str

    :param extras_mode_new: 'import' to import extras of new nodes or 'none' to ignore them.
    :type extras_mode_new: str

    :param comment_mode: Comment import modes (when same UUIDs are found).
        Can be either:
        'newest' (will keep the Comment with the most recent modification time (mtime)) or
        'overwrite' (will overwrite existing Comments with the ones from the import file).
    :type comment_mode: str

    :param silent: suppress prints.
    :type silent: bool

    :return: New and existing Nodes and Links.
    :rtype: dict

    :raises `~aiida.tools.importexport.common.exceptions.ImportValidationError`: if parameters or the contents of
        `metadata.json` or `data.json` can not be validated.
    :raises `~aiida.tools.importexport.common.exceptions.CorruptArchive`: if the provided archive at ``in_path`` is
        corrupted.
    :raises `~aiida.tools.importexport.common.exceptions.IncompatibleArchiveVersionError`: if the provided archive's
        export version is not equal to the export version of AiiDA at the moment of import.
    :raises `~aiida.tools.importexport.common.exceptions.ArchiveImportError`: if there are any internal errors when
        importing.
    :raises `~aiida.tools.importexport.common.exceptions.ImportUniquenessError`: if a new unique entity can not be
        created.
    """
    from django.db import transaction  # pylint: disable=import-error,no-name-in-module
    from aiida.backends.djsite.db import models

    # This is the export version expected by this function
    expected_export_version = StrictVersion(EXPORT_VERSION)

    # The returned dictionary with new and existing nodes and links
    ret_dict = {}

    # Initial check(s)
    if group:
        if not isinstance(group, Group):
            raise exceptions.ImportValidationError(
                'group must be a Group entity')
        elif not group.is_stored:
            group.store()

    ################
    # EXTRACT DATA #
    ################
    # The sandbox has to remain open until the end
    with SandboxFolder() as folder:
        if os.path.isdir(in_path):
            extract_tree(in_path, folder)
        else:
            if tarfile.is_tarfile(in_path):
                extract_tar(in_path,
                            folder,
                            silent=silent,
                            nodes_export_subfolder=NODES_EXPORT_SUBFOLDER)
            elif zipfile.is_zipfile(in_path):
                try:
                    extract_zip(in_path,
                                folder,
                                silent=silent,
                                nodes_export_subfolder=NODES_EXPORT_SUBFOLDER)
                except ValueError as exc:
                    print(
                        'The following problem occured while processing the provided file: {}'
                        .format(exc))
                    return
            else:
                raise exceptions.ImportValidationError(
                    'Unable to detect the input file format, it is neither a '
                    '(possibly compressed) tar file, nor a zip file.')

        if not folder.get_content_list():
            raise exceptions.CorruptArchive(
                'The provided file/folder ({}) is empty'.format(in_path))
        try:
            with open(folder.get_abs_path('metadata.json'),
                      'r',
                      encoding='utf8') as fhandle:
                metadata = json.load(fhandle)

            with open(folder.get_abs_path('data.json'), 'r',
                      encoding='utf8') as fhandle:
                data = json.load(fhandle)
        except IOError as error:
            raise exceptions.CorruptArchive(
                'Unable to find the file {} in the import file or folder'.
                format(error.filename))

        ######################
        # PRELIMINARY CHECKS #
        ######################
        export_version = StrictVersion(str(metadata['export_version']))
        if export_version != expected_export_version:
            msg = 'Export file version is {}, can import only version {}'\
                    .format(metadata['export_version'], expected_export_version)
            if export_version < expected_export_version:
                msg += "\nUse 'verdi export migrate' to update this export file."
            else:
                msg += '\nUpdate your AiiDA version in order to import this file.'

            raise exceptions.IncompatibleArchiveVersionError(msg)

        ##########################################################################
        # CREATE UUID REVERSE TABLES AND CHECK IF I HAVE ALL NODES FOR THE LINKS #
        ##########################################################################
        linked_nodes = set(
            chain.from_iterable(
                (l['input'], l['output']) for l in data['links_uuid']))
        group_nodes = set(chain.from_iterable(data['groups_uuid'].values()))

        if NODE_ENTITY_NAME in data['export_data']:
            import_nodes_uuid = set(
                v['uuid']
                for v in data['export_data'][NODE_ENTITY_NAME].values())
        else:
            import_nodes_uuid = set()

        # the combined set of linked_nodes and group_nodes was obtained from looking at all the links
        # the set of import_nodes_uuid was received from the stuff actually referred to in export_data
        unknown_nodes = linked_nodes.union(group_nodes) - import_nodes_uuid

        if unknown_nodes and not ignore_unknown_nodes:
            raise exceptions.DanglingLinkError(
                'The import file refers to {} nodes with unknown UUID, therefore it cannot be imported. Either first '
                'import the unknown nodes, or export also the parents when exporting. The unknown UUIDs are:\n'
                ''.format(len(unknown_nodes)) +
                '\n'.join('* {}'.format(uuid) for uuid in unknown_nodes))

        ###################################
        # DOUBLE-CHECK MODEL DEPENDENCIES #
        ###################################
        # The entity import order. It is defined by the database model relationships.

        model_order = (USER_ENTITY_NAME, COMPUTER_ENTITY_NAME,
                       NODE_ENTITY_NAME, GROUP_ENTITY_NAME, LOG_ENTITY_NAME,
                       COMMENT_ENTITY_NAME)

        for import_field_name in metadata['all_fields_info']:
            if import_field_name not in model_order:
                raise exceptions.ImportValidationError(
                    "You are trying to import an unknown model '{}'!".format(
                        import_field_name))

        for idx, model_name in enumerate(model_order):
            dependencies = []
            for field in metadata['all_fields_info'][model_name].values():
                try:
                    dependencies.append(field['requires'])
                except KeyError:
                    # (No ForeignKey)
                    pass
            for dependency in dependencies:
                if dependency not in model_order[:idx]:
                    raise exceptions.ArchiveImportError(
                        'Model {} requires {} but would be loaded first; stopping...'
                        .format(model_name, dependency))

        ###################################################
        # CREATE IMPORT DATA DIRECT UNIQUE_FIELD MAPPINGS #
        ###################################################
        import_unique_ids_mappings = {}
        for model_name, import_data in data['export_data'].items():
            if model_name in metadata['unique_identifiers']:
                # I have to reconvert the pk to integer
                import_unique_ids_mappings[model_name] = {
                    int(k): v[metadata['unique_identifiers'][model_name]]
                    for k, v in import_data.items()
                }

        ###############
        # IMPORT DATA #
        ###############
        # DO ALL WITH A TRANSACTION

        # batch size for bulk create operations
        batch_size = get_config_option('db.batch_size')

        with transaction.atomic():
            foreign_ids_reverse_mappings = {}
            new_entries = {}
            existing_entries = {}

            # I first generate the list of data
            for model_name in model_order:
                cls_signature = entity_names_to_signatures[model_name]
                model = get_object_from_string(cls_signature)
                fields_info = metadata['all_fields_info'].get(model_name, {})
                unique_identifier = metadata['unique_identifiers'].get(
                    model_name, None)

                new_entries[model_name] = {}
                existing_entries[model_name] = {}

                foreign_ids_reverse_mappings[model_name] = {}

                # Not necessarily all models are exported
                if model_name in data['export_data']:

                    # skip nodes that are already present in the DB
                    if unique_identifier is not None:
                        import_unique_ids = set(
                            v[unique_identifier]
                            for v in data['export_data'][model_name].values())

                        relevant_db_entries_result = model.objects.filter(**{
                            '{}__in'.format(unique_identifier):
                            import_unique_ids
                        })
                        # Note: uuids need to be converted to strings
                        relevant_db_entries = {
                            str(getattr(n, unique_identifier)): n
                            for n in relevant_db_entries_result
                        }

                        foreign_ids_reverse_mappings[model_name] = {
                            k: v.pk
                            for k, v in relevant_db_entries.items()
                        }
                        for key, value in data['export_data'][
                                model_name].items():
                            if value[
                                    unique_identifier] in relevant_db_entries.keys(
                                    ):
                                # Already in DB
                                existing_entries[model_name][key] = value
                            else:
                                # To be added
                                new_entries[model_name][key] = value
                    else:
                        new_entries[model_name] = data['export_data'][
                            model_name].copy()

            # Show Comment mode if not silent
            if not silent:
                print('Comment mode: {}'.format(comment_mode))

            # I import data from the given model
            for model_name in model_order:
                cls_signature = entity_names_to_signatures[model_name]
                model = get_object_from_string(cls_signature)
                fields_info = metadata['all_fields_info'].get(model_name, {})
                unique_identifier = metadata['unique_identifiers'].get(
                    model_name, None)

                # EXISTING ENTRIES
                for import_entry_pk, entry_data in existing_entries[
                        model_name].items():
                    unique_id = entry_data[unique_identifier]
                    existing_entry_id = foreign_ids_reverse_mappings[
                        model_name][unique_id]
                    import_data = dict(
                        deserialize_field(k,
                                          v,
                                          fields_info=fields_info,
                                          import_unique_ids_mappings=
                                          import_unique_ids_mappings,
                                          foreign_ids_reverse_mappings=
                                          foreign_ids_reverse_mappings)
                        for k, v in entry_data.items())
                    # TODO COMPARE, AND COMPARE ATTRIBUTES

                    if model is models.DbComment:
                        new_entry_uuid = merge_comment(import_data,
                                                       comment_mode)
                        if new_entry_uuid is not None:
                            entry_data[unique_identifier] = new_entry_uuid
                            new_entries[model_name][
                                import_entry_pk] = entry_data

                    if model_name not in ret_dict:
                        ret_dict[model_name] = {'new': [], 'existing': []}
                    ret_dict[model_name]['existing'].append(
                        (import_entry_pk, existing_entry_id))
                    if not silent:
                        print('existing %s: %s (%s->%s)' %
                              (model_name, unique_id, import_entry_pk,
                               existing_entry_id))
                        # print("  `-> WARNING: NO DUPLICITY CHECK DONE!")
                        # CHECK ALSO FILES!

                # Store all objects for this model in a list, and store them all in once at the end.
                objects_to_create = []
                # This is needed later to associate the import entry with the new pk
                import_new_entry_pks = {}
                imported_comp_names = set()

                # NEW ENTRIES
                for import_entry_pk, entry_data in new_entries[
                        model_name].items():
                    unique_id = entry_data[unique_identifier]
                    import_data = dict(
                        deserialize_field(k,
                                          v,
                                          fields_info=fields_info,
                                          import_unique_ids_mappings=
                                          import_unique_ids_mappings,
                                          foreign_ids_reverse_mappings=
                                          foreign_ids_reverse_mappings)
                        for k, v in entry_data.items())

                    if model is models.DbGroup:
                        # Check if there is already a group with the same name
                        dupl_counter = 0
                        orig_label = import_data['label']
                        while model.objects.filter(label=import_data['label']):
                            import_data[
                                'label'] = orig_label + DUPL_SUFFIX.format(
                                    dupl_counter)
                            dupl_counter += 1
                            if dupl_counter == 100:
                                raise exceptions.ImportUniquenessError(
                                    'A group of that label ( {} ) already exists and I could not create a new one'
                                    ''.format(orig_label))

                    elif model is models.DbComputer:
                        # Check if there is already a computer with the same name in the database
                        dupl = (model.objects.filter(name=import_data['name'])
                                or import_data['name'] in imported_comp_names)
                        orig_name = import_data['name']
                        dupl_counter = 0
                        while dupl:
                            # Rename the new computer
                            import_data['name'] = (
                                orig_name + DUPL_SUFFIX.format(dupl_counter))
                            dupl = (
                                model.objects.filter(name=import_data['name'])
                                or import_data['name'] in imported_comp_names)
                            dupl_counter += 1
                            if dupl_counter == 100:
                                raise exceptions.ImportUniquenessError(
                                    'A computer of that name ( {} ) already exists and I could not create a new one'
                                    ''.format(orig_name))

                        imported_comp_names.add(import_data['name'])

                    objects_to_create.append(model(**import_data))
                    import_new_entry_pks[unique_id] = import_entry_pk

                if model_name == NODE_ENTITY_NAME:
                    if not silent:
                        print('STORING NEW NODE REPOSITORY FILES...')

                    # NEW NODES
                    for object_ in objects_to_create:
                        import_entry_uuid = object_.uuid
                        import_entry_pk = import_new_entry_pks[
                            import_entry_uuid]

                        # Before storing entries in the DB, I store the files (if these are nodes).
                        # Note: only for new entries!
                        subfolder = folder.get_subfolder(
                            os.path.join(NODES_EXPORT_SUBFOLDER,
                                         export_shard_uuid(import_entry_uuid)))
                        if not subfolder.exists():
                            raise exceptions.CorruptArchive(
                                'Unable to find the repository folder for Node with UUID={} in the exported '
                                'file'.format(import_entry_uuid))
                        destdir = RepositoryFolder(
                            section=Repository._section_name,
                            uuid=import_entry_uuid)
                        # Replace the folder, possibly destroying existing previous folders, and move the files
                        # (faster if we are on the same filesystem, and in any case the source is a SandboxFolder)
                        destdir.replace_with_folder(subfolder.abspath,
                                                    move=True,
                                                    overwrite=True)

                        # For DbNodes, we also have to store its attributes
                        if not silent:
                            print('STORING NEW NODE ATTRIBUTES...')

                        # Get attributes from import file
                        try:
                            object_.attributes = data['node_attributes'][str(
                                import_entry_pk)]
                        except KeyError:
                            raise exceptions.CorruptArchive(
                                'Unable to find attribute info for Node with UUID={}'
                                .format(import_entry_uuid))

                        # For DbNodes, we also have to store its extras
                        if extras_mode_new == 'import':
                            if not silent:
                                print('STORING NEW NODE EXTRAS...')

                            # Get extras from import file
                            try:
                                extras = data['node_extras'][str(
                                    import_entry_pk)]
                            except KeyError:
                                raise exceptions.CorruptArchive(
                                    'Unable to find extra info for Node with UUID={}'
                                    .format(import_entry_uuid))
                            # TODO: remove when aiida extras will be moved somewhere else
                            # from here
                            extras = {
                                key: value
                                for key, value in extras.items()
                                if not key.startswith('_aiida_')
                            }
                            if object_.node_type.endswith('code.Code.'):
                                extras = {
                                    key: value
                                    for key, value in extras.items()
                                    if not key == 'hidden'
                                }
                            # till here
                            object_.extras = extras
                        elif extras_mode_new == 'none':
                            if not silent:
                                print('SKIPPING NEW NODE EXTRAS...')
                        else:
                            raise exceptions.ImportValidationError(
                                "Unknown extras_mode_new value: {}, should be either 'import' or 'none'"
                                ''.format(extras_mode_new))

                    # EXISTING NODES (Extras)
                    # For the existing nodes that are also in the imported list we also update their extras if necessary
                    if not silent:
                        print(
                            'UPDATING EXISTING NODE EXTRAS (mode: {})'.format(
                                extras_mode_existing))

                    import_existing_entry_pks = {
                        entry_data[unique_identifier]: import_entry_pk
                        for import_entry_pk, entry_data in
                        existing_entries[model_name].items()
                    }
                    for node in models.DbNode.objects.filter(
                            uuid__in=import_existing_entry_pks).all():  # pylint: disable=no-member
                        import_entry_uuid = str(node.uuid)
                        import_entry_pk = import_existing_entry_pks[
                            import_entry_uuid]

                        # Get extras from import file
                        try:
                            extras = data['node_extras'][str(import_entry_pk)]
                        except KeyError:
                            raise exceptions.CorruptArchive(
                                'Unable to find extra info for ode with UUID={}'
                                .format(import_entry_uuid))

                        # TODO: remove when aiida extras will be moved somewhere else
                        # from here
                        extras = {
                            key: value
                            for key, value in extras.items()
                            if not key.startswith('_aiida_')
                        }
                        if node.node_type.endswith('code.Code.'):
                            extras = {
                                key: value
                                for key, value in extras.items()
                                if not key == 'hidden'
                            }
                        # till here
                        node.extras = merge_extras(node.extras, extras,
                                                   extras_mode_existing)

                        # Already saving existing node here to update its extras
                        node.save()

                # If there is an mtime in the field, disable the automatic update
                # to keep the mtime that we have set here
                if 'mtime' in [
                        field.name for field in model._meta.local_fields
                ]:
                    with models.suppress_auto_now([(model, ['mtime'])]):
                        # Store them all in once; however, the PK are not set in this way...
                        model.objects.bulk_create(objects_to_create,
                                                  batch_size=batch_size)
                else:
                    model.objects.bulk_create(objects_to_create,
                                              batch_size=batch_size)

                # Get back the just-saved entries
                just_saved_queryset = model.objects.filter(
                    **{
                        '{}__in'.format(unique_identifier):
                        import_new_entry_pks.keys()
                    }).values_list(unique_identifier, 'pk')
                # note: convert uuids from type UUID to strings
                just_saved = {
                    str(key): value
                    for key, value in just_saved_queryset
                }

                # Now I have the PKs, print the info
                # Moreover, add newly created Nodes to foreign_ids_reverse_mappings
                for unique_id, new_pk in just_saved.items():
                    import_entry_pk = import_new_entry_pks[unique_id]
                    foreign_ids_reverse_mappings[model_name][
                        unique_id] = new_pk
                    if model_name not in ret_dict:
                        ret_dict[model_name] = {'new': [], 'existing': []}
                    ret_dict[model_name]['new'].append(
                        (import_entry_pk, new_pk))

                    if not silent:
                        print('NEW %s: %s (%s->%s)' %
                              (model_name, unique_id, import_entry_pk, new_pk))

            if not silent:
                print('STORING NODE LINKS...')
            import_links = data['links_uuid']
            links_to_store = []

            # Needed, since QueryBuilder does not yet work for recently saved Nodes
            existing_links_raw = models.DbLink.objects.all().values_list(
                'input', 'output', 'label', 'type')
            existing_links = {(l[0], l[1], l[2], l[3])
                              for l in existing_links_raw}
            existing_outgoing_unique = {(l[0], l[3])
                                        for l in existing_links_raw}
            existing_outgoing_unique_pair = {(l[0], l[2], l[3])
                                             for l in existing_links_raw}
            existing_incoming_unique = {(l[1], l[3])
                                        for l in existing_links_raw}
            existing_incoming_unique_pair = {(l[1], l[2], l[3])
                                             for l in existing_links_raw}

            calculation_node_types = 'process.calculation.'
            workflow_node_types = 'process.workflow.'
            data_node_types = 'data.'

            link_mapping = {
                LinkType.CALL_CALC:
                (workflow_node_types, calculation_node_types, 'unique_triple',
                 'unique'),
                LinkType.CALL_WORK: (workflow_node_types, workflow_node_types,
                                     'unique_triple', 'unique'),
                LinkType.CREATE: (calculation_node_types, data_node_types,
                                  'unique_pair', 'unique'),
                LinkType.INPUT_CALC: (data_node_types, calculation_node_types,
                                      'unique_triple', 'unique_pair'),
                LinkType.INPUT_WORK: (data_node_types, workflow_node_types,
                                      'unique_triple', 'unique_pair'),
                LinkType.RETURN: (workflow_node_types, data_node_types,
                                  'unique_pair', 'unique_triple'),
            }

            for link in import_links:
                # Check for dangling Links within the, supposed, self-consistent archive
                try:
                    in_id = foreign_ids_reverse_mappings[NODE_ENTITY_NAME][
                        link['input']]
                    out_id = foreign_ids_reverse_mappings[NODE_ENTITY_NAME][
                        link['output']]
                except KeyError:
                    if ignore_unknown_nodes:
                        continue
                    raise exceptions.ImportValidationError(
                        'Trying to create a link with one or both unknown nodes, stopping (in_uuid={}, out_uuid={}, '
                        'label={}, type={})'.format(link['input'],
                                                    link['output'],
                                                    link['label'],
                                                    link['type']))

                # Check if link already exists, skip if it does
                # This is equivalent to an existing triple link (i.e. unique_triple from below)
                if (in_id, out_id, link['label'],
                        link['type']) in existing_links:
                    continue

                # Since backend specific Links (DbLink) are not validated upon creation, we will now validate them.
                try:
                    validate_link_label(link['label'])
                except ValueError as why:
                    raise exceptions.ImportValidationError(
                        'Error during Link label validation: {}'.format(why))

                source = models.DbNode.objects.get(id=in_id)
                target = models.DbNode.objects.get(id=out_id)

                if source.uuid == target.uuid:
                    raise exceptions.ImportValidationError(
                        'Cannot add a link to oneself')

                link_type = LinkType(link['type'])
                type_source, type_target, outdegree, indegree = link_mapping[
                    link_type]

                # Check if source Node is a valid type
                if not source.node_type.startswith(type_source):
                    raise exceptions.ImportValidationError(
                        'Cannot add a {} link from {} to {}'.format(
                            link_type, source.node_type, target.node_type))

                # Check if target Node is a valid type
                if not target.node_type.startswith(type_target):
                    raise exceptions.ImportValidationError(
                        'Cannot add a {} link from {} to {}'.format(
                            link_type, source.node_type, target.node_type))

                # If the outdegree is `unique` there cannot already be any other outgoing link of that type,
                # i.e., the source Node may not have a LinkType of current LinkType, going out, existing already.
                if outdegree == 'unique' and (
                        in_id, link['type']) in existing_outgoing_unique:
                    raise exceptions.ImportValidationError(
                        'Node<{}> already has an outgoing {} link'.format(
                            source.uuid, link_type))

                # If the outdegree is `unique_pair`,
                # then the link labels for outgoing links of this type should be unique,
                # i.e., the source Node may not have a LinkType of current LinkType, going out,
                # that also has the current Link label, existing already.
                elif outdegree == 'unique_pair' and \
                (in_id, link['label'], link['type']) in existing_outgoing_unique_pair:
                    raise exceptions.ImportValidationError(
                        'Node<{}> already has an outgoing {} link with label "{}"'
                        .format(source.uuid, link_type, link['label']))

                # If the indegree is `unique` there cannot already be any other incoming links of that type,
                # i.e., the target Node may not have a LinkType of current LinkType, coming in, existing already.
                if indegree == 'unique' and (
                        out_id, link['type']) in existing_incoming_unique:
                    raise exceptions.ImportValidationError(
                        'Node<{}> already has an incoming {} link'.format(
                            target.uuid, link_type))

                # If the indegree is `unique_pair`,
                # then the link labels for incoming links of this type should be unique,
                # i.e., the target Node may not have a LinkType of current LinkType, coming in
                # that also has the current Link label, existing already.
                elif indegree == 'unique_pair' and \
                (out_id, link['label'], link['type']) in existing_incoming_unique_pair:
                    raise exceptions.ImportValidationError(
                        'Node<{}> already has an incoming {} link with label "{}"'
                        .format(target.uuid, link_type, link['label']))

                # New link
                links_to_store.append(
                    models.DbLink(input_id=in_id,
                                  output_id=out_id,
                                  label=link['label'],
                                  type=link['type']))
                if 'Link' not in ret_dict:
                    ret_dict['Link'] = {'new': []}
                ret_dict['Link']['new'].append((in_id, out_id))

                # Add new Link to sets of existing Links 'input PK', 'output PK', 'label', 'type'
                existing_links.add(
                    (in_id, out_id, link['label'], link['type']))
                existing_outgoing_unique.add((in_id, link['type']))
                existing_outgoing_unique_pair.add(
                    (in_id, link['label'], link['type']))
                existing_incoming_unique.add((out_id, link['type']))
                existing_incoming_unique_pair.add(
                    (out_id, link['label'], link['type']))

            # Store new links
            if links_to_store:
                if not silent:
                    print('   ({} new links...)'.format(len(links_to_store)))

                models.DbLink.objects.bulk_create(links_to_store,
                                                  batch_size=batch_size)
            else:
                if not silent:
                    print('   (0 new links...)')

            if not silent:
                print('STORING GROUP ELEMENTS...')
            import_groups = data['groups_uuid']
            for groupuuid, groupnodes in import_groups.items():
                # TODO: cache these to avoid too many queries
                group_ = models.DbGroup.objects.get(uuid=groupuuid)
                nodes_to_store = [
                    foreign_ids_reverse_mappings[NODE_ENTITY_NAME][node_uuid]
                    for node_uuid in groupnodes
                ]
                if nodes_to_store:
                    group_.dbnodes.add(*nodes_to_store)

        ######################################################
        # Put everything in a specific group
        ######################################################
        existing = existing_entries.get(NODE_ENTITY_NAME, {})
        existing_pk = [
            foreign_ids_reverse_mappings[NODE_ENTITY_NAME][v['uuid']]
            for v in existing.values()
        ]
        new = new_entries.get(NODE_ENTITY_NAME, {})
        new_pk = [
            foreign_ids_reverse_mappings[NODE_ENTITY_NAME][v['uuid']]
            for v in new.values()
        ]

        pks_for_group = existing_pk + new_pk

        # So that we do not create empty groups
        if pks_for_group:
            # If user specified a group, import all things into it
            if not group:
                # Get an unique name for the import group, based on the current (local) time
                basename = timezone.localtime(
                    timezone.now()).strftime('%Y%m%d-%H%M%S')
                counter = 0
                group_label = basename

                while Group.objects.find(filters={'label': group_label}):
                    counter += 1
                    group_label = '{}_{}'.format(basename, counter)

                    if counter == 100:
                        raise exceptions.ImportUniquenessError(
                            "Overflow of import groups (more than 100 import groups exists with basename '{}')"
                            ''.format(basename))
                group = ImportGroup(label=group_label).store()

            # Add all the nodes to the new group
            # TODO: decide if we want to return the group label
            nodes = [
                entry[0]
                for entry in QueryBuilder().append(Node,
                                                   filters={
                                                       'id': {
                                                           'in': pks_for_group
                                                       }
                                                   }).all()
            ]
            group.add_nodes(nodes)

            if not silent:
                print(
                    "IMPORTED NODES ARE GROUPED IN THE IMPORT GROUP LABELED '{}'"
                    .format(group.label))
        else:
            if not silent:
                print(
                    'NO NODES TO IMPORT, SO NO GROUP CREATED, IF IT DID NOT ALREADY EXIST'
                )

    if not silent:
        print('DONE.')

    return ret_dict
Exemplo n.º 20
0
def import_data(in_path, ignore_unknown_nodes=False, silent=False):
    """
    Import exported AiiDA environment to the AiiDA database.
    If the 'in_path' is a folder, calls export_tree; otherwise, tries to
    detect the compression format (zip, tar.gz, tar.bz2, ...) and calls the
    correct function.

    :param in_path: the path to a file or folder that can be imported in AiiDA
    """
    import json
    import os
    import tarfile
    import zipfile
    from itertools import chain

    from django.db import transaction
    from aiida.utils import timezone

    from aiida.orm.node import Node
    from aiida.orm.group import Group
    from aiida.common.exceptions import UniquenessError
    from aiida.common.folders import SandboxFolder, RepositoryFolder
    from aiida.backends.djsite.db import models
    from aiida.common.utils import get_class_string, get_object_from_string
    from aiida.common.datastructures import calc_states

    # This is the export version expected by this function
    expected_export_version = '0.1'

    # The name of the subfolder in which the node files are stored
    nodes_export_subfolder = 'nodes'

    # The returned dictionary with new and existing nodes and links
    ret_dict = {}

    ################
    # EXTRACT DATA #
    ################
    # The sandbox has to remain open until the end
    with SandboxFolder() as folder:
        if os.path.isdir(in_path):
            extract_tree(in_path, folder, silent=silent)
        else:
            if tarfile.is_tarfile(in_path):
                extract_tar(in_path,
                            folder,
                            silent=silent,
                            nodes_export_subfolder=nodes_export_subfolder)
            elif zipfile.is_zipfile(in_path):
                extract_zip(in_path,
                            folder,
                            silent=silent,
                            nodes_export_subfolder=nodes_export_subfolder)
            else:
                raise ValueError(
                    "Unable to detect the input file format, it "
                    "is neither a (possibly compressed) tar file, "
                    "nor a zip file.")

        try:
            with open(folder.get_abs_path('metadata.json')) as f:
                metadata = json.load(f)

            with open(folder.get_abs_path('data.json')) as f:
                data = json.load(f)
        except IOError as e:
            raise ValueError("Unable to find the file {} in the import "
                             "file or folder".format(e.filename))

        ######################
        # PRELIMINARY CHECKS #
        ######################
        if metadata['export_version'] != expected_export_version:
            raise ValueError(
                "File export version is {}, but I can import only "
                "version {}".format(metadata['export_version'],
                                    expected_export_version))

        ##########################################################################
        # CREATE UUID REVERSE TABLES AND CHECK IF I HAVE ALL NODES FOR THE LINKS #
        ##########################################################################
        linked_nodes = set(
            chain.from_iterable(
                (l['input'], l['output']) for l in data['links_uuid']))
        group_nodes = set(chain.from_iterable(
            data['groups_uuid'].itervalues()))

        # I preload the nodes, I need to check each of them later, and I also
        # store them in a reverse table
        # I break up the query due to SQLite limitations..
        relevant_db_nodes = {}
        for group in grouper(999, linked_nodes):
            relevant_db_nodes.update({
                n.uuid: n
                for n in models.DbNode.objects.filter(uuid__in=group)
            })

        db_nodes_uuid = set(relevant_db_nodes.keys())
        dbnode_model = get_class_string(models.DbNode)
        import_nodes_uuid = set(
            v['uuid'] for v in data['export_data'][dbnode_model].values())

        unknown_nodes = linked_nodes.union(group_nodes) - db_nodes_uuid.union(
            import_nodes_uuid)

        if unknown_nodes and not ignore_unknown_nodes:
            raise ValueError(
                "The import file refers to {} nodes with unknown UUID, therefore "
                "it cannot be imported. Either first import the unknown nodes, "
                "or export also the parents when exporting. The unknown UUIDs "
                "are:\n".format(len(unknown_nodes)) +
                "\n".join('* {}'.format(uuid) for uuid in unknown_nodes))

        ###################################
        # DOUBLE-CHECK MODEL DEPENDENCIES #
        ###################################
        # I hardcode here the model order, for simplicity; in any case, this is
        # fixed by the export version
        model_order = [
            get_class_string(m) for m in (
                models.DbUser,
                models.DbComputer,
                models.DbNode,
                models.DbGroup,
            )
        ]

        # Models that do appear in the import file, but whose import is
        # managed manually
        model_manual = [
            get_class_string(m) for m in (
                models.DbLink,
                models.DbAttribute,
            )
        ]

        all_known_models = model_order + model_manual

        for import_field_name in metadata['all_fields_info']:
            if import_field_name not in all_known_models:
                raise NotImplementedError(
                    "Apparently, you are importing a "
                    "file with a model '{}', but this does not appear in "
                    "all_known_models!".format(import_field_name))

        for idx, model_name in enumerate(model_order):
            dependencies = []
            for field in metadata['all_fields_info'][model_name].values():
                try:
                    dependencies.append(field['requires'])
                except KeyError:
                    # (No ForeignKey)
                    pass
            for dependency in dependencies:
                if dependency not in model_order[:idx]:
                    raise ValueError(
                        "Model {} requires {} but would be loaded "
                        "first; stopping...".format(model_name, dependency))

        ###################################################
        # CREATE IMPORT DATA DIRECT UNIQUE_FIELD MAPPINGS #
        ###################################################
        import_unique_ids_mappings = {}
        for model_name, import_data in data['export_data'].iteritems():
            if model_name in metadata['unique_identifiers']:
                # I have to reconvert the pk to integer
                import_unique_ids_mappings[model_name] = {
                    int(k): v[metadata['unique_identifiers'][model_name]]
                    for k, v in import_data.iteritems()
                }

        ###############
        # IMPORT DATA #
        ###############
        # DO ALL WITH A TRANSACTION
        with transaction.commit_on_success():
            foreign_ids_reverse_mappings = {}
            new_entries = {}
            existing_entries = {}

            # I first generate the list of data
            for model_name in model_order:
                Model = get_object_from_string(model_name)
                fields_info = metadata['all_fields_info'].get(model_name, {})
                unique_identifier = metadata['unique_identifiers'].get(
                    model_name, None)

                new_entries[model_name] = {}
                existing_entries[model_name] = {}

                foreign_ids_reverse_mappings[model_name] = {}

                # Not necessarily all models are exported
                if model_name in data['export_data']:

                    if unique_identifier is not None:
                        import_unique_ids = set(
                            v[unique_identifier]
                            for v in data['export_data'][model_name].values())

                        relevant_db_entries = {
                            getattr(n, unique_identifier): n
                            for n in Model.objects.filter(
                                **{
                                    '{}__in'.format(unique_identifier):
                                    import_unique_ids
                                })
                        }

                        foreign_ids_reverse_mappings[model_name] = {
                            k: v.pk
                            for k, v in relevant_db_entries.iteritems()
                        }
                        for k, v in data['export_data'][model_name].iteritems(
                        ):
                            if v[unique_identifier] in relevant_db_entries.keys(
                            ):
                                # Already in DB
                                existing_entries[model_name][k] = v
                            else:
                                # To be added
                                new_entries[model_name][k] = v
                    else:
                        new_entries[model_name] = data['export_data'][
                            model_name].copy()

            # I import data from the given model
            for model_name in model_order:
                Model = get_object_from_string(model_name)
                fields_info = metadata['all_fields_info'].get(model_name, {})
                unique_identifier = metadata['unique_identifiers'].get(
                    model_name, None)

                for import_entry_id, entry_data in existing_entries[
                        model_name].iteritems():
                    unique_id = entry_data[unique_identifier]
                    existing_entry_id = foreign_ids_reverse_mappings[
                        model_name][unique_id]
                    # TODO COMPARE, AND COMPARE ATTRIBUTES
                    if model_name not in ret_dict:
                        ret_dict[model_name] = {'new': [], 'existing': []}
                    ret_dict[model_name]['existing'].append(
                        (import_entry_id, existing_entry_id))
                    if not silent:
                        print "existing %s: %s (%s->%s)" % (
                            model_name, unique_id, import_entry_id,
                            existing_entry_id)
                        # print "  `-> WARNING: NO DUPLICITY CHECK DONE!"
                        # CHECK ALSO FILES!

                # Store all objects for this model in a list, and store them
                # all in once at the end.
                objects_to_create = []
                # This is needed later to associate the import entry with the new pk
                import_entry_ids = {}
                for import_entry_id, entry_data in new_entries[
                        model_name].iteritems():
                    unique_id = entry_data[unique_identifier]
                    import_data = dict(
                        deserialize_field(k,
                                          v,
                                          fields_info=fields_info,
                                          import_unique_ids_mappings=
                                          import_unique_ids_mappings,
                                          foreign_ids_reverse_mappings=
                                          foreign_ids_reverse_mappings)
                        for k, v in entry_data.iteritems())

                    objects_to_create.append(Model(**import_data))
                    import_entry_ids[unique_id] = import_entry_id

                # Before storing entries in the DB, I store the files (if these
                # are nodes). Note: only for new entries!
                if model_name == get_class_string(models.DbNode):
                    if not silent:
                        print "STORING NEW NODE FILES..."
                    for o in objects_to_create:

                        subfolder = folder.get_subfolder(
                            os.path.join(nodes_export_subfolder,
                                         export_shard_uuid(o.uuid)))
                        if not subfolder.exists():
                            raise ValueError(
                                "Unable to find the repository "
                                "folder for node with UUID={} in the exported "
                                "file".format(o.uuid))
                        destdir = RepositoryFolder(section=Node._section_name,
                                                   uuid=o.uuid)
                        # Replace the folder, possibly destroying existing
                        # previous folders, and move the files (faster if we
                        # are on the same filesystem, and
                        # in any case the source is a SandboxFolder)
                        destdir.replace_with_folder(subfolder.abspath,
                                                    move=True,
                                                    overwrite=True)

                # Store them all in once; however, the PK are not set in this way...
                Model.objects.bulk_create(objects_to_create)

                # Get back the just-saved entries
                just_saved = dict(
                    Model.objects.filter(
                        **{
                            "{}__in".format(unique_identifier):
                            import_entry_ids.keys()
                        }).values_list(unique_identifier, 'pk'))

                imported_states = []
                if model_name == get_class_string(models.DbNode):
                    if not silent:
                        print "SETTING THE IMPORTED STATES FOR NEW NODES..."
                    # I set for all nodes, even if I should set it only
                    # for calculations
                    for unique_id, new_pk in just_saved.iteritems():
                        imported_states.append(
                            models.DbCalcState(dbnode_id=new_pk,
                                               state=calc_states.IMPORTED))
                    models.DbCalcState.objects.bulk_create(imported_states)

                # Now I have the PKs, print the info
                # Moreover, set the foreing_ids_reverse_mappings
                for unique_id, new_pk in just_saved.iteritems():
                    import_entry_id = import_entry_ids[unique_id]
                    foreign_ids_reverse_mappings[model_name][
                        unique_id] = new_pk
                    if model_name not in ret_dict:
                        ret_dict[model_name] = {'new': [], 'existing': []}
                    ret_dict[model_name]['new'].append(
                        (import_entry_id, new_pk))

                    if not silent:
                        print "NEW %s: %s (%s->%s)" % (model_name, unique_id,
                                                       import_entry_id, new_pk)

                # For DbNodes, we also have to store Attributes!
                if model_name == get_class_string(models.DbNode):
                    if not silent:
                        print "STORING NEW NODE ATTRIBUTES..."
                    for unique_id, new_pk in just_saved.iteritems():
                        import_entry_id = import_entry_ids[unique_id]
                        # Get attributes from import file
                        try:
                            attributes = data['node_attributes'][str(
                                import_entry_id)]
                            attributes_conversion = data[
                                'node_attributes_conversion'][str(
                                    import_entry_id)]
                        except KeyError:
                            raise ValueError(
                                "Unable to find attribute info "
                                "for DbNode with UUID = {}".format(unique_id))

                        # Here I have to deserialize the attributes
                        deserialized_attributes = deserialize_attributes(
                            attributes, attributes_conversion)
                        models.DbAttribute.reset_values_for_node(
                            dbnode=new_pk,
                            attributes=deserialized_attributes,
                            with_transaction=False)

            if not silent:
                print "STORING NODE LINKS..."
            ## TODO: check that we are not creating input links of an already
            ##       existing node...
            import_links = data['links_uuid']
            links_to_store = []

            # Needed for fast checks of existing links
            existing_links_raw = models.DbLink.objects.all().values_list(
                'input', 'output', 'label')
            existing_links_labels = {(l[0], l[1]): l[2]
                                     for l in existing_links_raw}
            existing_input_links = {(l[1], l[2]): l[0]
                                    for l in existing_links_raw}

            dbnode_reverse_mappings = foreign_ids_reverse_mappings[
                get_class_string(models.DbNode)]
            for link in import_links:
                try:
                    in_id = dbnode_reverse_mappings[link['input']]
                    out_id = dbnode_reverse_mappings[link['output']]
                except KeyError:
                    if ignore_unknown_nodes:
                        continue
                    else:
                        raise ValueError("Trying to create a link with one "
                                         "or both unknown nodes, stopping "
                                         "(in_uuid={}, out_uuid={}, "
                                         "label={})".format(
                                             link['input'], link['output'],
                                             link['label']))

                try:
                    existing_label = existing_links_labels[in_id, out_id]
                    if existing_label != link['label']:
                        raise ValueError(
                            "Trying to rename an existing link name, "
                            "stopping (in={}, out={}, old_label={}, "
                            "new_label={})".format(in_id, out_id,
                                                   existing_label,
                                                   link['label']))
                        # Do nothing, the link is already in place and has the correct
                        # name
                except KeyError:
                    try:
                        existing_input = existing_input_links[out_id,
                                                              link['label']]
                        # If existing_input were the correct one, I would have found
                        # it already in the previous step!
                        raise ValueError(
                            "There exists already an input link to "
                            "node {} with label {} but it does not "
                            "come the expected input {}".format(
                                out_id, link['label'], in_id))
                    except KeyError:
                        # New link
                        links_to_store.append(
                            models.DbLink(input_id=in_id,
                                          output_id=out_id,
                                          label=link['label']))
                        if 'aiida.backends.djsite.db.models.DbLink' not in ret_dict:
                            ret_dict[
                                'aiida.backends.djsite.db.models.DbLink'] = {
                                    'new': []
                                }
                        ret_dict['aiida.backends.djsite.db.models.DbLink'][
                            'new'].append((in_id, out_id))

            # Store new links
            if links_to_store:
                if not silent:
                    print "   ({} new links...)".format(len(links_to_store))

                models.DbLink.objects.bulk_create(links_to_store)
            else:
                if not silent:
                    print "   (0 new links...)"

            if not silent:
                print "STORING GROUP ELEMENTS..."
            import_groups = data['groups_uuid']
            for groupuuid, groupnodes in import_groups.iteritems():
                # TODO: cache these to avoid too many queries
                group = models.DbGroup.objects.get(uuid=groupuuid)
                nodes_to_store = [
                    dbnode_reverse_mappings[node_uuid]
                    for node_uuid in groupnodes
                ]
                if nodes_to_store:
                    group.dbnodes.add(*nodes_to_store)

            ######################################################
            # Put everything in a specific group
            dbnode_model_name = get_class_string(models.DbNode)
            existing = existing_entries.get(dbnode_model_name, {})
            existing_pk = [
                foreign_ids_reverse_mappings[dbnode_model_name][v['uuid']]
                for v in existing.itervalues()
            ]
            new = new_entries.get(dbnode_model_name, {})
            new_pk = [
                foreign_ids_reverse_mappings[dbnode_model_name][v['uuid']]
                for v in new.itervalues()
            ]

            pks_for_group = existing_pk + new_pk

            # So that we do not create empty groups
            if pks_for_group:
                # Get an unique name for the import group, based on the
                # current (local) time
                basename = timezone.localtime(
                    timezone.now()).strftime("%Y%m%d-%H%M%S")
                counter = 0
                created = False
                while not created:
                    if counter == 0:
                        group_name = basename
                    else:
                        group_name = "{}_{}".format(basename, counter)
                    try:
                        group = Group(name=group_name,
                                      type_string=IMPORTGROUP_TYPE).store()
                        created = True
                    except UniquenessError:
                        counter += 1

                # Add all the nodes to the new group
                # TODO: decide if we want to return the group name
                group.add_nodes(
                    models.DbNode.objects.filter(pk__in=pks_for_group))

                if not silent:
                    print "IMPORTED NODES GROUPED IN IMPORT GROUP NAMED '{}'".format(
                        group.name)
            else:
                if not silent:
                    print "NO DBNODES TO IMPORT, SO NO GROUP CREATED"

    if not silent:
        print "*** WARNING: MISSING EXISTING UUID CHECKS!!"
        print "*** WARNING: TODO: UPDATE IMPORT_DATA WITH DEFAULT VALUES! (e.g. calc status, user pwd, ...)"
        print "DONE."

    return ret_dict