Exemplo n.º 1
0
def cmd_stage_relax(details, max_atoms):
    """Commands to analyse the relax stage of the project."""
    import collections
    from aiida.orm import load_group, QueryBuilder, Group, WorkChainNode, Data

    group = load_group('workchain/relax')

    filters_structure = {}
    filters_workchain = {'attributes.exit_status': 0}

    if max_atoms is not None:
        filters_structure['attributes.sites'] = {'shorter': max_atoms + 1}

    if not details:
        query = QueryBuilder()
        query.append(Group, filters={'id': group.pk}, tag='group')
        query.append(WorkChainNode,
                     with_group='group',
                     filters=filters_workchain,
                     tag='relax',
                     project='id')
        query.append(WorkChainNode,
                     with_incoming='relax',
                     project='id',
                     tag='base')
        query.append(Data,
                     with_outgoing='relax',
                     edge_filters={'label': 'structure'},
                     filters=filters_structure)

        mapping = collections.defaultdict(list)

        for relax, called in query.iterall():
            mapping[relax].append(called)

        counts = []

        for called in mapping.values():
            counts.append(len(called))

        table = []
        counter = collections.Counter(counts)
        total = sum(counter.values())
        cumulative = 0

        for iterations, count in sorted(counter.items(),
                                        key=lambda item: item[1],
                                        reverse=True):
            percentage = (count / total) * 100
            cumulative += percentage
            table.append((count, percentage, cumulative, iterations))

        click.echo(
            tabulate.tabulate(
                table,
                headers=['Count', 'Percentage', 'Cumulative', 'Iterations']))

    else:
        print('test')
Exemplo n.º 2
0
    def __init__(self, title=''):
        self.title = title

        # Structure objects we want to query for.
        self.query_structure_type = (DataFactory('structure'), DataFactory('cif'))

        # Extracting available process labels.
        qbuilder = QueryBuilder().append((CalcJobNode, WorkChainNode), project="label")
        self.drop_label = ipw.Dropdown(options=sorted({'All'}.union({i[0] for i in qbuilder.iterall() if i[0]})),
                                       value='All',
                                       description='Process Label',
                                       disabled=True,
                                       style={'description_width': '120px'},
                                       layout={'width': '50%'})
        self.drop_label.observe(self.search, names='value')

        # Disable process labels selection if we are not looking for the calculated structures.
        def disable_drop_label(change):
            self.drop_label.disabled = not change['new'] == 'calculated'

        # Select structures kind.
        self.mode = ipw.RadioButtons(options=['all', 'uploaded', 'edited', 'calculated'], layout={'width': '25%'})
        self.mode.observe(self.search, names='value')
        self.mode.observe(disable_drop_label, names='value')

        # Date range.
        # Note: there is Date picker widget, but it currently does not work in Safari:
        # https://ipywidgets.readthedocs.io/en/latest/examples/Widget%20List.html#Date-picker
        date_text = ipw.HTML(value='<p>Select the date range:</p>')
        self.start_date_widget = ipw.Text(value='', description='From: ', style={'description_width': '120px'})
        self.end_date_widget = ipw.Text(value='', description='To: ')

        # Search button.
        btn_search = ipw.Button(description='Search',
                                button_style='info',
                                layout={
                                    'width': 'initial',
                                    'margin': '2px 0 0 2em'
                                })
        btn_search.on_click(self.search)

        age_selection = ipw.VBox(
            [date_text, ipw.HBox([self.start_date_widget, self.end_date_widget, btn_search])],
            layout={
                'border': '1px solid #fafafa',
                'padding': '1em'
            })

        h_line = ipw.HTML('<hr>')
        box = ipw.VBox([age_selection, h_line, ipw.HBox([self.mode, self.drop_label])])

        self.results = ipw.Dropdown(layout={'width': '900px'})
        self.results.observe(self._on_select_structure)
        self.search()
        super().__init__([box, h_line, self.results])
Exemplo n.º 3
0
def cmd_stage_scf(max_atoms):
    """Commands to analyse the scf stage of the project."""
    import collections
    from aiida.orm import load_group, QueryBuilder, Group, CalcJobNode, WorkChainNode, Data

    group = load_group('workchain/scf')

    filters_structure = {}
    filters_workchain = {'attributes.exit_status': 0}

    if max_atoms is not None:
        filters_structure['attributes.sites'] = {'shorter': max_atoms + 1}

    query = QueryBuilder()
    query.append(Group, filters={'id': group.pk}, tag='group')
    query.append(WorkChainNode,
                 with_group='group',
                 filters=filters_workchain,
                 tag='scf',
                 project='id')
    query.append(CalcJobNode,
                 with_incoming='scf',
                 project='attributes.exit_status')
    query.append(Data,
                 with_outgoing='scf',
                 edge_filters={'label': 'pw__structure'},
                 filters=filters_structure)

    mapping = collections.defaultdict(list)

    for scf, exit_status in query.iterall():
        mapping[scf].append(exit_status)

    counts = []

    for exit_statuses in mapping.values():
        counts.append(tuple(exit_statuses))

    table = []
    counter = collections.Counter(counts)
    total = sum(counter.values())
    cumulative = 0

    for exit_statuses, count in sorted(counter.items(),
                                       key=lambda item: item[1],
                                       reverse=True):
        percentage = (count / total) * 100
        cumulative += percentage
        table.append((count, percentage, cumulative, exit_statuses))

    click.echo(
        tabulate.tabulate(
            table,
            headers=['Count', 'Percentage', 'Cumulative', 'Exit statuses']))
Exemplo n.º 4
0
def _make_import_group(*, group: Optional[ImportGroup],
                       node_pks: List[int]) -> ImportGroup:
    """Make an import group containing all imported nodes.

    :param group: Use an existing group
    :param node_pks: node pks to add to group

    """
    # So that we do not create empty groups
    if not node_pks:
        IMPORT_LOGGER.debug('No nodes to import, so no import group created')
        return group

    # If user specified a group, import all things into it
    if not group:
        # Get an unique name for the import group, based on the current (local) time
        basename = timezone.localtime(timezone.now()).strftime('%Y%m%d-%H%M%S')
        counter = 0
        group_label = basename

        while Group.objects.find(filters={'label': group_label}):
            counter += 1
            group_label = f'{basename}_{counter}'

            if counter == MAX_GROUPS:
                raise exceptions.ImportUniquenessError(
                    f"Overflow of import groups (more than {MAX_GROUPS} groups exists with basename '{basename}')"
                )
        group = ImportGroup(label=group_label).store()

    # Add all the nodes to the new group
    builder = QueryBuilder().append(Node, filters={'id': {'in': node_pks}})

    first = True
    nodes = []
    description = 'Creating import Group - Preprocessing'

    with get_progress_reporter()(total=len(node_pks),
                                 desc=description) as progress:
        for entry in builder.iterall():
            if first:
                progress.set_description_str('Creating import Group',
                                             refresh=False)
                first = False
            progress.update()
            nodes.append(entry[0])

        group.add_nodes(nodes)
        progress.set_description_str('Done (cleaning up)', refresh=True)

    return group
Exemplo n.º 5
0
def traj_to_atoms(traj, combine_ancesters=False, eng_key="enthalpy"):
    """
    Generate a list of ASE Atoms given an AiiDA TrajectoryData object
    :param bool combine_ancesters: If true will try to combine trajectory
    from ancestor calculations

    :returns: A list of atoms for the trajectory.
    """
    from ase import Atoms
    from ase.calculators.singlepoint import SinglePointCalculator
    from aiida.orm import QueryBuilder, Node, CalcJobNode
    from aiida_castep.common import OUTPUT_LINKNAMES

    # If a CalcJobNode is passed, select its output trajectory
    if isinstance(traj, CalcJobNode):
        traj = traj.outputs.__getattr__(OUTPUT_LINKNAMES['trajectory'])
    # Combine trajectory from ancesters
    if combine_ancesters is True:
        qbd = QueryBuilder()
        qbd.append(Node, filters={"uuid": traj.uuid})
        qbd.append(CalcJobNode, tag="ans", ancestor_of=Node)
        qbd.order_by({"ans": "id"})
        calcs = [_[0] for _ in qbd.iterall()]
        atoms_list = []
        for counter in calcs:
            atoms_list.extend(
                traj_to_atoms(counter.outputs.__getattr__(
                    OUTPUT_LINKNAMES['trajectory']),
                              combine_ancesters=False,
                              eng_key=eng_key))
        return atoms_list
    forces = traj.get_array("forces")
    symbols = traj.get_array("symbols")
    positions = traj.get_array("positions")
    try:
        eng = traj.get_array(eng_key)
    except KeyError:
        eng = None
    cells = traj.get_array("cells")
    atoms_traj = []
    for counter, pos, eng_, force in zip(cells, positions, eng, forces):
        atoms = Atoms(symbols=symbols, cell=counter, pbc=True, positions=pos)
        calc = SinglePointCalculator(atoms, energy=eng_, forces=force)
        atoms.set_calculator(calc)
        atoms_traj.append(atoms)
    return atoms_traj
Exemplo n.º 6
0
def migrate_otfg_family():
    """Migrate the old OTFG families to new families"""
    old_types = [OLD_OTFGGROUP_TYPE, "data.castep.usp.family"]
    q = QueryBuilder()
    q.append(Group, filters={'type_string': {'in': old_types}})

    migrated = []
    created = []
    for (old_group, ) in q.iterall():
        new_group, created = OTFGGroup.objects.get_or_create(
            label=old_group.label, description=old_group.description)
        new_group.add_nodes(list(old_group.nodes))
        new_group.store()
        migrated.append(new_group.label)
        if created:
            print("Created new style Group for <{}>".format(old_group.label))
        else:
            print(("Adding nodes to existing group <{}>".format(
                old_group.label)))

    return
Exemplo n.º 7
0
def delete_nodes(pks, verbosity=0, dry_run=False, force=False, **kwargs):
    """Delete nodes by a list of pks.

    This command will delete not only the specified nodes, but also the ones that are
    linked to these and should be also deleted in order to keep a consistent provenance
    according to the rules explained in the concepts section of the documentation.
    In summary:

    1. If a DATA node is deleted, any process nodes linked to it will also be deleted.

    2. If a CALC node is deleted, any incoming WORK node (callers) will be deleted as
    well whereas any incoming DATA node (inputs) will be kept. Outgoing DATA nodes
    (outputs) will be deleted by default but this can be disabled.

    3. If a WORK node is deleted, any incoming WORK node (callers) will be deleted as
    well, but all DATA nodes will be kept. Outgoing WORK or CALC nodes will be kept by
    default, but deletion of either of both kind of connected nodes can be enabled.

    These rules are 'recursive', so if a CALC node is deleted, then its output DATA
    nodes will be deleted as well, and then any CALC node that may have those as
    inputs, and so on.

    :param pks: a list of the PKs of the nodes to delete
    :param bool force: do not ask for confirmation to delete nodes.
    :param int verbosity: 0 prints nothing,
                          1 prints just sums and total,
                          2 prints individual nodes.

    :param kwargs: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` what rule names
        are toggleable and what the defaults are.
    :param bool dry_run:
        Just perform a dry run and do not delete anything. Print statistics according
        to the verbosity level set.
    :param bool force:
        Do not ask for confirmation to delete nodes.
    """
    # pylint: disable=too-many-arguments,too-many-branches,too-many-locals,too-many-statements
    from aiida.backends.utils import delete_nodes_and_connections
    from aiida.common import exceptions
    from aiida.orm import Node, QueryBuilder, load_node
    from aiida.tools.graph.graph_traversers import get_nodes_delete

    starting_pks = []
    for pk in pks:
        try:
            load_node(pk)
        except exceptions.NotExistent:
            echo.echo_warning(
                f'warning: node with pk<{pk}> does not exist, skipping')
        else:
            starting_pks.append(pk)

    # An empty set might be problematic for the queries done below.
    if not starting_pks:
        if verbosity:
            echo.echo('Nothing to delete')
        return

    pks_set_to_delete = get_nodes_delete(starting_pks, **kwargs)['nodes']

    if verbosity > 0:
        echo.echo('I {} delete {} node{}'.format(
            'would' if dry_run else 'will', len(pks_set_to_delete),
            's' if len(pks_set_to_delete) > 1 else ''))
        if verbosity > 1:
            builder = QueryBuilder().append(
                Node,
                filters={'id': {
                    'in': pks_set_to_delete
                }},
                project=('uuid', 'id', 'node_type', 'label'))
            echo.echo(f"The nodes I {'would' if dry_run else 'will'} delete:")
            for uuid, pk, type_string, label in builder.iterall():
                try:
                    short_type_string = type_string.split('.')[-2]
                except IndexError:
                    short_type_string = type_string
                echo.echo(f'   {uuid} {pk} {short_type_string} {label}')

    if dry_run:
        if verbosity > 0:
            echo.echo(
                '\nThis was a dry run, exiting without deleting anything')
        return

    # Asking for user confirmation here
    if force:
        pass
    else:
        echo.echo_warning(
            f'YOU ARE ABOUT TO DELETE {len(pks_set_to_delete)} NODES! THIS CANNOT BE UNDONE!'
        )
        if not click.confirm('Shall I continue?'):
            echo.echo('Exiting without deleting')
            return

    # Recover the list of folders to delete before actually deleting the nodes. I will delete the folders only later,
    # so that if there is a problem during the deletion of the nodes in the DB, I don't delete the folders
    repositories = [load_node(pk)._repository for pk in pks_set_to_delete]  # pylint: disable=protected-access

    if verbosity > 0:
        echo.echo('Starting node deletion...')
    delete_nodes_and_connections(pks_set_to_delete)

    if verbosity > 0:
        echo.echo(
            'Nodes deleted from database, deleting files from the repository now...'
        )

    # If we are here, we managed to delete the entries from the DB.
    # I can now delete the folders
    for repository in repositories:
        repository.erase(force=True)

    if verbosity > 0:
        echo.echo('Deletion completed.')
Exemplo n.º 8
0
    def search(self, _=None):
        """Launch the search of structures in AiiDA database."""
        self.preprocess()

        qbuild = QueryBuilder()

        # If the date range is valid, use it for the search
        try:
            start_date = datetime.datetime.strptime(
                self.start_date_widget.value, '%Y-%m-%d')
            end_date = datetime.datetime.strptime(
                self.end_date_widget.value,
                '%Y-%m-%d') + datetime.timedelta(hours=24)

        # Otherwise revert to the standard (i.e. last 7 days)
        except ValueError:
            start_date = datetime.datetime.now() - datetime.timedelta(days=7)
            end_date = datetime.datetime.now() + datetime.timedelta(hours=24)

            self.start_date_widget.value = start_date.strftime('%Y-%m-%d')
            self.end_date_widget.value = end_date.strftime('%Y-%m-%d')

        filters = {}
        filters['ctime'] = {'and': [{'>': start_date}, {'<=': end_date}]}

        if self.mode.value == "uploaded":
            qbuild2 = QueryBuilder().append(self.query_structure_type,
                                            project=["id"],
                                            tag='structures').append(
                                                Node,
                                                with_outgoing='structures')
            processed_nodes = [n[0] for n in qbuild2.all()]
            if processed_nodes:
                filters['id'] = {"!in": processed_nodes}
            qbuild.append(self.query_structure_type, filters=filters)

        elif self.mode.value == "calculated":
            if self.drop_label.value == 'All':
                qbuild.append((CalcJobNode, WorkChainNode),
                              tag='calcjobworkchain')
            else:
                qbuild.append((CalcJobNode, WorkChainNode),
                              filters={'label': self.drop_label.value},
                              tag='calcjobworkchain')
            qbuild.append(self.query_structure_type,
                          with_incoming='calcjobworkchain',
                          filters=filters)

        elif self.mode.value == "edited":
            qbuild.append(CalcFunctionNode)
            qbuild.append(self.query_structure_type,
                          with_incoming=CalcFunctionNode,
                          filters=filters)

        elif self.mode.value == "all":
            qbuild.append(self.query_structure_type, filters=filters)

        qbuild.order_by({self.query_structure_type: {'ctime': 'desc'}})
        matches = {n[0] for n in qbuild.iterall()}
        matches = sorted(matches, reverse=True, key=lambda n: n.ctime)

        options = OrderedDict()
        options["Select a Structure ({} found)".format(len(matches))] = False

        for mch in matches:
            label = "PK: {}".format(mch.id)
            label += " | " + mch.ctime.strftime("%Y-%m-%d %H:%M")
            label += " | " + mch.get_extra("formula")
            label += " | " + mch.node_type.split('.')[-2]
            label += " | " + mch.label
            label += " | " + mch.description
            options[label] = mch

        self.results.options = options
Exemplo n.º 9
0
def import_data_sqla(in_path,
                     group=None,
                     ignore_unknown_nodes=False,
                     extras_mode_existing='kcl',
                     extras_mode_new='import',
                     comment_mode='newest',
                     silent=False,
                     **kwargs):
    """Import exported AiiDA archive to the AiiDA database and repository.

    Specific for the SQLAlchemy backend.
    If ``in_path`` is a folder, calls extract_tree; otherwise, tries to detect the compression format
    (zip, tar.gz, tar.bz2, ...) and calls the correct function.

    :param in_path: the path to a file or folder that can be imported in AiiDA.
    :type in_path: str

    :param group: Group wherein all imported Nodes will be placed.
    :type group: :py:class:`~aiida.orm.groups.Group`

    :param extras_mode_existing: 3 letter code that will identify what to do with the extras import.
        The first letter acts on extras that are present in the original node and not present in the imported node.
        Can be either:
        'k' (keep it) or
        'n' (do not keep it).
        The second letter acts on the imported extras that are not present in the original node.
        Can be either:
        'c' (create it) or
        'n' (do not create it).
        The third letter defines what to do in case of a name collision.
        Can be either:
        'l' (leave the old value),
        'u' (update with a new value),
        'd' (delete the extra), or
        'a' (ask what to do if the content is different).
    :type extras_mode_existing: str

    :param extras_mode_new: 'import' to import extras of new nodes or 'none' to ignore them.
    :type extras_mode_new: str

    :param comment_mode: Comment import modes (when same UUIDs are found).
        Can be either:
        'newest' (will keep the Comment with the most recent modification time (mtime)) or
        'overwrite' (will overwrite existing Comments with the ones from the import file).
    :type comment_mode: str

    :param silent: suppress progress bar and summary.
    :type silent: bool

    :return: New and existing Nodes and Links.
    :rtype: dict

    :raises `~aiida.tools.importexport.common.exceptions.ImportValidationError`: if parameters or the contents of
        `metadata.json` or `data.json` can not be validated.
    :raises `~aiida.tools.importexport.common.exceptions.CorruptArchive`: if the provided archive at ``in_path`` is
        corrupted.
    :raises `~aiida.tools.importexport.common.exceptions.IncompatibleArchiveVersionError`: if the provided archive's
        export version is not equal to the export version of AiiDA at the moment of import.
    :raises `~aiida.tools.importexport.common.exceptions.ArchiveImportError`: if there are any internal errors when
        importing.
    :raises `~aiida.tools.importexport.common.exceptions.ImportUniquenessError`: if a new unique entity can not be
        created.
    """
    from aiida.backends.sqlalchemy.models.node import DbNode, DbLink
    from aiida.backends.sqlalchemy.utils import flag_modified

    # This is the export version expected by this function
    expected_export_version = StrictVersion(EXPORT_VERSION)

    # The returned dictionary with new and existing nodes and links
    ret_dict = {}

    # Initial check(s)
    if group:
        if not isinstance(group, Group):
            raise exceptions.ImportValidationError(
                'group must be a Group entity')
        elif not group.is_stored:
            group.store()

    if silent:
        logging.disable(level=logging.CRITICAL)

    ################
    # EXTRACT DATA #
    ################
    # The sandbox has to remain open until the end
    with SandboxFolder() as folder:
        if os.path.isdir(in_path):
            extract_tree(in_path, folder)
        else:
            if tarfile.is_tarfile(in_path):
                extract_tar(in_path,
                            folder,
                            silent=silent,
                            nodes_export_subfolder=NODES_EXPORT_SUBFOLDER,
                            **kwargs)
            elif zipfile.is_zipfile(in_path):
                extract_zip(in_path,
                            folder,
                            silent=silent,
                            nodes_export_subfolder=NODES_EXPORT_SUBFOLDER,
                            **kwargs)
            else:
                raise exceptions.ImportValidationError(
                    'Unable to detect the input file format, it is neither a '
                    'tar file, nor a (possibly compressed) zip file.')

        if not folder.get_content_list():
            raise exceptions.CorruptArchive(
                'The provided file/folder ({}) is empty'.format(in_path))
        try:
            IMPORT_LOGGER.debug('CACHING metadata.json')
            with open(folder.get_abs_path('metadata.json'),
                      encoding='utf8') as fhandle:
                metadata = json.load(fhandle)

            IMPORT_LOGGER.debug('CACHING data.json')
            with open(folder.get_abs_path('data.json'),
                      encoding='utf8') as fhandle:
                data = json.load(fhandle)
        except IOError as error:
            raise exceptions.CorruptArchive(
                'Unable to find the file {} in the import file or folder'.
                format(error.filename))

        ######################
        # PRELIMINARY CHECKS #
        ######################
        export_version = StrictVersion(str(metadata['export_version']))
        if export_version != expected_export_version:
            msg = 'Export file version is {}, can import only version {}'\
                    .format(metadata['export_version'], expected_export_version)
            if export_version < expected_export_version:
                msg += "\nUse 'verdi export migrate' to update this export file."
            else:
                msg += '\nUpdate your AiiDA version in order to import this file.'

            raise exceptions.IncompatibleArchiveVersionError(msg)

        start_summary(in_path, comment_mode, extras_mode_new,
                      extras_mode_existing)

        ###################################################################
        #           CREATE UUID REVERSE TABLES AND CHECK IF               #
        #              I HAVE ALL NODES FOR THE LINKS                     #
        ###################################################################
        IMPORT_LOGGER.debug(
            'CHECKING IF NODES FROM LINKS ARE IN DB OR ARCHIVE...')

        linked_nodes = set(
            chain.from_iterable(
                (l['input'], l['output']) for l in data['links_uuid']))
        group_nodes = set(chain.from_iterable(data['groups_uuid'].values()))

        # Check that UUIDs are valid
        linked_nodes = set(x for x in linked_nodes if validate_uuid(x))
        group_nodes = set(x for x in group_nodes if validate_uuid(x))

        import_nodes_uuid = set()
        for value in data['export_data'].get(NODE_ENTITY_NAME, {}).values():
            import_nodes_uuid.add(value['uuid'])

        unknown_nodes = linked_nodes.union(group_nodes) - import_nodes_uuid

        if unknown_nodes and not ignore_unknown_nodes:
            raise exceptions.DanglingLinkError(
                'The import file refers to {} nodes with unknown UUID, therefore it cannot be imported. Either first '
                'import the unknown nodes, or export also the parents when exporting. The unknown UUIDs are:\n'
                ''.format(len(unknown_nodes)) +
                '\n'.join('* {}'.format(uuid) for uuid in unknown_nodes))

        ###################################
        # DOUBLE-CHECK MODEL DEPENDENCIES #
        ###################################
        # The entity import order. It is defined by the database model relationships.
        entity_order = [
            USER_ENTITY_NAME, COMPUTER_ENTITY_NAME, NODE_ENTITY_NAME,
            GROUP_ENTITY_NAME, LOG_ENTITY_NAME, COMMENT_ENTITY_NAME
        ]

        #  I make a new list that contains the entity names:
        # eg: ['User', 'Computer', 'Node', 'Group']
        for import_field_name in metadata['all_fields_info']:
            if import_field_name not in entity_order:
                raise exceptions.ImportValidationError(
                    "You are trying to import an unknown model '{}'!".format(
                        import_field_name))

        for idx, entity_name in enumerate(entity_order):
            dependencies = []
            # for every field, I checked the dependencies given as value for key requires
            for field in metadata['all_fields_info'][entity_name].values():
                try:
                    dependencies.append(field['requires'])
                except KeyError:
                    # (No ForeignKey)
                    pass
            for dependency in dependencies:
                if dependency not in entity_order[:idx]:
                    raise exceptions.ArchiveImportError(
                        'Entity {} requires {} but would be loaded first; stopping...'
                        .format(entity_name, dependency))

        ###################################################
        # CREATE IMPORT DATA DIRECT UNIQUE_FIELD MAPPINGS #
        ###################################################
        # This is nested dictionary of entity_name:{id:uuid}
        # to map one id (the pk) to a different one.
        # One of the things to remove for v0.4
        # {
        # 'Node': {2362: '82a897b5-fb3a-47d7-8b22-c5fe1b4f2c14',
        #           2363: 'ef04aa5d-99e7-4bfd-95ef-fe412a6a3524', 2364: '1dc59576-af21-4d71-81c2-bac1fc82a84a'},
        # 'User': {1: 'aiida@localhost'}
        # }
        IMPORT_LOGGER.debug('CREATING PK-2-UUID/EMAIL MAPPING...')
        import_unique_ids_mappings = {}
        # Export data since v0.3 contains the keys entity_name
        for entity_name, import_data in data['export_data'].items():
            # Again I need the entity_name since that's what's being stored since 0.3
            if entity_name in metadata['unique_identifiers']:
                # I have to reconvert the pk to integer
                import_unique_ids_mappings[entity_name] = {
                    int(k): v[metadata['unique_identifiers'][entity_name]]
                    for k, v in import_data.items()
                }
        ###############
        # IMPORT DATA #
        ###############
        # DO ALL WITH A TRANSACTION
        import aiida.backends.sqlalchemy

        session = aiida.backends.sqlalchemy.get_scoped_session()

        try:
            foreign_ids_reverse_mappings = {}
            new_entries = {}
            existing_entries = {}

            IMPORT_LOGGER.debug('GENERATING LIST OF DATA...')

            # Instantiate progress bar
            progress_bar = get_progress_bar(total=1,
                                            leave=False,
                                            disable=silent)
            pbar_base_str = 'Generating list of data - '

            # Get total entities from data.json
            # To be used with progress bar
            number_of_entities = 0

            # I first generate the list of data
            for entity_name in entity_order:
                entity = entity_names_to_entities[entity_name]
                # I get the unique identifier, since v0.3 stored under entity_name
                unique_identifier = metadata['unique_identifiers'].get(
                    entity_name, None)

                # so, new_entries. Also, since v0.3 it makes more sense to use the entity_name
                new_entries[entity_name] = {}
                existing_entries[entity_name] = {}
                foreign_ids_reverse_mappings[entity_name] = {}

                # Not necessarily all models are exported
                if entity_name in data['export_data']:

                    IMPORT_LOGGER.debug('  %s...', entity_name)

                    progress_bar.set_description_str(pbar_base_str +
                                                     entity_name,
                                                     refresh=False)
                    number_of_entities += len(data['export_data'][entity_name])

                    if unique_identifier is not None:
                        import_unique_ids = set(
                            v[unique_identifier]
                            for v in data['export_data'][entity_name].values())

                        relevant_db_entries = {}
                        if import_unique_ids:
                            builder = QueryBuilder()
                            builder.append(entity,
                                           filters={
                                               unique_identifier: {
                                                   'in': import_unique_ids
                                               }
                                           },
                                           project='*')

                            if builder.count():
                                progress_bar = get_progress_bar(
                                    total=builder.count(), disable=silent)
                                for object_ in builder.iterall():
                                    progress_bar.update()

                                    relevant_db_entries.update({
                                        getattr(object_[0], unique_identifier):
                                        object_[0]
                                    })

                            foreign_ids_reverse_mappings[entity_name] = {
                                k: v.pk
                                for k, v in relevant_db_entries.items()
                            }

                        IMPORT_LOGGER.debug('    GOING THROUGH ARCHIVE...')

                        imported_comp_names = set()
                        for key, value in data['export_data'][
                                entity_name].items():
                            if entity_name == GROUP_ENTITY_NAME:
                                # Check if there is already a group with the same name,
                                # and if so, recreate the name
                                orig_label = value['label']
                                dupl_counter = 0
                                while QueryBuilder().append(
                                        entity,
                                        filters={
                                            'label': {
                                                '==': value['label']
                                            }
                                        }).count():
                                    # Rename the new group
                                    value[
                                        'label'] = orig_label + DUPL_SUFFIX.format(
                                            dupl_counter)
                                    dupl_counter += 1
                                    if dupl_counter == 100:
                                        raise exceptions.ImportUniquenessError(
                                            'A group of that label ( {} ) already exists and I could not create a new '
                                            'one'.format(orig_label))

                            elif entity_name == COMPUTER_ENTITY_NAME:
                                # The following is done for compatibility
                                # reasons in case the export file was generated
                                # with the Django export method. In Django the
                                # metadata and the transport parameters are
                                # stored as (unicode) strings of the serialized
                                # JSON objects and not as simple serialized
                                # JSON objects.
                                if isinstance(value['metadata'], (str, bytes)):
                                    value['metadata'] = json.loads(
                                        value['metadata'])

                                # Check if there is already a computer with the
                                # same name in the database
                                builder = QueryBuilder()
                                builder.append(
                                    entity,
                                    filters={'name': {
                                        '==': value['name']
                                    }},
                                    project=['*'],
                                    tag='res')
                                dupl = builder.count(
                                ) or value['name'] in imported_comp_names
                                dupl_counter = 0
                                orig_name = value['name']
                                while dupl:
                                    # Rename the new computer
                                    value[
                                        'name'] = orig_name + DUPL_SUFFIX.format(
                                            dupl_counter)
                                    builder = QueryBuilder()
                                    builder.append(entity,
                                                   filters={
                                                       'name': {
                                                           '==': value['name']
                                                       }
                                                   },
                                                   project=['*'],
                                                   tag='res')
                                    dupl = builder.count(
                                    ) or value['name'] in imported_comp_names
                                    dupl_counter += 1
                                    if dupl_counter == 100:
                                        raise exceptions.ImportUniquenessError(
                                            'A computer of that name ( {} ) already exists and I could not create a '
                                            'new one'.format(orig_name))

                                imported_comp_names.add(value['name'])

                            if value[unique_identifier] in relevant_db_entries:
                                # Already in DB
                                # again, switched to entity_name in v0.3
                                existing_entries[entity_name][key] = value
                            else:
                                # To be added
                                new_entries[entity_name][key] = value
                    else:
                        new_entries[entity_name] = data['export_data'][
                            entity_name]

            # Progress bar - reset for import
            progress_bar = get_progress_bar(total=number_of_entities,
                                            disable=silent)
            reset_progress_bar = {}

            # I import data from the given model
            for entity_name in entity_order:
                entity = entity_names_to_entities[entity_name]
                fields_info = metadata['all_fields_info'].get(entity_name, {})
                unique_identifier = metadata['unique_identifiers'].get(
                    entity_name, '')

                # Progress bar initialization - Model
                if reset_progress_bar:
                    progress_bar = get_progress_bar(
                        total=reset_progress_bar['total'], disable=silent)
                    progress_bar.n = reset_progress_bar['n']
                    reset_progress_bar = {}
                pbar_base_str = '{}s - '.format(entity_name)
                progress_bar.set_description_str(pbar_base_str +
                                                 'Initializing',
                                                 refresh=True)

                # EXISTING ENTRIES
                if existing_entries[entity_name]:
                    # Progress bar update - Model
                    progress_bar.set_description_str(
                        pbar_base_str + '{} existing entries'.format(
                            len(existing_entries[entity_name])),
                        refresh=True)

                for import_entry_pk, entry_data in existing_entries[
                        entity_name].items():
                    unique_id = entry_data[unique_identifier]
                    existing_entry_pk = foreign_ids_reverse_mappings[
                        entity_name][unique_id]
                    import_data = dict(
                        deserialize_field(k,
                                          v,
                                          fields_info=fields_info,
                                          import_unique_ids_mappings=
                                          import_unique_ids_mappings,
                                          foreign_ids_reverse_mappings=
                                          foreign_ids_reverse_mappings)
                        for k, v in entry_data.items())
                    # TODO COMPARE, AND COMPARE ATTRIBUTES

                    if entity_name == COMMENT_ENTITY_NAME:
                        new_entry_uuid = merge_comment(import_data,
                                                       comment_mode)
                        if new_entry_uuid is not None:
                            entry_data[unique_identifier] = new_entry_uuid
                            new_entries[entity_name][
                                import_entry_pk] = entry_data

                    if entity_name not in ret_dict:
                        ret_dict[entity_name] = {'new': [], 'existing': []}
                    ret_dict[entity_name]['existing'].append(
                        (import_entry_pk, existing_entry_pk))
                    IMPORT_LOGGER.debug('Existing %s: %s (%s->%s)',
                                        entity_name, unique_id,
                                        import_entry_pk, existing_entry_pk)

                # Store all objects for this model in a list, and store them
                # all in once at the end.
                objects_to_create = list()
                # In the following list we add the objects to be updated
                objects_to_update = list()
                # This is needed later to associate the import entry with the new pk
                import_new_entry_pks = dict()

                # NEW ENTRIES
                if new_entries[entity_name]:
                    # Progress bar update - Model
                    progress_bar.set_description_str(
                        pbar_base_str +
                        '{} new entries'.format(len(new_entries[entity_name])),
                        refresh=True)

                for import_entry_pk, entry_data in new_entries[
                        entity_name].items():
                    unique_id = entry_data[unique_identifier]
                    import_data = dict(
                        deserialize_field(k,
                                          v,
                                          fields_info=fields_info,
                                          import_unique_ids_mappings=
                                          import_unique_ids_mappings,
                                          foreign_ids_reverse_mappings=
                                          foreign_ids_reverse_mappings)
                        for k, v in entry_data.items())

                    # We convert the Django fields to SQLA. Note that some of
                    # the Django fields were converted to SQLA compatible
                    # fields by the deserialize_field method. This was done
                    # for optimization reasons in Django but makes them
                    # compatible with the SQLA schema and they don't need any
                    # further conversion.
                    if entity_name in file_fields_to_model_fields:
                        for file_fkey in file_fields_to_model_fields[
                                entity_name]:

                            # This is an exception because the DbLog model defines the `_metadata` column instead of the
                            # `metadata` column used in the Django model. This is because the SqlAlchemy model base
                            # class already has a metadata attribute that cannot be overridden. For consistency, the
                            # `DbLog` class however expects the `metadata` keyword in its constructor, so we should
                            # ignore the mapping here
                            if entity_name == LOG_ENTITY_NAME and file_fkey == 'metadata':
                                continue

                            model_fkey = file_fields_to_model_fields[
                                entity_name][file_fkey]
                            if model_fkey in import_data:
                                continue
                            import_data[model_fkey] = import_data[file_fkey]
                            import_data.pop(file_fkey, None)

                    db_entity = get_object_from_string(
                        entity_names_to_sqla_schema[entity_name])

                    objects_to_create.append(db_entity(**import_data))
                    import_new_entry_pks[unique_id] = import_entry_pk

                if entity_name == NODE_ENTITY_NAME:
                    IMPORT_LOGGER.debug(
                        'STORING NEW NODE REPOSITORY FILES & ATTRIBUTES...')

                    # NEW NODES
                    for object_ in objects_to_create:
                        import_entry_uuid = object_.uuid
                        import_entry_pk = import_new_entry_pks[
                            import_entry_uuid]

                        # Progress bar initialization - Node
                        progress_bar.update()
                        pbar_node_base_str = pbar_base_str + 'UUID={} - '.format(
                            import_entry_uuid.split('-')[0])

                        # Before storing entries in the DB, I store the files (if these are nodes).
                        # Note: only for new entries!
                        subfolder = folder.get_subfolder(
                            os.path.join(NODES_EXPORT_SUBFOLDER,
                                         export_shard_uuid(import_entry_uuid)))
                        if not subfolder.exists():
                            raise exceptions.CorruptArchive(
                                'Unable to find the repository folder for Node with UUID={} in the exported '
                                'file'.format(import_entry_uuid))
                        destdir = RepositoryFolder(
                            section=Repository._section_name,
                            uuid=import_entry_uuid)
                        # Replace the folder, possibly destroying existing previous folders, and move the files
                        # (faster if we are on the same filesystem, and in any case the source is a SandboxFolder)
                        progress_bar.set_description_str(pbar_node_base_str +
                                                         'Repository',
                                                         refresh=True)
                        destdir.replace_with_folder(subfolder.abspath,
                                                    move=True,
                                                    overwrite=True)

                        # For Nodes, we also have to store Attributes!
                        IMPORT_LOGGER.debug('STORING NEW NODE ATTRIBUTES...')
                        progress_bar.set_description_str(pbar_node_base_str +
                                                         'Attributes',
                                                         refresh=True)

                        # Get attributes from import file
                        try:
                            object_.attributes = data['node_attributes'][str(
                                import_entry_pk)]
                        except KeyError:
                            raise exceptions.CorruptArchive(
                                'Unable to find attribute info for Node with UUID={}'
                                .format(import_entry_uuid))

                        # For DbNodes, we also have to store extras
                        if extras_mode_new == 'import':
                            IMPORT_LOGGER.debug('STORING NEW NODE EXTRAS...')
                            progress_bar.set_description_str(
                                pbar_node_base_str + 'Extras', refresh=True)

                            # Get extras from import file
                            try:
                                extras = data['node_extras'][str(
                                    import_entry_pk)]
                            except KeyError:
                                raise exceptions.CorruptArchive(
                                    'Unable to find extra info for Node with UUID={}'
                                    .format(import_entry_uuid))
                            # TODO: remove when aiida extras will be moved somewhere else
                            # from here
                            extras = {
                                key: value
                                for key, value in extras.items()
                                if not key.startswith('_aiida_')
                            }
                            if object_.node_type.endswith('code.Code.'):
                                extras = {
                                    key: value
                                    for key, value in extras.items()
                                    if not key == 'hidden'
                                }
                            # till here
                            object_.extras = extras
                        elif extras_mode_new == 'none':
                            IMPORT_LOGGER.debug('SKIPPING NEW NODE EXTRAS...')
                        else:
                            raise exceptions.ImportValidationError(
                                "Unknown extras_mode_new value: {}, should be either 'import' or 'none'"
                                ''.format(extras_mode_new))

                    # EXISTING NODES (Extras)
                    IMPORT_LOGGER.debug('UPDATING EXISTING NODE EXTRAS...')

                    import_existing_entry_pks = {
                        entry_data[unique_identifier]: import_entry_pk
                        for import_entry_pk, entry_data in
                        existing_entries[entity_name].items()
                    }
                    for node in session.query(DbNode).filter(
                            DbNode.uuid.in_(import_existing_entry_pks)).all():
                        import_entry_uuid = str(node.uuid)
                        import_entry_pk = import_existing_entry_pks[
                            import_entry_uuid]

                        # Progress bar initialization - Node
                        pbar_node_base_str = pbar_base_str + 'UUID={} - '.format(
                            import_entry_uuid.split('-')[0])
                        progress_bar.set_description_str(pbar_node_base_str +
                                                         'Extras',
                                                         refresh=False)
                        progress_bar.update()

                        # Get extras from import file
                        try:
                            extras = data['node_extras'][str(import_entry_pk)]
                        except KeyError:
                            raise exceptions.CorruptArchive(
                                'Unable to find extra info for Node with UUID={}'
                                .format(import_entry_uuid))

                        old_extras = node.extras.copy()
                        # TODO: remove when aiida extras will be moved somewhere else
                        # from here
                        extras = {
                            key: value
                            for key, value in extras.items()
                            if not key.startswith('_aiida_')
                        }
                        if node.node_type.endswith('code.Code.'):
                            extras = {
                                key: value
                                for key, value in extras.items()
                                if not key == 'hidden'
                            }
                        # till here
                        new_extras = merge_extras(node.extras, extras,
                                                  extras_mode_existing)
                        if new_extras != old_extras:
                            node.extras = new_extras
                            flag_modified(node, 'extras')
                            objects_to_update.append(node)

                else:
                    # Update progress bar with new non-Node entries
                    progress_bar.update(n=len(existing_entries[entity_name]) +
                                        len(new_entries[entity_name]))

                progress_bar.set_description_str(pbar_base_str + 'Storing',
                                                 refresh=True)

                # Store them all in once; However, the PK are not set in this way...
                if objects_to_create:
                    session.add_all(objects_to_create)
                if objects_to_update:
                    session.add_all(objects_to_update)

                session.flush()

                just_saved = {}
                if import_new_entry_pks.keys():
                    reset_progress_bar = {
                        'total': progress_bar.total,
                        'n': progress_bar.n
                    }
                    progress_bar = get_progress_bar(
                        total=len(import_new_entry_pks), disable=silent)

                    builder = QueryBuilder()
                    builder.append(entity,
                                   filters={
                                       unique_identifier: {
                                           'in':
                                           list(import_new_entry_pks.keys())
                                       }
                                   },
                                   project=[unique_identifier, 'id'])

                    for entry in builder.iterall():
                        progress_bar.update()

                        just_saved.update({entry[0]: entry[1]})

                progress_bar.set_description_str(pbar_base_str + 'Done!',
                                                 refresh=True)

                # Now I have the PKs, print the info
                # Moreover, add newly created Nodes to foreign_ids_reverse_mappings
                for unique_id, new_pk in just_saved.items():
                    from uuid import UUID
                    if isinstance(unique_id, UUID):
                        unique_id = str(unique_id)
                    import_entry_pk = import_new_entry_pks[unique_id]
                    foreign_ids_reverse_mappings[entity_name][
                        unique_id] = new_pk
                    if entity_name not in ret_dict:
                        ret_dict[entity_name] = {'new': [], 'existing': []}
                    ret_dict[entity_name]['new'].append(
                        (import_entry_pk, new_pk))

                    IMPORT_LOGGER.debug('N %s: %s (%s->%s)', entity_name,
                                        unique_id, import_entry_pk, new_pk)

            IMPORT_LOGGER.debug('STORING NODE LINKS...')

            import_links = data['links_uuid']

            if import_links:
                progress_bar = get_progress_bar(total=len(import_links),
                                                disable=silent)
                pbar_base_str = 'Links - '

            for link in import_links:
                # Check for dangling Links within the, supposed, self-consistent archive
                progress_bar.set_description_str(
                    pbar_base_str + 'label={}'.format(link['label']),
                    refresh=False)
                progress_bar.update()

                try:
                    in_id = foreign_ids_reverse_mappings[NODE_ENTITY_NAME][
                        link['input']]
                    out_id = foreign_ids_reverse_mappings[NODE_ENTITY_NAME][
                        link['output']]
                except KeyError:
                    if ignore_unknown_nodes:
                        continue
                    raise exceptions.ImportValidationError(
                        'Trying to create a link with one or both unknown nodes, stopping (in_uuid={}, out_uuid={}, '
                        'label={}, type={})'.format(link['input'],
                                                    link['output'],
                                                    link['label'],
                                                    link['type']))

                # Since backend specific Links (DbLink) are not validated upon creation, we will now validate them.
                source = QueryBuilder().append(Node,
                                               filters={
                                                   'id': in_id
                                               },
                                               project='*').first()[0]
                target = QueryBuilder().append(Node,
                                               filters={
                                                   'id': out_id
                                               },
                                               project='*').first()[0]
                link_type = LinkType(link['type'])

                # Check for existence of a triple link, i.e. unique triple.
                # If it exists, then the link already exists, continue to next link, otherwise, validate link.
                if link_triple_exists(source, target, link_type,
                                      link['label']):
                    continue

                try:
                    validate_link(source, target, link_type, link['label'])
                except ValueError as why:
                    raise exceptions.ImportValidationError(
                        'Error occurred during Link validation: {}'.format(
                            why))

                # New link
                session.add(
                    DbLink(input_id=in_id,
                           output_id=out_id,
                           label=link['label'],
                           type=link['type']))
                if 'Link' not in ret_dict:
                    ret_dict['Link'] = {'new': []}
                ret_dict['Link']['new'].append((in_id, out_id))

            IMPORT_LOGGER.debug('   (%d new links...)',
                                len(ret_dict.get('Link', {}).get('new', [])))

            IMPORT_LOGGER.debug('STORING GROUP ELEMENTS...')

            import_groups = data['groups_uuid']

            if import_groups:
                progress_bar = get_progress_bar(total=len(import_groups),
                                                disable=silent)
                pbar_base_str = 'Groups - '

            for groupuuid, groupnodes in import_groups.items():
                # # TODO: cache these to avoid too many queries
                qb_group = QueryBuilder().append(
                    Group, filters={'uuid': {
                        '==': groupuuid
                    }})
                group_ = qb_group.first()[0]

                progress_bar.set_description_str(
                    pbar_base_str + 'label={}'.format(group_.label),
                    refresh=False)
                progress_bar.update()

                nodes_ids_to_add = [
                    foreign_ids_reverse_mappings[NODE_ENTITY_NAME][node_uuid]
                    for node_uuid in groupnodes
                ]
                qb_nodes = QueryBuilder().append(
                    Node, filters={'id': {
                        'in': nodes_ids_to_add
                    }})
                # Adding nodes to group avoiding the SQLA ORM to increase speed
                nodes_to_add = [n[0].backend_entity for n in qb_nodes.all()]
                group_.backend_entity.add_nodes(nodes_to_add, skip_orm=True)

            ######################################################
            # Put everything in a specific group
            ######################################################
            existing = existing_entries.get(NODE_ENTITY_NAME, {})
            existing_pk = [
                foreign_ids_reverse_mappings[NODE_ENTITY_NAME][v['uuid']]
                for v in existing.values()
            ]
            new = new_entries.get(NODE_ENTITY_NAME, {})
            new_pk = [
                foreign_ids_reverse_mappings[NODE_ENTITY_NAME][v['uuid']]
                for v in new.values()
            ]

            pks_for_group = existing_pk + new_pk

            # So that we do not create empty groups
            if pks_for_group:
                # If user specified a group, import all things into it
                if not group:
                    from aiida.backends.sqlalchemy.models.group import DbGroup

                    # Get an unique name for the import group, based on the current (local) time
                    basename = timezone.localtime(
                        timezone.now()).strftime('%Y%m%d-%H%M%S')
                    counter = 0
                    group_label = basename
                    while session.query(DbGroup).filter(
                            DbGroup.label == group_label).count() > 0:
                        counter += 1
                        group_label = '{}_{}'.format(basename, counter)

                        if counter == 100:
                            raise exceptions.ImportUniquenessError(
                                "Overflow of import groups (more than 100 import groups exists with basename '{}')"
                                ''.format(basename))
                    group = ImportGroup(label=group_label)
                    session.add(group.backend_entity._dbmodel)

                # Adding nodes to group avoiding the SQLA ORM to increase speed
                builder = QueryBuilder().append(
                    Node, filters={'id': {
                        'in': pks_for_group
                    }})

                progress_bar = get_progress_bar(total=len(pks_for_group),
                                                disable=silent)
                progress_bar.set_description_str(
                    'Creating import Group - Preprocessing', refresh=True)
                first = True

                nodes = []
                for entry in builder.iterall():
                    if first:
                        progress_bar.set_description_str(
                            'Creating import Group', refresh=False)
                        first = False
                    progress_bar.update()
                    nodes.append(entry[0].backend_entity)
                group.backend_entity.add_nodes(nodes, skip_orm=True)
                progress_bar.set_description_str('Done (cleaning up)',
                                                 refresh=True)
            else:
                IMPORT_LOGGER.debug(
                    'No Nodes to import, so no Group created, if it did not already exist'
                )

            IMPORT_LOGGER.debug('COMMITTING EVERYTHING...')
            session.commit()

            # Finalize Progress bar
            close_progress_bar(leave=False)

            # Summarize import
            result_summary(ret_dict, getattr(group, 'label', None))

        except:
            # Finalize Progress bar
            close_progress_bar(leave=False)

            result_summary({}, None)

            IMPORT_LOGGER.debug('Rolling back')
            session.rollback()
            raise

    # Reset logging level
    if silent:
        logging.disable(level=logging.NOTSET)

    return ret_dict
Exemplo n.º 10
0
def import_data_dj(in_path,
                   group=None,
                   ignore_unknown_nodes=False,
                   extras_mode_existing='kcl',
                   extras_mode_new='import',
                   comment_mode='newest',
                   silent=False,
                   **kwargs):
    """Import exported AiiDA archive to the AiiDA database and repository.

    Specific for the Django backend.
    If ``in_path`` is a folder, calls extract_tree; otherwise, tries to detect the compression format
    (zip, tar.gz, tar.bz2, ...) and calls the correct function.

    :param in_path: the path to a file or folder that can be imported in AiiDA.
    :type in_path: str

    :param group: Group wherein all imported Nodes will be placed.
    :type group: :py:class:`~aiida.orm.groups.Group`

    :param extras_mode_existing: 3 letter code that will identify what to do with the extras import.
        The first letter acts on extras that are present in the original node and not present in the imported node.
        Can be either:
        'k' (keep it) or
        'n' (do not keep it).
        The second letter acts on the imported extras that are not present in the original node.
        Can be either:
        'c' (create it) or
        'n' (do not create it).
        The third letter defines what to do in case of a name collision.
        Can be either:
        'l' (leave the old value),
        'u' (update with a new value),
        'd' (delete the extra), or
        'a' (ask what to do if the content is different).
    :type extras_mode_existing: str

    :param extras_mode_new: 'import' to import extras of new nodes or 'none' to ignore them.
    :type extras_mode_new: str

    :param comment_mode: Comment import modes (when same UUIDs are found).
        Can be either:
        'newest' (will keep the Comment with the most recent modification time (mtime)) or
        'overwrite' (will overwrite existing Comments with the ones from the import file).
    :type comment_mode: str

    :param silent: suppress progress bar and summary.
    :type silent: bool

    :return: New and existing Nodes and Links.
    :rtype: dict

    :raises `~aiida.tools.importexport.common.exceptions.ImportValidationError`: if parameters or the contents of
        `metadata.json` or `data.json` can not be validated.
    :raises `~aiida.tools.importexport.common.exceptions.CorruptArchive`: if the provided archive at ``in_path`` is
        corrupted.
    :raises `~aiida.tools.importexport.common.exceptions.IncompatibleArchiveVersionError`: if the provided archive's
        export version is not equal to the export version of AiiDA at the moment of import.
    :raises `~aiida.tools.importexport.common.exceptions.ArchiveImportError`: if there are any internal errors when
        importing.
    :raises `~aiida.tools.importexport.common.exceptions.ImportUniquenessError`: if a new unique entity can not be
        created.
    """
    from django.db import transaction  # pylint: disable=import-error,no-name-in-module
    from aiida.backends.djsite.db import models

    # This is the export version expected by this function
    expected_export_version = StrictVersion(EXPORT_VERSION)

    # The returned dictionary with new and existing nodes and links
    ret_dict = {}

    # Initial check(s)
    if group:
        if not isinstance(group, Group):
            raise exceptions.ImportValidationError(
                'group must be a Group entity')
        elif not group.is_stored:
            group.store()

    if silent:
        logging.disable(level=logging.CRITICAL)

    ################
    # EXTRACT DATA #
    ################
    # The sandbox has to remain open until the end
    with SandboxFolder() as folder:
        if os.path.isdir(in_path):
            extract_tree(in_path, folder)
        else:
            if tarfile.is_tarfile(in_path):
                extract_tar(in_path,
                            folder,
                            silent=silent,
                            nodes_export_subfolder=NODES_EXPORT_SUBFOLDER,
                            **kwargs)
            elif zipfile.is_zipfile(in_path):
                extract_zip(in_path,
                            folder,
                            silent=silent,
                            nodes_export_subfolder=NODES_EXPORT_SUBFOLDER,
                            **kwargs)
            else:
                raise exceptions.ImportValidationError(
                    'Unable to detect the input file format, it is neither a '
                    'tar file, nor a (possibly compressed) zip file.')

        if not folder.get_content_list():
            raise exceptions.CorruptArchive(
                'The provided file/folder ({}) is empty'.format(in_path))
        try:
            with open(folder.get_abs_path('metadata.json'),
                      'r',
                      encoding='utf8') as fhandle:
                metadata = json.load(fhandle)

            with open(folder.get_abs_path('data.json'), 'r',
                      encoding='utf8') as fhandle:
                data = json.load(fhandle)
        except IOError as error:
            raise exceptions.CorruptArchive(
                'Unable to find the file {} in the import file or folder'.
                format(error.filename))

        ######################
        # PRELIMINARY CHECKS #
        ######################
        export_version = StrictVersion(str(metadata['export_version']))
        if export_version != expected_export_version:
            msg = 'Export file version is {}, can import only version {}'\
                    .format(metadata['export_version'], expected_export_version)
            if export_version < expected_export_version:
                msg += "\nUse 'verdi export migrate' to update this export file."
            else:
                msg += '\nUpdate your AiiDA version in order to import this file.'

            raise exceptions.IncompatibleArchiveVersionError(msg)

        start_summary(in_path, comment_mode, extras_mode_new,
                      extras_mode_existing)

        ##########################################################################
        # CREATE UUID REVERSE TABLES AND CHECK IF I HAVE ALL NODES FOR THE LINKS #
        ##########################################################################
        linked_nodes = set(
            chain.from_iterable(
                (l['input'], l['output']) for l in data['links_uuid']))
        group_nodes = set(chain.from_iterable(data['groups_uuid'].values()))

        if NODE_ENTITY_NAME in data['export_data']:
            import_nodes_uuid = set(
                v['uuid']
                for v in data['export_data'][NODE_ENTITY_NAME].values())
        else:
            import_nodes_uuid = set()

        # the combined set of linked_nodes and group_nodes was obtained from looking at all the links
        # the set of import_nodes_uuid was received from the stuff actually referred to in export_data
        unknown_nodes = linked_nodes.union(group_nodes) - import_nodes_uuid

        if unknown_nodes and not ignore_unknown_nodes:
            raise exceptions.DanglingLinkError(
                'The import file refers to {} nodes with unknown UUID, therefore it cannot be imported. Either first '
                'import the unknown nodes, or export also the parents when exporting. The unknown UUIDs are:\n'
                ''.format(len(unknown_nodes)) +
                '\n'.join('* {}'.format(uuid) for uuid in unknown_nodes))

        ###################################
        # DOUBLE-CHECK MODEL DEPENDENCIES #
        ###################################
        # The entity import order. It is defined by the database model relationships.
        model_order = (USER_ENTITY_NAME, COMPUTER_ENTITY_NAME,
                       NODE_ENTITY_NAME, GROUP_ENTITY_NAME, LOG_ENTITY_NAME,
                       COMMENT_ENTITY_NAME)

        for import_field_name in metadata['all_fields_info']:
            if import_field_name not in model_order:
                raise exceptions.ImportValidationError(
                    "You are trying to import an unknown model '{}'!".format(
                        import_field_name))

        for idx, model_name in enumerate(model_order):
            dependencies = []
            for field in metadata['all_fields_info'][model_name].values():
                try:
                    dependencies.append(field['requires'])
                except KeyError:
                    # (No ForeignKey)
                    pass
            for dependency in dependencies:
                if dependency not in model_order[:idx]:
                    raise exceptions.ArchiveImportError(
                        'Model {} requires {} but would be loaded first; stopping...'
                        .format(model_name, dependency))

        ###################################################
        # CREATE IMPORT DATA DIRECT UNIQUE_FIELD MAPPINGS #
        ###################################################
        import_unique_ids_mappings = {}
        for model_name, import_data in data['export_data'].items():
            if model_name in metadata['unique_identifiers']:
                # I have to reconvert the pk to integer
                import_unique_ids_mappings[model_name] = {
                    int(k): v[metadata['unique_identifiers'][model_name]]
                    for k, v in import_data.items()
                }

        ###############
        # IMPORT DATA #
        ###############
        # DO ALL WITH A TRANSACTION
        # !!! EXCEPT: Creating final import Group containing all Nodes in archive

        # batch size for bulk create operations
        batch_size = get_config_option('db.batch_size')

        with transaction.atomic():
            foreign_ids_reverse_mappings = {}
            new_entries = {}
            existing_entries = {}

            IMPORT_LOGGER.debug('GENERATING LIST OF DATA...')

            # Instantiate progress bar
            progress_bar = get_progress_bar(total=1,
                                            leave=False,
                                            disable=silent)
            pbar_base_str = 'Generating list of data - '

            # Get total entities from data.json
            # To be used with progress bar
            number_of_entities = 0

            # I first generate the list of data
            for model_name in model_order:
                cls_signature = entity_names_to_signatures[model_name]
                model = get_object_from_string(cls_signature)
                fields_info = metadata['all_fields_info'].get(model_name, {})
                unique_identifier = metadata['unique_identifiers'].get(
                    model_name, None)

                new_entries[model_name] = {}
                existing_entries[model_name] = {}

                foreign_ids_reverse_mappings[model_name] = {}

                # Not necessarily all models are exported
                if model_name in data['export_data']:

                    IMPORT_LOGGER.debug('  %s...', model_name)

                    progress_bar.set_description_str(pbar_base_str +
                                                     model_name,
                                                     refresh=False)
                    number_of_entities += len(data['export_data'][model_name])

                    # skip nodes that are already present in the DB
                    if unique_identifier is not None:
                        import_unique_ids = set(
                            v[unique_identifier]
                            for v in data['export_data'][model_name].values())

                        relevant_db_entries = {}
                        if import_unique_ids:
                            relevant_db_entries_result = model.objects.filter(
                                **{
                                    '{}__in'.format(unique_identifier):
                                    import_unique_ids
                                })

                            # Note: UUIDs need to be converted to strings
                            if relevant_db_entries_result.count():
                                progress_bar = get_progress_bar(
                                    total=relevant_db_entries_result.count(),
                                    disable=silent)
                                # Imitating QueryBuilder.iterall() with default settings
                                for object_ in relevant_db_entries_result.iterator(
                                        chunk_size=100):
                                    progress_bar.update()
                                    relevant_db_entries.update({
                                        str(getattr(object_, unique_identifier)):
                                        object_
                                    })

                        foreign_ids_reverse_mappings[model_name] = {
                            k: v.pk
                            for k, v in relevant_db_entries.items()
                        }

                        IMPORT_LOGGER.debug('    GOING THROUGH ARCHIVE...')

                        imported_comp_names = set()
                        for key, value in data['export_data'][
                                model_name].items():
                            if model_name == GROUP_ENTITY_NAME:
                                # Check if there is already a group with the same name
                                dupl_counter = 0
                                orig_label = value['label']
                                while model.objects.filter(
                                        label=value['label']):
                                    value[
                                        'label'] = orig_label + DUPL_SUFFIX.format(
                                            dupl_counter)
                                    dupl_counter += 1
                                    if dupl_counter == 100:
                                        raise exceptions.ImportUniquenessError(
                                            'A group of that label ( {} ) already exists and I could not create a new '
                                            'one'.format(orig_label))

                            elif model_name == COMPUTER_ENTITY_NAME:
                                # Check if there is already a computer with the same name in the database
                                dupl = (
                                    model.objects.filter(name=value['name'])
                                    or value['name'] in imported_comp_names)
                                orig_name = value['name']
                                dupl_counter = 0
                                while dupl:
                                    # Rename the new computer
                                    value[
                                        'name'] = orig_name + DUPL_SUFFIX.format(
                                            dupl_counter)
                                    dupl = (model.objects.filter(
                                        name=value['name']) or value['name']
                                            in imported_comp_names)
                                    dupl_counter += 1
                                    if dupl_counter == 100:
                                        raise exceptions.ImportUniquenessError(
                                            'A computer of that name ( {} ) already exists and I could not create a '
                                            'new one'.format(orig_name))

                                imported_comp_names.add(value['name'])

                            if value[unique_identifier] in relevant_db_entries:
                                # Already in DB
                                existing_entries[model_name][key] = value
                            else:
                                # To be added
                                new_entries[model_name][key] = value
                    else:
                        new_entries[model_name] = data['export_data'][
                            model_name]

            # Reset for import
            progress_bar = get_progress_bar(total=number_of_entities,
                                            disable=silent)

            # I import data from the given model
            for model_name in model_order:
                # Progress bar initialization - Model
                pbar_base_str = '{}s - '.format(model_name)
                progress_bar.set_description_str(pbar_base_str +
                                                 'Initializing',
                                                 refresh=True)

                cls_signature = entity_names_to_signatures[model_name]
                model = get_object_from_string(cls_signature)
                fields_info = metadata['all_fields_info'].get(model_name, {})
                unique_identifier = metadata['unique_identifiers'].get(
                    model_name, None)

                # EXISTING ENTRIES
                if existing_entries[model_name]:
                    # Progress bar update - Model
                    progress_bar.set_description_str(
                        pbar_base_str + '{} existing entries'.format(
                            len(existing_entries[model_name])),
                        refresh=True)

                for import_entry_pk, entry_data in existing_entries[
                        model_name].items():
                    unique_id = entry_data[unique_identifier]
                    existing_entry_id = foreign_ids_reverse_mappings[
                        model_name][unique_id]
                    import_data = dict(
                        deserialize_field(k,
                                          v,
                                          fields_info=fields_info,
                                          import_unique_ids_mappings=
                                          import_unique_ids_mappings,
                                          foreign_ids_reverse_mappings=
                                          foreign_ids_reverse_mappings)
                        for k, v in entry_data.items())
                    # TODO COMPARE, AND COMPARE ATTRIBUTES

                    if model is models.DbComment:
                        new_entry_uuid = merge_comment(import_data,
                                                       comment_mode)
                        if new_entry_uuid is not None:
                            entry_data[unique_identifier] = new_entry_uuid
                            new_entries[model_name][
                                import_entry_pk] = entry_data

                    if model_name not in ret_dict:
                        ret_dict[model_name] = {'new': [], 'existing': []}
                    ret_dict[model_name]['existing'].append(
                        (import_entry_pk, existing_entry_id))
                    IMPORT_LOGGER.debug('Existing %s: %s (%s->%s)', model_name,
                                        unique_id, import_entry_pk,
                                        existing_entry_id)
                    # print('  `-> WARNING: NO DUPLICITY CHECK DONE!')
                    # CHECK ALSO FILES!

                # Store all objects for this model in a list, and store them all in once at the end.
                objects_to_create = []
                # This is needed later to associate the import entry with the new pk
                import_new_entry_pks = {}

                # NEW ENTRIES
                if new_entries[model_name]:
                    # Progress bar update - Model
                    progress_bar.set_description_str(
                        pbar_base_str +
                        '{} new entries'.format(len(new_entries[model_name])),
                        refresh=True)

                for import_entry_pk, entry_data in new_entries[
                        model_name].items():
                    unique_id = entry_data[unique_identifier]
                    import_data = dict(
                        deserialize_field(k,
                                          v,
                                          fields_info=fields_info,
                                          import_unique_ids_mappings=
                                          import_unique_ids_mappings,
                                          foreign_ids_reverse_mappings=
                                          foreign_ids_reverse_mappings)
                        for k, v in entry_data.items())

                    objects_to_create.append(model(**import_data))
                    import_new_entry_pks[unique_id] = import_entry_pk

                if model_name == NODE_ENTITY_NAME:
                    IMPORT_LOGGER.debug('STORING NEW NODE REPOSITORY FILES...')

                    # NEW NODES
                    for object_ in objects_to_create:
                        import_entry_uuid = object_.uuid
                        import_entry_pk = import_new_entry_pks[
                            import_entry_uuid]

                        # Progress bar initialization - Node
                        progress_bar.update()
                        pbar_node_base_str = pbar_base_str + 'UUID={} - '.format(
                            import_entry_uuid.split('-')[0])

                        # Before storing entries in the DB, I store the files (if these are nodes).
                        # Note: only for new entries!
                        subfolder = folder.get_subfolder(
                            os.path.join(NODES_EXPORT_SUBFOLDER,
                                         export_shard_uuid(import_entry_uuid)))
                        if not subfolder.exists():
                            raise exceptions.CorruptArchive(
                                'Unable to find the repository folder for Node with UUID={} in the exported '
                                'file'.format(import_entry_uuid))
                        destdir = RepositoryFolder(
                            section=Repository._section_name,
                            uuid=import_entry_uuid)
                        # Replace the folder, possibly destroying existing previous folders, and move the files
                        # (faster if we are on the same filesystem, and in any case the source is a SandboxFolder)
                        progress_bar.set_description_str(pbar_node_base_str +
                                                         'Repository',
                                                         refresh=True)
                        destdir.replace_with_folder(subfolder.abspath,
                                                    move=True,
                                                    overwrite=True)

                        # For DbNodes, we also have to store its attributes
                        IMPORT_LOGGER.debug('STORING NEW NODE ATTRIBUTES...')
                        progress_bar.set_description_str(pbar_node_base_str +
                                                         'Attributes',
                                                         refresh=True)

                        # Get attributes from import file
                        try:
                            object_.attributes = data['node_attributes'][str(
                                import_entry_pk)]
                        except KeyError:
                            raise exceptions.CorruptArchive(
                                'Unable to find attribute info for Node with UUID={}'
                                .format(import_entry_uuid))

                        # For DbNodes, we also have to store its extras
                        if extras_mode_new == 'import':
                            IMPORT_LOGGER.debug('STORING NEW NODE EXTRAS...')
                            progress_bar.set_description_str(
                                pbar_node_base_str + 'Extras', refresh=True)

                            # Get extras from import file
                            try:
                                extras = data['node_extras'][str(
                                    import_entry_pk)]
                            except KeyError:
                                raise exceptions.CorruptArchive(
                                    'Unable to find extra info for Node with UUID={}'
                                    .format(import_entry_uuid))
                            # TODO: remove when aiida extras will be moved somewhere else
                            # from here
                            extras = {
                                key: value
                                for key, value in extras.items()
                                if not key.startswith('_aiida_')
                            }
                            if object_.node_type.endswith('code.Code.'):
                                extras = {
                                    key: value
                                    for key, value in extras.items()
                                    if not key == 'hidden'
                                }
                            # till here
                            object_.extras = extras
                        elif extras_mode_new == 'none':
                            IMPORT_LOGGER.debug('SKIPPING NEW NODE EXTRAS...')
                        else:
                            raise exceptions.ImportValidationError(
                                "Unknown extras_mode_new value: {}, should be either 'import' or 'none'"
                                ''.format(extras_mode_new))

                    # EXISTING NODES (Extras)
                    # For the existing nodes that are also in the imported list we also update their extras if necessary
                    IMPORT_LOGGER.debug('UPDATING EXISTING NODE EXTRAS...')

                    import_existing_entry_pks = {
                        entry_data[unique_identifier]: import_entry_pk
                        for import_entry_pk, entry_data in
                        existing_entries[model_name].items()
                    }
                    for node in models.DbNode.objects.filter(
                            uuid__in=import_existing_entry_pks).all():  # pylint: disable=no-member
                        import_entry_uuid = str(node.uuid)
                        import_entry_pk = import_existing_entry_pks[
                            import_entry_uuid]

                        # Progress bar initialization - Node
                        pbar_node_base_str = pbar_base_str + 'UUID={} - '.format(
                            import_entry_uuid.split('-')[0])
                        progress_bar.set_description_str(pbar_node_base_str +
                                                         'Extras',
                                                         refresh=False)
                        progress_bar.update()

                        # Get extras from import file
                        try:
                            extras = data['node_extras'][str(import_entry_pk)]
                        except KeyError:
                            raise exceptions.CorruptArchive(
                                'Unable to find extra info for Node with UUID={}'
                                .format(import_entry_uuid))

                        old_extras = node.extras.copy()
                        # TODO: remove when aiida extras will be moved somewhere else
                        # from here
                        extras = {
                            key: value
                            for key, value in extras.items()
                            if not key.startswith('_aiida_')
                        }
                        if node.node_type.endswith('code.Code.'):
                            extras = {
                                key: value
                                for key, value in extras.items()
                                if not key == 'hidden'
                            }
                        # till here
                        new_extras = merge_extras(node.extras, extras,
                                                  extras_mode_existing)

                        if new_extras != old_extras:
                            # Already saving existing node here to update its extras
                            node.extras = new_extras
                            node.save()

                else:
                    # Update progress bar with new non-Node entries
                    progress_bar.update(n=len(existing_entries[model_name]) +
                                        len(new_entries[model_name]))

                progress_bar.set_description_str(pbar_base_str + 'Storing',
                                                 refresh=True)

                # If there is an mtime in the field, disable the automatic update
                # to keep the mtime that we have set here
                if 'mtime' in [
                        field.name for field in model._meta.local_fields
                ]:
                    with models.suppress_auto_now([(model, ['mtime'])]):
                        # Store them all in once; however, the PK are not set in this way...
                        model.objects.bulk_create(objects_to_create,
                                                  batch_size=batch_size)
                else:
                    model.objects.bulk_create(objects_to_create,
                                              batch_size=batch_size)

                # Get back the just-saved entries
                just_saved_queryset = model.objects.filter(
                    **{
                        '{}__in'.format(unique_identifier):
                        import_new_entry_pks.keys()
                    }).values_list(unique_identifier, 'pk')
                # note: convert uuids from type UUID to strings
                just_saved = {
                    str(key): value
                    for key, value in just_saved_queryset
                }

                # Now I have the PKs, print the info
                # Moreover, add newly created Nodes to foreign_ids_reverse_mappings
                for unique_id, new_pk in just_saved.items():
                    import_entry_pk = import_new_entry_pks[unique_id]
                    foreign_ids_reverse_mappings[model_name][
                        unique_id] = new_pk
                    if model_name not in ret_dict:
                        ret_dict[model_name] = {'new': [], 'existing': []}
                    ret_dict[model_name]['new'].append(
                        (import_entry_pk, new_pk))

                    IMPORT_LOGGER.debug(
                        'New %s: %s (%s->%s)' %
                        (model_name, unique_id, import_entry_pk, new_pk))

            IMPORT_LOGGER.debug('STORING NODE LINKS...')
            import_links = data['links_uuid']
            links_to_store = []

            # Needed, since QueryBuilder does not yet work for recently saved Nodes
            existing_links_raw = models.DbLink.objects.all().values_list(
                'input', 'output', 'label', 'type')
            existing_links = {(l[0], l[1], l[2], l[3])
                              for l in existing_links_raw}
            existing_outgoing_unique = {(l[0], l[3])
                                        for l in existing_links_raw}
            existing_outgoing_unique_pair = {(l[0], l[2], l[3])
                                             for l in existing_links_raw}
            existing_incoming_unique = {(l[1], l[3])
                                        for l in existing_links_raw}
            existing_incoming_unique_pair = {(l[1], l[2], l[3])
                                             for l in existing_links_raw}

            calculation_node_types = 'process.calculation.'
            workflow_node_types = 'process.workflow.'
            data_node_types = 'data.'

            link_mapping = {
                LinkType.CALL_CALC:
                (workflow_node_types, calculation_node_types, 'unique_triple',
                 'unique'),
                LinkType.CALL_WORK: (workflow_node_types, workflow_node_types,
                                     'unique_triple', 'unique'),
                LinkType.CREATE: (calculation_node_types, data_node_types,
                                  'unique_pair', 'unique'),
                LinkType.INPUT_CALC: (data_node_types, calculation_node_types,
                                      'unique_triple', 'unique_pair'),
                LinkType.INPUT_WORK: (data_node_types, workflow_node_types,
                                      'unique_triple', 'unique_pair'),
                LinkType.RETURN: (workflow_node_types, data_node_types,
                                  'unique_pair', 'unique_triple'),
            }

            if import_links:
                progress_bar = get_progress_bar(total=len(import_links),
                                                disable=silent)
                pbar_base_str = 'Links - '

            for link in import_links:
                # Check for dangling Links within the, supposed, self-consistent archive
                progress_bar.set_description_str(
                    pbar_base_str + 'label={}'.format(link['label']),
                    refresh=False)
                progress_bar.update()

                try:
                    in_id = foreign_ids_reverse_mappings[NODE_ENTITY_NAME][
                        link['input']]
                    out_id = foreign_ids_reverse_mappings[NODE_ENTITY_NAME][
                        link['output']]
                except KeyError:
                    if ignore_unknown_nodes:
                        continue
                    raise exceptions.ImportValidationError(
                        'Trying to create a link with one or both unknown nodes, stopping (in_uuid={}, out_uuid={}, '
                        'label={}, type={})'.format(link['input'],
                                                    link['output'],
                                                    link['label'],
                                                    link['type']))

                # Check if link already exists, skip if it does
                # This is equivalent to an existing triple link (i.e. unique_triple from below)
                if (in_id, out_id, link['label'],
                        link['type']) in existing_links:
                    continue

                # Since backend specific Links (DbLink) are not validated upon creation, we will now validate them.
                try:
                    validate_link_label(link['label'])
                except ValueError as why:
                    raise exceptions.ImportValidationError(
                        'Error during Link label validation: {}'.format(why))

                source = models.DbNode.objects.get(id=in_id)
                target = models.DbNode.objects.get(id=out_id)

                if source.uuid == target.uuid:
                    raise exceptions.ImportValidationError(
                        'Cannot add a link to oneself')

                link_type = LinkType(link['type'])
                type_source, type_target, outdegree, indegree = link_mapping[
                    link_type]

                # Check if source Node is a valid type
                if not source.node_type.startswith(type_source):
                    raise exceptions.ImportValidationError(
                        'Cannot add a {} link from {} to {}'.format(
                            link_type, source.node_type, target.node_type))

                # Check if target Node is a valid type
                if not target.node_type.startswith(type_target):
                    raise exceptions.ImportValidationError(
                        'Cannot add a {} link from {} to {}'.format(
                            link_type, source.node_type, target.node_type))

                # If the outdegree is `unique` there cannot already be any other outgoing link of that type,
                # i.e., the source Node may not have a LinkType of current LinkType, going out, existing already.
                if outdegree == 'unique' and (
                        in_id, link['type']) in existing_outgoing_unique:
                    raise exceptions.ImportValidationError(
                        'Node<{}> already has an outgoing {} link'.format(
                            source.uuid, link_type))

                # If the outdegree is `unique_pair`,
                # then the link labels for outgoing links of this type should be unique,
                # i.e., the source Node may not have a LinkType of current LinkType, going out,
                # that also has the current Link label, existing already.
                elif outdegree == 'unique_pair' and \
                (in_id, link['label'], link['type']) in existing_outgoing_unique_pair:
                    raise exceptions.ImportValidationError(
                        'Node<{}> already has an outgoing {} link with label "{}"'
                        .format(source.uuid, link_type, link['label']))

                # If the indegree is `unique` there cannot already be any other incoming links of that type,
                # i.e., the target Node may not have a LinkType of current LinkType, coming in, existing already.
                if indegree == 'unique' and (
                        out_id, link['type']) in existing_incoming_unique:
                    raise exceptions.ImportValidationError(
                        'Node<{}> already has an incoming {} link'.format(
                            target.uuid, link_type))

                # If the indegree is `unique_pair`,
                # then the link labels for incoming links of this type should be unique,
                # i.e., the target Node may not have a LinkType of current LinkType, coming in
                # that also has the current Link label, existing already.
                elif indegree == 'unique_pair' and \
                (out_id, link['label'], link['type']) in existing_incoming_unique_pair:
                    raise exceptions.ImportValidationError(
                        'Node<{}> already has an incoming {} link with label "{}"'
                        .format(target.uuid, link_type, link['label']))

                # New link
                links_to_store.append(
                    models.DbLink(input_id=in_id,
                                  output_id=out_id,
                                  label=link['label'],
                                  type=link['type']))
                if 'Link' not in ret_dict:
                    ret_dict['Link'] = {'new': []}
                ret_dict['Link']['new'].append((in_id, out_id))

                # Add new Link to sets of existing Links 'input PK', 'output PK', 'label', 'type'
                existing_links.add(
                    (in_id, out_id, link['label'], link['type']))
                existing_outgoing_unique.add((in_id, link['type']))
                existing_outgoing_unique_pair.add(
                    (in_id, link['label'], link['type']))
                existing_incoming_unique.add((out_id, link['type']))
                existing_incoming_unique_pair.add(
                    (out_id, link['label'], link['type']))

            # Store new links
            if links_to_store:
                IMPORT_LOGGER.debug('   (%d new links...)',
                                    len(links_to_store))

                models.DbLink.objects.bulk_create(links_to_store,
                                                  batch_size=batch_size)
            else:
                IMPORT_LOGGER.debug('   (0 new links...)')

            IMPORT_LOGGER.debug('STORING GROUP ELEMENTS...')

            import_groups = data['groups_uuid']

            if import_groups:
                progress_bar = get_progress_bar(total=len(import_groups),
                                                disable=silent)
                pbar_base_str = 'Groups - '

            for groupuuid, groupnodes in import_groups.items():
                # TODO: cache these to avoid too many queries
                group_ = models.DbGroup.objects.get(uuid=groupuuid)

                progress_bar.set_description_str(
                    pbar_base_str + 'label={}'.format(group_.label),
                    refresh=False)
                progress_bar.update()

                nodes_to_store = [
                    foreign_ids_reverse_mappings[NODE_ENTITY_NAME][node_uuid]
                    for node_uuid in groupnodes
                ]
                if nodes_to_store:
                    group_.dbnodes.add(*nodes_to_store)

        ######################################################
        # Put everything in a specific group
        ######################################################
        existing = existing_entries.get(NODE_ENTITY_NAME, {})
        existing_pk = [
            foreign_ids_reverse_mappings[NODE_ENTITY_NAME][v['uuid']]
            for v in existing.values()
        ]
        new = new_entries.get(NODE_ENTITY_NAME, {})
        new_pk = [
            foreign_ids_reverse_mappings[NODE_ENTITY_NAME][v['uuid']]
            for v in new.values()
        ]

        pks_for_group = existing_pk + new_pk

        # So that we do not create empty groups
        if pks_for_group:
            # If user specified a group, import all things into it
            if not group:
                # Get an unique name for the import group, based on the current (local) time
                basename = timezone.localtime(
                    timezone.now()).strftime('%Y%m%d-%H%M%S')
                counter = 0
                group_label = basename

                while Group.objects.find(filters={'label': group_label}):
                    counter += 1
                    group_label = '{}_{}'.format(basename, counter)

                    if counter == 100:
                        raise exceptions.ImportUniquenessError(
                            "Overflow of import groups (more than 100 import groups exists with basename '{}')"
                            ''.format(basename))
                group = ImportGroup(label=group_label).store()

            # Add all the nodes to the new group
            builder = QueryBuilder().append(
                Node, filters={'id': {
                    'in': pks_for_group
                }})

            progress_bar = get_progress_bar(total=len(pks_for_group),
                                            disable=silent)
            progress_bar.set_description_str(
                'Creating import Group - Preprocessing', refresh=True)
            first = True

            nodes = []
            for entry in builder.iterall():
                if first:
                    progress_bar.set_description_str('Creating import Group',
                                                     refresh=False)
                    first = False
                progress_bar.update()
                nodes.append(entry[0])
            group.add_nodes(nodes)
            progress_bar.set_description_str('Done (cleaning up)',
                                             refresh=True)
        else:
            IMPORT_LOGGER.debug(
                'No Nodes to import, so no Group created, if it did not already exist'
            )

    # Finalize Progress bar
    close_progress_bar(leave=False)

    # Summarize import
    result_summary(ret_dict, getattr(group, 'label', None))

    # Reset logging level
    if silent:
        logging.disable(level=logging.NOTSET)

    return ret_dict
Exemplo n.º 11
0
def calculations_for_label(label):
    qb = QueryBuilder()
    qb.append(MPDSCrystalWorkchain, filters={'label': {'like': label}}, tag='root')
    qb.append(WorkChainNode, with_incoming='root', tag='parent')
    qb.append(CalcJobNode, with_incoming='parent', project=['label', 'uuid'])
    return {label: uuid for label, uuid in qb.iterall()}
Exemplo n.º 12
0
def _select_entity_data(*, entity_name: str, reader: ArchiveReaderAbstract,
                        new_entries: Dict[str, Dict[str, dict]],
                        existing_entries: Dict[str, Dict[str, dict]],
                        foreign_ids_reverse_mappings: Dict[str, Dict[str,
                                                                     int]],
                        extras_mode_new: str):
    """Select the data to import by comparing the AiiDA database to the archive contents."""
    entity = entity_names_to_entities[entity_name]

    # entity = entity_names_to_entities[entity_name]
    unique_identifier = reader.metadata.unique_identifiers.get(
        entity_name, None)

    # Not necessarily all models are present in the archive
    if entity_name not in reader.entity_names:
        return

    existing_entries.setdefault(entity_name, {})
    new_entries.setdefault(entity_name, {})

    if unique_identifier is None:
        new_entries[entity_name] = {
            str(pk): fields
            for pk, fields in reader.iter_entity_fields(entity_name)
        }
        return

    # skip nodes that are already present in the DB
    import_unique_ids = set(f[unique_identifier]
                            for _, f in reader.iter_entity_fields(
                                entity_name, fields=(unique_identifier, )))

    relevant_db_entries = {}
    if import_unique_ids:
        builder = QueryBuilder()
        builder.append(entity,
                       filters={unique_identifier: {
                           'in': import_unique_ids
                       }},
                       project='*')

        if builder.count():
            with get_progress_reporter()(
                    desc=f'Finding existing entities - {entity_name}',
                    total=builder.count()) as progress:
                for object_ in builder.iterall():
                    progress.update()
                    # Note: UUIDs need to be converted to strings
                    relevant_db_entries.update({
                        str(getattr(object_[0], unique_identifier)):
                        object_[0]
                    })

    foreign_ids_reverse_mappings[entity_name] = {
        k: v.pk
        for k, v in relevant_db_entries.items()
    }

    entity_count = reader.entity_count(entity_name)
    if not entity_count:
        return

    with get_progress_reporter()(
            desc=f'Reading archived entities - {entity_name}',
            total=entity_count) as progress:
        imported_comp_names = set()
        for pk, fields in reader.iter_entity_fields(entity_name):
            if entity_name == GROUP_ENTITY_NAME:
                # Check if there is already a group with the same name,
                # and if so, recreate the name
                orig_label = fields['label']
                dupl_counter = 0
                while QueryBuilder().append(entity,
                                            filters={
                                                'label': {
                                                    '==': fields['label']
                                                }
                                            }).count():
                    # Rename the new group
                    fields['label'] = orig_label + DUPL_SUFFIX.format(
                        dupl_counter)
                    dupl_counter += 1
                    if dupl_counter == MAX_GROUPS:
                        raise exceptions.ImportUniquenessError(
                            f'A group of that label ( {orig_label} ) already exists and I could not create a new one'
                        )

            elif entity_name == COMPUTER_ENTITY_NAME:
                # The following is done for compatibility
                # reasons in case the archive file was generated
                # with the Django export method. In Django the
                # metadata and the transport parameters are
                # stored as (unicode) strings of the serialized
                # JSON objects and not as simple serialized
                # JSON objects.
                if isinstance(fields['metadata'], (str, bytes)):
                    fields['metadata'] = json.loads(fields['metadata'])

                # Check if there is already a computer with the
                # same name in the database
                builder = QueryBuilder()
                builder.append(entity,
                               filters={'name': {
                                   '==': fields['name']
                               }},
                               project=['*'],
                               tag='res')
                dupl = builder.count() or fields['name'] in imported_comp_names
                dupl_counter = 0
                orig_name = fields['name']
                while dupl:
                    # Rename the new computer
                    fields['name'] = orig_name + DUPL_SUFFIX.format(
                        dupl_counter)
                    builder = QueryBuilder()
                    builder.append(entity,
                                   filters={'name': {
                                       '==': fields['name']
                                   }},
                                   project=['*'],
                                   tag='res')
                    dupl = builder.count(
                    ) or fields['name'] in imported_comp_names
                    dupl_counter += 1
                    if dupl_counter == MAX_COMPUTERS:
                        raise exceptions.ImportUniquenessError(
                            f'A computer of that name ( {orig_name} ) already exists and I could not create a new one'
                        )

                imported_comp_names.add(fields['name'])

            if fields[unique_identifier] in relevant_db_entries:
                # Already in DB
                existing_entries[entity_name][str(pk)] = fields
            else:
                # To be added
                if entity_name == NODE_ENTITY_NAME:
                    # format extras
                    fields = _sanitize_extras(fields)
                    if extras_mode_new != 'import':
                        fields.pop('extras', None)
                new_entries[entity_name][str(pk)] = fields
Exemplo n.º 13
0
    def __init__(self, title=""):
        self.title = title

        # Structure objects we want to query for.
        self.query_structure_type = (DataFactory("structure"),
                                     DataFactory("cif"))

        # Extracting available process labels.
        qbuilder = QueryBuilder().append((CalcJobNode, WorkChainNode),
                                         project="label")
        self.drop_label = ipw.Dropdown(
            options=sorted({"All"
                            }.union({i[0]
                                     for i in qbuilder.iterall() if i[0]})),
            value="All",
            description="Process Label",
            disabled=True,
            style={"description_width": "120px"},
            layout={"width": "50%"},
        )
        self.drop_label.observe(self.search, names="value")

        # Disable process labels selection if we are not looking for the calculated structures.
        def disable_drop_label(change):
            self.drop_label.disabled = not change["new"] == "calculated"

        # Select structures kind.
        self.mode = ipw.RadioButtons(
            options=["all", "uploaded", "edited", "calculated"],
            layout={"width": "25%"})
        self.mode.observe(self.search, names="value")
        self.mode.observe(disable_drop_label, names="value")

        # Date range.
        # Note: there is Date picker widget, but it currently does not work in Safari:
        # https://ipywidgets.readthedocs.io/en/latest/examples/Widget%20List.html#Date-picker
        date_text = ipw.HTML(value="<p>Select the date range:</p>")
        self.start_date_widget = ipw.Text(value="",
                                          description="From: ",
                                          style={"description_width": "120px"})
        self.end_date_widget = ipw.Text(value="", description="To: ")

        # Search button.
        btn_search = ipw.Button(
            description="Search",
            button_style="info",
            layout={
                "width": "initial",
                "margin": "2px 0 0 2em"
            },
        )
        btn_search.on_click(self.search)

        age_selection = ipw.VBox(
            [
                date_text,
                ipw.HBox([
                    self.start_date_widget, self.end_date_widget, btn_search
                ]),
            ],
            layout={
                "border": "1px solid #fafafa",
                "padding": "1em"
            },
        )

        h_line = ipw.HTML("<hr>")
        box = ipw.VBox(
            [age_selection, h_line,
             ipw.HBox([self.mode, self.drop_label])])

        self.results = ipw.Dropdown(layout={"width": "900px"})
        self.results.observe(self._on_select_structure, names="value")
        self.search()
        super().__init__([box, h_line, self.results])
Exemplo n.º 14
0
def delete_nodes(
    pks: Iterable[int],
    verbosity: Optional[int] = None,
    dry_run: Union[bool, Callable[[Set[int]], bool]] = True,
    force: Optional[bool] = None,
    **traversal_rules: bool
) -> Tuple[Set[int], bool]:
    """Delete nodes given a list of "starting" PKs.

    This command will delete not only the specified nodes, but also the ones that are
    linked to these and should be also deleted in order to keep a consistent provenance
    according to the rules explained in the Topics - Provenance section of the documentation.
    In summary:

    1. If a DATA node is deleted, any process nodes linked to it will also be deleted.

    2. If a CALC node is deleted, any incoming WORK node (callers) will be deleted as
    well whereas any incoming DATA node (inputs) will be kept. Outgoing DATA nodes
    (outputs) will be deleted by default but this can be disabled.

    3. If a WORK node is deleted, any incoming WORK node (callers) will be deleted as
    well, but all DATA nodes will be kept. Outgoing WORK or CALC nodes will be kept by
    default, but deletion of either of both kind of connected nodes can be enabled.

    These rules are 'recursive', so if a CALC node is deleted, then its output DATA
    nodes will be deleted as well, and then any CALC node that may have those as
    inputs, and so on.

    .. deprecated:: 1.6.0
        The `verbosity` keyword will be removed in `v2.0.0`, set the level of `DELETE_LOGGER` instead.

    .. deprecated:: 1.6.0
        The `force` keyword will be removed in `v2.0.0`, use the `dry_run` option instead.

    :param pks: a list of starting PKs of the nodes to delete
        (the full set will be based on the traversal rules)

    :param dry_run:
        If True, return the pks to delete without deleting anything.
        If False, delete the pks without confirmation
        If callable, a function that return True/False, based on the pks, e.g. ``dry_run=lambda pks: True``

    :param traversal_rules: graph traversal rules.
        See :const:`aiida.common.links.GraphTraversalRules` for what rule names
        are toggleable and what the defaults are.

    :returns: (pks to delete, whether they were deleted)

    """
    # pylint: disable=too-many-arguments,too-many-branches,too-many-locals,too-many-statements

    if verbosity is not None:
        warnings.warn(
            'The verbosity option is deprecated and will be removed in `aiida-core==2.0.0`. '
            'Set the level of DELETE_LOGGER instead', AiidaDeprecationWarning
        )  # pylint: disable=no-member

    if force is not None:
        warnings.warn(
            'The force option is deprecated and will be removed in `aiida-core==2.0.0`. '
            'Use dry_run instead', AiidaDeprecationWarning
        )  # pylint: disable=no-member
        if force is True:
            dry_run = False

    def _missing_callback(_pks: Iterable[int]):
        for _pk in _pks:
            DELETE_LOGGER.warning(f'warning: node with pk<{_pk}> does not exist, skipping')

    pks_set_to_delete = get_nodes_delete(pks, get_links=False, missing_callback=_missing_callback,
                                         **traversal_rules)['nodes']

    DELETE_LOGGER.info('%s Node(s) marked for deletion', len(pks_set_to_delete))

    if pks_set_to_delete and DELETE_LOGGER.level == logging.DEBUG:
        builder = QueryBuilder().append(
            Node, filters={'id': {
                'in': pks_set_to_delete
            }}, project=('uuid', 'id', 'node_type', 'label')
        )
        DELETE_LOGGER.debug('Node(s) to delete:')
        for uuid, pk, type_string, label in builder.iterall():
            try:
                short_type_string = type_string.split('.')[-2]
            except IndexError:
                short_type_string = type_string
            DELETE_LOGGER.debug(f'   {uuid} {pk} {short_type_string} {label}')

    if dry_run is True:
        DELETE_LOGGER.info('This was a dry run, exiting without deleting anything')
        return (pks_set_to_delete, False)

    # confirm deletion
    if callable(dry_run) and dry_run(pks_set_to_delete):
        DELETE_LOGGER.info('This was a dry run, exiting without deleting anything')
        return (pks_set_to_delete, False)

    if not pks_set_to_delete:
        return (pks_set_to_delete, True)

    # Recover the list of folders to delete before actually deleting the nodes. I will delete the folders only later,
    # so that if there is a problem during the deletion of the nodes in the DB, I don't delete the folders
    repositories = [load_node(pk)._repository for pk in pks_set_to_delete]  # pylint: disable=protected-access

    DELETE_LOGGER.info('Starting node deletion...')
    delete_nodes_and_connections(pks_set_to_delete)

    DELETE_LOGGER.info('Nodes deleted from database, deleting files from the repository now...')

    # If we are here, we managed to delete the entries from the DB.
    # I can now delete the folders
    for repository in repositories:
        repository.erase(force=True)

    DELETE_LOGGER.info('Deletion of nodes completed.')

    return (pks_set_to_delete, True)
Exemplo n.º 15
0
def delete_nodes(pks, verbosity=0, dry_run=False, force=False, **kwargs):
    """Delete nodes by a list of pks.

    This command will delete not only the specified nodes, but also the ones that are
    linked to these and should be also deleted in order to keep a consistent provenance
    according to the rules explained in the concepts section of the documentation.
    In summary:

    1. If a DATA node is deleted, any process nodes linked to it will also be deleted.

    2. If a CALC node is deleted, any incoming WORK node (callers) will be deleted as
    well whereas any incoming DATA node (inputs) will be kept. Outgoing DATA nodes
    (outputs) will be deleted by default but this can be disabled.

    3. If a WORK node is deleted, any incoming WORK node (callers) will be deleted as
    well, but all DATA nodes will be kept. Outgoing WORK or CALC nodes will be kept by
    default, but deletion of either of both kind of connected nodes can be enabled.

    These rules are 'recursive', so if a CALC node is deleted, then its output DATA
    nodes will be deleted as well, and then any CALC node that may have those as
    inputs, and so on.

    :param pks: a list of the PKs of the nodes to delete
    :param bool force: do not ask for confirmation to delete nodes.
    :param int verbosity: 0 prints nothing,
                          1 prints just sums and total,
                          2 prints individual nodes.

    :param kwargs: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` what rule names
        are toggleable and what the defaults are.
    :param bool dry_run:
        Just perform a dry run and do not delete anything. Print statistics according
        to the verbosity level set.
    :param bool force:
        Do not ask for confirmation to delete nodes.
    """
    # pylint: disable=too-many-arguments,too-many-branches,too-many-locals,too-many-statements
    from aiida.backends.utils import delete_nodes_and_connections
    from aiida.common import exceptions
    from aiida.common.links import GraphTraversalRules
    from aiida.orm import Node, QueryBuilder, load_node

    starting_pks = []
    for pk in pks:
        try:
            load_node(pk)
        except exceptions.NotExistent:
            echo.echo_warning('warning: node with pk<{}> does not exist, skipping'.format(pk))
        else:
            starting_pks.append(pk)

    # An empty set might be problematic for the queries done below.
    if not starting_pks:
        if verbosity:
            echo.echo('Nothing to delete')
        return

    follow_forwards = []
    follow_backwards = []

    # Create the dictionary with graph traversal rules to be used in determing complete node set to be exported
    for name, rule in GraphTraversalRules.DELETE.value.items():

        # Check that rules that are not toggleable are not specified in the keyword arguments
        if not rule.toggleable and name in kwargs:
            raise exceptions.ExportValidationError('traversal rule {} is not toggleable'.format(name))

        follow = kwargs.pop(name, rule.default)

        if follow:
            if rule.direction == 'forward':
                follow_forwards.append(rule.link_type.value)
            elif rule.direction == 'backward':
                follow_backwards.append(rule.link_type.value)
            else:
                raise InternalError('unrecognized direction `{}` for graph traversal rule'.format(rule.direction))

    links_backwards = {'type': {'in': follow_backwards}}
    links_forwards = {'type': {'in': follow_forwards}}

    operational_set = set().union(set(starting_pks))
    accumulator_set = set().union(set(starting_pks))

    while operational_set:
        new_pks_set = set()

        query_nodes = QueryBuilder()
        query_nodes.append(Node, filters={'id': {'in': operational_set}}, tag='sources')
        query_nodes.append(
            Node,
            filters={'id': {
                '!in': accumulator_set
            }},
            edge_filters=links_forwards,
            with_incoming='sources',
            project='id'
        )
        new_pks_set.update(i for i, in query_nodes.iterall())

        query_nodes = QueryBuilder()
        query_nodes.append(Node, filters={'id': {'in': operational_set}}, tag='sources')
        query_nodes.append(
            Node,
            filters={'id': {
                '!in': accumulator_set
            }},
            edge_filters=links_backwards,
            with_outgoing='sources',
            project='id'
        )
        new_pks_set.update(i for i, in query_nodes.iterall())

        operational_set = new_pks_set.difference(accumulator_set)
        accumulator_set.update(new_pks_set)

    pks_set_to_delete = accumulator_set

    if verbosity > 0:
        echo.echo(
            'I {} delete {} node{}'.format(
                'would' if dry_run else 'will', len(pks_set_to_delete), 's' if len(pks_set_to_delete) > 1 else ''
            )
        )
        if verbosity > 1:
            builder = QueryBuilder().append(
                Node, filters={'id': {
                    'in': pks_set_to_delete
                }}, project=('uuid', 'id', 'node_type', 'label')
            )
            echo.echo('The nodes I {} delete:'.format('would' if dry_run else 'will'))
            for uuid, pk, type_string, label in builder.iterall():
                try:
                    short_type_string = type_string.split('.')[-2]
                except IndexError:
                    short_type_string = type_string
                echo.echo('   {} {} {} {}'.format(uuid, pk, short_type_string, label))

    if dry_run:
        if verbosity > 0:
            echo.echo('\nThis was a dry run, exiting without deleting anything')
        return

    # Asking for user confirmation here
    if force:
        pass
    else:
        echo.echo_warning('YOU ARE ABOUT TO DELETE {} NODES! THIS CANNOT BE UNDONE!'.format(len(pks_set_to_delete)))
        if not click.confirm('Shall I continue?'):
            echo.echo('Exiting without deleting')
            return

    # Recover the list of folders to delete before actually deleting the nodes. I will delete the folders only later,
    # so that if there is a problem during the deletion of the nodes in the DB, I don't delete the folders
    repositories = [load_node(pk)._repository for pk in pks_set_to_delete]  # pylint: disable=protected-access

    if verbosity > 0:
        echo.echo('Starting node deletion...')
    delete_nodes_and_connections(pks_set_to_delete)

    if verbosity > 0:
        echo.echo('Nodes deleted from database, deleting files from the repository now...')

    # If we are here, we managed to delete the entries from the DB.
    # I can now delete the folders
    for repository in repositories:
        repository.erase(force=True)

    if verbosity > 0:
        echo.echo('Deletion completed.')
Exemplo n.º 16
0
def _retrieve_linked_nodes_query(current_node, input_type, output_type,
                                 direction, link_type_value):
    """Helper function for :py:func:`~aiida.tools.importexport.dbexport.utils.retrieve_linked_nodes`

    A general :py:class:`~aiida.orm.querybuilder.QueryBuilder` query, retrieving linked Nodes and returning link
    information and the found Nodes.

    :param current_node: The current Node's PK.
    :type current_node: int

    :param input_type: Source Node class for Link
    :type input_type: :py:class:`~aiida.orm.nodes.data.data.Data`,
        :py:class:`~aiida.orm.nodes.process.process.ProcessNode`.

    :param output_type: Target Node class for Link
    :type output_type: :py:class:`~aiida.orm.nodes.data.data.Data`,
        :py:class:`~aiida.orm.nodes.process.process.ProcessNode`.

    :param direction: Link direction, must be either ``'forward'`` or ``'backward'``.
    :type direction: str

    :param link_type_value: A :py:class:`~aiida.common.links.LinkType` value, e.g. ``LinkType.RETURN.value``.
    :type link_type_value: str

    :return: Dictionary of link information to be used for the export archive and set of found Nodes.
    """
    found_nodes = set()
    links_uuid_dict = {}
    filters_input = {}
    filters_output = {}

    if direction == 'forward':
        filters_input['id'] = current_node
    elif direction == 'backward':
        filters_output['id'] = current_node
    else:
        raise exceptions.ExportValidationError(
            'direction must be either "forward" or "backward"')

    builder = QueryBuilder()
    builder.append(input_type,
                   project=['uuid', 'id'],
                   tag='input',
                   filters=filters_input)
    builder.append(output_type,
                   project=['uuid', 'id'],
                   with_incoming='input',
                   filters=filters_output,
                   edge_filters={'type': link_type_value},
                   edge_project=['label', 'type'])

    for input_uuid, input_pk, output_uuid, output_pk, link_label, link_type in builder.iterall(
    ):
        links_uuid_entry = {
            'input': str(input_uuid),
            'output': str(output_uuid),
            'label': str(link_label),
            'type': str(link_type)
        }
        links_uuid_dict[frozenset(links_uuid_entry.items())] = links_uuid_entry

        node_pk = output_pk if direction == 'forward' else input_pk
        found_nodes.add(node_pk)

    return links_uuid_dict, found_nodes
Exemplo n.º 17
0
    def get_or_create_group(self):
        """Return the current `AutoGroup`, or create one if None has been set yet.

        This function implements a somewhat complex logic that is however needed
        to make sure that, even if `verdi run` is called at the same time multiple
        times, e.g. in a for loop in bash, there is never the risk that two ``verdi run``
        Unix processes try to create the same group, with the same label, ending
        up in a crash of the code (see PR #3650).

        Here, instead, we make sure that if this concurrency issue happens,
        one of the two will get a IntegrityError from the DB, and then recover
        trying to create a group with a different label (with a numeric suffix appended),
        until it manages to create it.
        """
        from aiida.orm import QueryBuilder

        # When this function is called, if it is the first time, just generate
        # a new group name (later on, after this ``if`` block`).
        # In that case, we will later cache in ``self._group_label`` the group label,
        # So the group with the same name can be returned quickly in future
        # calls of this method.
        if self._group_label is not None:
            builder = QueryBuilder().append(
                AutoGroup, filters={'label': self._group_label})
            results = [res[0] for res in builder.iterall()]
            if results:
                # If it is not empty, it should have only one result due to the uniqueness constraints
                assert len(
                    results
                ) == 1, 'I got more than one autogroup with the same label!'
                return results[0]
            # There are no results: probably the group has been deleted.
            # I continue as if it was not cached
            self._group_label = None

        label_prefix = self.get_group_label_prefix()
        # Try to do a preliminary QB query to avoid to do too many try/except
        # if many of the prefix_NUMBER groups already exist
        queryb = QueryBuilder().append(
            AutoGroup,
            filters={
                'or': [{
                    'label': {
                        '==': label_prefix
                    }
                }, {
                    'label': {
                        'like': escape_for_sql_like(label_prefix + '_') + '%'
                    }
                }]
            },
            project='label')
        existing_group_labels = [
            res[0][len(label_prefix):] for res in queryb.all()
        ]
        existing_group_ints = []
        for label in existing_group_labels:
            if label == '':
                # This is just the prefix without name - corresponds to counter = 0
                existing_group_ints.append(0)
            elif label.startswith('_'):
                try:
                    existing_group_ints.append(int(label[1:]))
                except ValueError:
                    # It's not an integer, so it will never collide - just ignore it
                    pass

        if not existing_group_ints:
            counter = 0
        else:
            counter = max(existing_group_ints) + 1

        while True:
            try:
                label = label_prefix if counter == 0 else '{}_{}'.format(
                    label_prefix, counter)
                group = AutoGroup(label=label).store()
                self._group_label = group.label
            except exceptions.IntegrityError:
                counter += 1
            else:
                break

        return group
Exemplo n.º 18
0
def _store_entity_data(*, reader: ArchiveReaderAbstract, entity_name: str,
                       comment_mode: str, extras_mode_existing: str,
                       new_entries: Dict[str, Dict[str, dict]],
                       existing_entries: Dict[str, Dict[str, dict]],
                       foreign_ids_reverse_mappings: Dict[str, Dict[str, int]],
                       import_unique_ids_mappings: Dict[str, Dict[int, str]],
                       ret_dict: dict, session: Session):
    """Store the entity data on the AiiDA profile."""
    from aiida.backends.sqlalchemy.utils import flag_modified
    from aiida.backends.sqlalchemy.models.node import DbNode

    entity = entity_names_to_entities[entity_name]

    fields_info = reader.metadata.all_fields_info.get(entity_name, {})
    unique_identifier = reader.metadata.unique_identifiers.get(
        entity_name, None)

    pbar_base_str = f'{entity_name}s - '

    # EXISTING ENTRIES
    if existing_entries[entity_name]:

        with get_progress_reporter()(
                total=len(existing_entries[entity_name]),
                desc=f'{pbar_base_str} existing entries') as progress:

            for import_entry_pk, entry_data in existing_entries[
                    entity_name].items():

                progress.update()

                unique_id = entry_data[unique_identifier]
                existing_entry_pk = foreign_ids_reverse_mappings[entity_name][
                    unique_id]
                import_data = dict(
                    deserialize_field(
                        k,
                        v,
                        fields_info=fields_info,
                        import_unique_ids_mappings=import_unique_ids_mappings,
                        foreign_ids_reverse_mappings=
                        foreign_ids_reverse_mappings)
                    for k, v in entry_data.items())
                # TODO COMPARE, AND COMPARE ATTRIBUTES

                if entity_name == COMMENT_ENTITY_NAME:
                    new_entry_uuid = merge_comment(import_data, comment_mode)
                    if new_entry_uuid is not None:
                        entry_data[unique_identifier] = new_entry_uuid
                        new_entries[entity_name][import_entry_pk] = entry_data

                if entity_name not in ret_dict:
                    ret_dict[entity_name] = {'new': [], 'existing': []}
                ret_dict[entity_name]['existing'].append(
                    (import_entry_pk, existing_entry_pk))

                # print('  `-> WARNING: NO DUPLICITY CHECK DONE!')
                # CHECK ALSO FILES!

    # Store all objects for this model in a list, and store them all in once at the end.
    objects_to_create = []
    # In the following list we add the objects to be updated
    objects_to_update = []
    # This is needed later to associate the import entry with the new pk
    import_new_entry_pks = {}

    # NEW ENTRIES
    for import_entry_pk, entry_data in new_entries[entity_name].items():
        unique_id = entry_data[unique_identifier]
        import_data = dict(
            deserialize_field(
                k,
                v,
                fields_info=fields_info,
                import_unique_ids_mappings=import_unique_ids_mappings,
                foreign_ids_reverse_mappings=foreign_ids_reverse_mappings)
            for k, v in entry_data.items())

        # We convert the Django fields to SQLA. Note that some of
        # the Django fields were converted to SQLA compatible
        # fields by the deserialize_field method. This was done
        # for optimization reasons in Django but makes them
        # compatible with the SQLA schema and they don't need any
        # further conversion.
        if entity_name in file_fields_to_model_fields:
            for file_fkey in file_fields_to_model_fields[entity_name]:

                # This is an exception because the DbLog model defines the `_metadata` column instead of the
                # `metadata` column used in the Django model. This is because the SqlAlchemy model base
                # class already has a metadata attribute that cannot be overridden. For consistency, the
                # `DbLog` class however expects the `metadata` keyword in its constructor, so we should
                # ignore the mapping here
                if entity_name == LOG_ENTITY_NAME and file_fkey == 'metadata':
                    continue

                model_fkey = file_fields_to_model_fields[entity_name][
                    file_fkey]
                if model_fkey in import_data:
                    continue
                import_data[model_fkey] = import_data[file_fkey]
                import_data.pop(file_fkey, None)

        db_entity = get_object_from_string(
            entity_names_to_sqla_schema[entity_name])

        objects_to_create.append(db_entity(**import_data))
        import_new_entry_pks[unique_id] = import_entry_pk

    if entity_name == NODE_ENTITY_NAME:

        # Before storing entries in the DB, I store the files (if these are nodes).
        # Note: only for new entries!
        uuids_to_create = [obj.uuid for obj in objects_to_create]
        _copy_node_repositories(uuids_to_create=uuids_to_create, reader=reader)

        # For the existing nodes that are also in the imported list we also update their extras if necessary
        if existing_entries[entity_name]:

            with get_progress_reporter()(
                    total=len(existing_entries[entity_name]),
                    desc='Updating existing node extras') as progress:

                import_existing_entry_pks = {
                    entry_data[unique_identifier]: import_entry_pk
                    for import_entry_pk, entry_data in
                    existing_entries[entity_name].items()
                }
                for node in session.query(DbNode).filter(
                        DbNode.uuid.in_(import_existing_entry_pks)).all():
                    import_entry_uuid = str(node.uuid)
                    import_entry_pk = import_existing_entry_pks[
                        import_entry_uuid]

                    pbar_node_base_str = f"{pbar_base_str}UUID={import_entry_uuid.split('-')[0]} - "
                    progress.set_description_str(f'{pbar_node_base_str}Extras',
                                                 refresh=False)
                    progress.update()

                    old_extras = node.extras.copy()
                    extras = existing_entries[entity_name][str(
                        import_entry_pk)].get('extras', {})

                    new_extras = merge_extras(node.extras, extras,
                                              extras_mode_existing)

                    if new_extras != old_extras:
                        node.extras = new_extras
                        flag_modified(node, 'extras')
                        objects_to_update.append(node)

    # Store them all in once; However, the PK are not set in this way...
    if objects_to_create:
        session.add_all(objects_to_create)
    if objects_to_update:
        session.add_all(objects_to_update)

    session.flush()

    if not import_new_entry_pks:
        return

    with get_progress_reporter()(
            total=len(import_new_entry_pks),
            desc=f'{pbar_base_str} storing new') as progress:

        just_saved = {}

        builder = QueryBuilder()
        builder.append(entity,
                       filters={
                           unique_identifier: {
                               'in': list(import_new_entry_pks.keys())
                           }
                       },
                       project=[unique_identifier, 'id'])

        for entry in builder.iterall():
            progress.update()
            just_saved.update({entry[0]: entry[1]})

        # Now I have the PKs, print the info
        # Moreover, add newly created Nodes to foreign_ids_reverse_mappings
        for unique_id, new_pk in just_saved.items():
            from uuid import UUID
            if isinstance(unique_id, UUID):
                unique_id = str(unique_id)
            import_entry_pk = import_new_entry_pks[unique_id]
            foreign_ids_reverse_mappings[entity_name][unique_id] = new_pk
            if entity_name not in ret_dict:
                ret_dict[entity_name] = {'new': [], 'existing': []}
            ret_dict[entity_name]['new'].append((import_entry_pk, new_pk))