Example #1
0
    def create_structure_data():
        """Create StructureData object."""
        alat = 4.  # angstrom
        cell = [
            [
                alat,
                0.,
                0.,
            ],
            [
                0.,
                alat,
                0.,
            ],
            [
                0.,
                0.,
                alat,
            ],
        ]

        # BaTiO3 cubic structure
        struc = StructureData(cell=cell)
        struc.append_atom(position=(0., 0., 0.), symbols='Ba')
        struc.append_atom(position=(alat / 2., alat / 2., alat / 2.),
                          symbols='Ti')
        struc.append_atom(position=(alat / 2., alat / 2., 0.), symbols='O')
        struc.append_atom(position=(alat / 2., 0., alat / 2.), symbols='O')
        struc.append_atom(position=(0., alat / 2., alat / 2.), symbols='O')
        struc.store()

        # Create 2 groups and add the data to one of them
        g_ne = Group(label='non_empty_group')
        g_ne.store()
        g_ne.add_nodes(struc)

        g_e = Group(label='empty_group')
        g_e.store()

        return {
            DummyVerdiDataListable.NODE_ID_STR: struc.id,
            DummyVerdiDataListable.NON_EMPTY_GROUP_ID_STR: g_ne.id,
            DummyVerdiDataListable.EMPTY_GROUP_ID_STR: g_e.id
        }
def launch(input_group, input_structures, target_supercellsize,
           solute_elements, structure_comments, structure_group_label,
           structure_group_description, dryrun):
    """
    Script for generating substitutional and vacancy defects for an input structure
    or all structures in an input group.
    """
    if not dryrun:
        structure_group = Group.objects.get_or_create(
            label=structure_group_label,
            description=structure_group_description)[0]
    else:
        structure_group = None

    if input_group:
        input_group = Group(input_group)
        structure_nodes = get_allstructurenodes_fromgroup(input_group)
    elif input_structures:
        input_structures = input_structures.split(',')
        structure_nodes = [load_node(x) for x in input_structures]
    else:
        raise Exception("Must use either input group or input structures")

    solute_elements = prep_elementlist(solute_elements)

    for structure_node in structure_nodes:
        extras = {
            'input_structure': structure_node.uuid,
            'structure_comments': structure_comments
        }

        input_structure_ase = structure_node.get_ase()
        #Unfortunately, we must get the unique sites prior to supercell
        #creation, meaning changes in site index can cause bugs
        unique_sites = get_unique_sites(input_structure_ase)
        if target_supercellsize is not None:
            target_supercellsize = int(target_supercellsize)
            extras['target_supercellsize'] = target_supercellsize
            input_structure_ase = generate_supercell(input_structure_ase,
                                                     target_supercellsize)[1]
        for unique_site in unique_sites:

            site_index, element_index, wyckoff = unique_site
            extras['site_index'] = site_index
            extras['element_index'] = element_index
            extras['wyckoff'] = wyckoff

            for element in solute_elements:
                defect_structure = input_structure_ase.copy()
                extras['element_new'] = element
                if defect_structure[site_index].symbol == element:
                    continue
                else:
                    defect_structure[site_index].symbol = element
                    store_asestructure(defect_structure, extras,
                                       structure_group, dryrun)
Example #3
0
def test_show_argument_type(clear_db, run_cli_command, get_pseudo_family):
    """Test that `aiida-pseudo show` only accepts instances of `PseudoPotentialFamily` or subclasses as argument."""
    pseudo_family = get_pseudo_family(label='pseudo-family',
                                      cls=PseudoPotentialFamily)
    upf_family = get_pseudo_family(label='upf-family', cls=UpfFamily)
    normal_group = Group('normal-group').store()

    run_cli_command(cmd_show, [pseudo_family.label])
    run_cli_command(cmd_show, [upf_family.label])
    run_cli_command(cmd_show, [normal_group.label], raises=SystemExit)
Example #4
0
def test_complete(setup_groups, parameter_type):
    """Test the `complete` method that provides auto-complete functionality."""
    entity_01, entity_02, entity_03 = setup_groups
    entity_04 = Group(label='xavier').store()

    options = [item[0] for item in parameter_type.complete(None, '')]
    assert sorted(options) == sorted([entity_01.label, entity_02.label, entity_03.label, entity_04.label])

    options = [item[0] for item in parameter_type.complete(None, 'xa')]
    assert sorted(options) == sorted([entity_04.label])
Example #5
0
def setup_groups(clear_database_before_test):
    """Create some groups to test the `GroupParamType` parameter type for the command line infrastructure.

    We create an initial group with a random name and then on purpose create two groups with a name that matches exactly
    the ID and UUID, respectively, of the first one. This allows us to test the rules implemented to solve ambiguities
    that arise when determing the identifier type.
    """
    entity_01 = Group(label='group_01').store()
    entity_02 = AutoGroup(label=str(entity_01.pk)).store()
    entity_03 = ImportGroup(label=str(entity_01.uuid)).store()
    return entity_01, entity_02, entity_03
Example #6
0
    def test_queryhelp(self):
        """
        Here I test the queryhelp by seeing whether results are the same as using the append method.
        I also check passing of tuples.
        """

        from aiida.orm.data.structure import StructureData
        from aiida.orm.data.parameter import ParameterData
        from aiida.orm.data import Data
        from aiida.orm.querybuilder import QueryBuilder
        from aiida.orm.group import Group
        from aiida.orm.computer import Computer
        g = Group(name='helloworld').store()
        for cls in (StructureData, ParameterData, Data):
            obj = cls()
            obj._set_attr('foo-qh2', 'bar')
            obj.store()
            g.add_nodes(obj)

        for cls, expected_count, subclassing in (
            (StructureData, 1, True),
            (ParameterData, 1, True),
            (Data, 3, True),
            (Data, 1, False),
            ((ParameterData, StructureData), 2, True),
            ((ParameterData, StructureData), 2, False),
            ((ParameterData, Data), 2, False),
            ((ParameterData, Data), 3, True),
            ((ParameterData, Data, StructureData), 3, False),
        ):
            qb = QueryBuilder()
            qb.append(cls,
                      filters={'attributes.foo-qh2': 'bar'},
                      subclassing=subclassing,
                      project='uuid')
            self.assertEqual(qb.count(), expected_count)

            qh = qb.get_json_compatible_queryhelp()
            qb_new = QueryBuilder(**qh)
            self.assertEqual(qb_new.count(), expected_count)
            self.assertEqual(sorted([uuid for uuid, in qb.all()]),
                             sorted([uuid for uuid, in qb_new.all()]))

        qb = QueryBuilder().append(Group, filters={'name': 'helloworld'})
        self.assertEqual(qb.count(), 1)

        qb = QueryBuilder().append((Group, ), filters={'name': 'helloworld'})
        self.assertEqual(qb.count(), 1)

        qb = QueryBuilder().append(Computer, )
        self.assertEqual(qb.count(), 1)

        qb = QueryBuilder().append(cls=(Computer, ))
        self.assertEqual(qb.count(), 1)
Example #7
0
    def setUpClass(cls):  # pylint: disable=arguments-differ
        super().setUpClass()
        orm.Computer(label='comp',
                     hostname='localhost',
                     transport_type='local',
                     scheduler_type='direct',
                     workdir='/tmp/aiida').store()
        cls.ids = cls.create_structure_data()

        for group_label in ['xyz structure group', 'ase structure group']:
            Group(label=group_label).store()
Example #8
0
def get_nodes_from_group(group, return_format='uuid'):
    """
    returns a list of node uuids for a given group as, name, pk, uuid or group object
    """
    from aiida.orm import Group
    from aiida.common.exceptions import NotExistent

    nodes = []
    g_nodes = []

    try:
        group_pk = int(group)
    except ValueError:
        group_pk = None
        group_name = group

    if group_pk is not None:
        try:
            str_group = Group(dbgroup=group_pk)
        except NotExistent:
            str_group = None
            message = ('You have to provide a valid pk for a Group '
                       'or a Group name. Reference key: "group".'
                       'given pk= {} is not a valid group'
                       '(or is your group name integer?)'.format(group_pk))
            print(message)
    elif group_name is not None:
        try:
            str_group = Group.get_from_string(group_name)
        except NotExistent:
            str_group = None
            message = (
                'You have to provide a valid pk for a Group or a Group name.'
                'given group name= {} is not a valid group'
                '(or is your group name integer?)'.format(group_name))
            print(message)
    elif isinstance(group, Group):
        str_group = group
    else:
        str_group = None
        print(
            'I could not handle given input, either Group, pk, or group name please.'
        )
        return nodes

    g_nodes = str_group.nodes

    for node in g_nodes:
        if return_format == 'uuid':
            nodes.append(node.uuid)
        elif return_format == 'pk':
            nodes.append(node.pk)

    return nodes
Example #9
0
    def convert(self, value, param, ctx):
        from aiida.orm import Group, GroupTypeString
        try:
            group = super(GroupParamType, self).convert(value, param, ctx)
        except click.BadParameter:
            if self._create_if_not_exist:
                group = Group(label=value,
                              type_string=GroupTypeString.USER.value)
            else:
                raise

        return group
Example #10
0
    def test_import_to_group(self):
        """
        Test import to existing Group and that Nodes are added correctly for multiple imports of the same,
        as well as separate, archives.
        """
        archives = [
            get_archive_file('arithmetic.add.aiida', filepath='calcjob'),
            get_archive_file(self.newest_archive, filepath=self.archive_path)
        ]

        group_label = 'import_madness'
        group = Group(group_label).store()

        self.assertTrue(group.is_empty, msg='The Group should be empty.')

        # Invoke `verdi import`, making sure there are no exceptions
        options = ['-G', group.label] + [archives[0]]
        result = self.cli_runner.invoke(cmd_import.cmd_import, options)
        self.assertIsNone(result.exception, msg=result.output)
        self.assertEqual(result.exit_code, 0, msg=result.output)

        self.assertFalse(group.is_empty, msg='The Group should no longer be empty.')

        nodes_in_group = group.count()

        # Invoke `verdi import` again, making sure Group count doesn't change
        options = ['-G', group.label] + [archives[0]]
        result = self.cli_runner.invoke(cmd_import.cmd_import, options)
        self.assertIsNone(result.exception, msg=result.output)
        self.assertEqual(result.exit_code, 0, msg=result.output)

        self.assertEqual(
            group.count(),
            nodes_in_group,
            msg='The Group count should not have changed from {}. Instead it is now {}'.format(
                nodes_in_group, group.count()
            )
        )

        # Invoke `verdi import` again with new archive, making sure Group count is upped
        options = ['-G', group.label] + [archives[1]]
        result = self.cli_runner.invoke(cmd_import.cmd_import, options)
        self.assertIsNone(result.exception, msg=result.output)
        self.assertEqual(result.exit_code, 0, msg=result.output)

        self.assertGreater(
            group.count(),
            nodes_in_group,
            msg='There should now be more than {} nodes in group {} , instead there are {}'.format(
                nodes_in_group, group_label, group.count()
            )
        )
    def create_cif_data(cls):
        with tempfile.NamedTemporaryFile(mode='w+') as fhandle:
            filename = fhandle.name
            fhandle.write(cls.valid_sample_cif_str)
            fhandle.flush()
            a = CifData(file=filename, source={'version': '1234', 'db_name': 'COD', 'id': '0000001'})
            a.store()

            g_ne = Group(label='non_empty_group')
            g_ne.store()
            g_ne.add_nodes(a)

            g_e = Group(label='empty_group')
            g_e.store()

        cls.cif = a

        return {
            TestVerdiDataListable.NODE_ID_STR: a.id,
            TestVerdiDataListable.NON_EMPTY_GROUP_ID_STR: g_ne.id,
            TestVerdiDataListable.EMPTY_GROUP_ID_STR: g_e.id
        }
Example #12
0
def get_para_from_group(element, group):
    """
    get structure node for a given element from a given group of structures
    (quit creedy, done straighforward)
    """

    report = []

    try:
        group_pk = int(group)
    except ValueError:
        group_pk = None
        group_name = group

    if group_pk is not None:
        try:
            para_group = Group(dbgroup=group_pk)
        except NotExistent:
            para_group = None
            message = ('You have to provide a valid pk for a Group of '
                       'parameters or a Group name. Reference key: "group".'
                       'given pk= {} is not a valid group'
                       '(or is your group name integer?)'.format(group_pk))
            report.append(message)
    else:
        try:
            para_group = Group.get_from_string(group_name)
        except NotExistent:
            para_group = None
            message = ('You have to provide a valid pk for a Group of '
                       'parameters or a Group name. Wf_para key: "para_group".'
                       'given group name= {} is not a valid group'
                       '(or is your group name integer?)'.format(group_name))
            report.append(message)
            #abort_nowait('I abort, because I have no structures to calculate ...')

    para_nodes = para_group.nodes
    #n_stru = len(para_nodes)

    parameter = None

    for para in para_nodes:
        formula = para.get_extras().get('element', None)
        #eformula = formula.translate(None, digits) # remove digits, !python3 differs
        if formula == element:
            return para, report

    report.append('Parameter node for element {} not found in group {}'
                  ''.format(element, group))

    return parameter, report
Example #13
0
 def make_group(self):
     """Create curated-xxx_XXXX_vx group and put the orig_cif inside, and exit if it already exists."""
     self.ctx.orig_cif = self.exposed_inputs(
         ZeoppMultistageDdecWorkChain)['structure']
     self.ctx.group = Group(
         label="curated-{}_{}_v3".format(
             self.ctx.orig_cif.extras["class_material"],
             self.ctx.orig_cif.label),
         description=
         "Group collecting the results of CURATED-COFs/MOFs/ZEOs: v3 is consistent with the API of Feb 2020"
     )
     self.ctx.group.store(
     )  # REMEMBER: this will crash if a node with the same label exists!
     include_node("orig_cif", self.ctx.orig_cif, self.ctx.group)
Example #14
0
    def setUpClass(cls, *args, **kwargs):
        super(TestVerdiExport, cls).setUpClass(*args, **kwargs)
        from aiida.orm import Code, Computer, Group, Node

        cls.computer = Computer(name='comp',
                                hostname='localhost',
                                transport_type='local',
                                scheduler_type='direct',
                                workdir='/tmp/aiida').store()

        cls.code = Code(remote_computer_exec=(cls.computer,
                                              '/bin/true')).store()
        cls.group = Group(name='test_group').store()
        cls.node = Node().store()
Example #15
0
    def test_serialize_group(self):
        """
        Test that serialization and deserialization of Groups works.
        Also make sure that the serialized data is json-serializable
        """
        group_name = 'groupie'
        group_a = Group(name=group_name).store()

        data = {'group': group_a}

        serialized_data = serialize_data(data)
        json_dumped = json.dumps(serialized_data)
        deserialized_data = deserialize_data(serialized_data)

        self.assertEqual(data['group'].uuid, deserialized_data['group'].uuid)
        self.assertEqual(data['group'].name, deserialized_data['group'].name)
Example #16
0
    def get_or_create_famgroup(cls, famname):
        '''Returns a PAW family group, creates it if it didn't exists'''
        from aiida.orm import Group
        from aiida.djsite.utils import get_automatic_user

        # TODO: maybe replace with Group.get_or_create?
        try:
            group = Group.get(name=famname, type_string=cls.group_type)
            group_created = False
        except NotExistent:
            group = Group(name=famname, type_string=cls.group_type,
                          user=get_automatic_user())
            group_created = True

        if group.user != get_automatic_user():
            raise UniquenessError("There is already a UpfFamily group "
                                  "with name {}, but it belongs to user {},"
                                  " therefore you cannot modify it".format(
                                      famname, group.user.email))
        return group, group_created
Example #17
0
    def _prepare_group_for_upload(cls, group_name, group_description=None, dry_run=False):
        """Prepare a (possibly new) group to upload a POTCAR family to."""
        if not dry_run:
            group, group_created = Group.objects.get_or_create(label=group_name, type_string=cls.potcar_family_type_string)
        else:
            group = cls.get_potcar_group(group_name)
            group_created = bool(not group)
            if not group:
                group = Group(label=group_name)

        if group.user.pk != get_current_user().pk:
            raise UniquenessError(
                'There is already a POTCAR family group with name {}, but it belongs to user {}, therefore you cannot modify it'.format(
                    group_name, group.user.email))

        if group_description:
            group.description = group_description
        elif group_created:
            raise ValueError('A new POTCAR family {} should be created but no description was given!'.format(group_name))

        return group
Example #18
0
    def test_load_group(self):
        """Test the functionality of load_group."""
        name = 'groupie'
        group = Group(label=name).store()

        # Load through label
        loaded_group = load_group(group.label)
        self.assertEquals(loaded_group.uuid, group.uuid)

        # Load through uuid
        loaded_group = load_group(group.uuid)
        self.assertEquals(loaded_group.uuid, group.uuid)

        # Load through pk
        loaded_group = load_group(group.pk)
        self.assertEquals(loaded_group.uuid, group.uuid)

        # Load through label explicitly
        loaded_group = load_group(label=group.label)
        self.assertEquals(loaded_group.uuid, group.uuid)

        # Load through uuid explicitly
        loaded_group = load_group(uuid=group.uuid)
        self.assertEquals(loaded_group.uuid, group.uuid)

        # Load through pk explicitly
        loaded_group = load_group(pk=group.pk)
        self.assertEquals(loaded_group.uuid, group.uuid)

        # Load through partial uuid without a dash
        loaded_group = load_group(uuid=group.uuid[:8])
        self.assertEquals(loaded_group.uuid, group.uuid)

        # Load through partial uuid including a dash
        loaded_group = load_group(uuid=group.uuid[:10])
        self.assertEquals(loaded_group.uuid, group.uuid)

        with self.assertRaises(NotExistent):
            load_group('non-existent-uuid')
Example #19
0
    def upload_basisset_family(cls,
                               folder,
                               group_name,
                               group_description,
                               stop_if_existing=True,
                               extension=".basis",
                               dry_run=False):
        """
        Upload a set of Basis Set files in a given group.

        :param folder: a path containing all Basis Set files to be added.
            Only files ending in the set extension (case-insensitive) are considered.
        :param group_name: the name of the group to create. If it exists and is
            non-empty, a UniquenessError is raised.
        :param group_description: a string to be set as the group description.
            Overwrites previous descriptions, if the group was existing.
        :param stop_if_existing: if True, check for the md5 of the files and,
            if the file already exists in the DB, raises a MultipleObjectsError.
            If False, simply adds the existing BasisSetData node to the group.
        :param extension: the filename extension to look for
        :param dry_run: If True, do not change the database.
        """
        from aiida.common import aiidalogger
        from aiida.orm import Group
        from aiida.common.exceptions import UniquenessError, NotExistent
        from aiida_crystal17.aiida_compatability import get_automatic_user

        automatic_user = get_automatic_user()

        if not os.path.isdir(folder):
            raise ValueError("folder must be a directory")

        # only files, and only those ending with specified exension;
        # go to the real file if it is a symlink
        files = [
            os.path.realpath(os.path.join(folder, i))
            for i in os.listdir(folder)
            if os.path.isfile(os.path.join(folder, i))
            and i.lower().endswith(extension)
        ]

        nfiles = len(files)

        try:
            group = Group.get(name=group_name, type_string=BASISGROUP_TYPE)
            group_created = False
        except NotExistent:
            group = Group(name=group_name,
                          type_string=BASISGROUP_TYPE,
                          user=automatic_user)
            group_created = True

        if group.user.email != automatic_user.email:
            raise UniquenessError(
                "There is already a BasisFamily group with name {}"
                ", but it belongs to user {}, therefore you "
                "cannot modify it".format(group_name, group.user.email))

        # Always update description, even if the group already existed
        group.description = group_description

        # NOTE: GROUP SAVED ONLY AFTER CHECKS OF UNICITY

        basis_and_created = _retrieve_basis_sets(files, stop_if_existing)
        # check whether basisset are unique per element
        elements = [(i[0].element, i[0].md5sum) for i in basis_and_created]
        # If group already exists, check also that I am not inserting more than
        # once the same element
        if not group_created:
            for aiida_n in group.nodes:
                # Skip non-basis sets
                if not isinstance(aiida_n, BasisSetData):
                    continue
                elements.append((aiida_n.element, aiida_n.md5sum))

        elements = set(
            elements)  # Discard elements with the same MD5, that would
        # not be stored twice
        elements_names = [e[0] for e in elements]

        if not len(elements_names) == len(set(elements_names)):
            duplicates = set(
                [x for x in elements_names if elements_names.count(x) > 1])
            duplicates_string = ", ".join(i for i in duplicates)
            raise UniquenessError(
                "More than one Basis found for the elements: " +
                duplicates_string + ".")

        # At this point, save the group, if still unstored
        if group_created and not dry_run:
            group.store()

        # save the basis set in the database, and add them to group
        for basisset, created in basis_and_created:
            if created:
                if not dry_run:
                    basisset.store()

                aiidalogger.debug("New node {0} created for file {1}".format(  # pylint: disable=logging-format-interpolation
                    basisset.uuid, basisset.filename))
            else:
                aiidalogger.debug("Reusing node {0} for file {1}".format(  # pylint: disable=logging-format-interpolation
                    basisset.uuid, basisset.filename))

        # Add elements to the group all together
        if not dry_run:
            group.add_nodes(basis for basis, created in basis_and_created)

        nuploaded = len([_ for _, created in basis_and_created if created])

        return nfiles, nuploaded
Example #20
0
def upload_psf_family(folder,
                      group_name,
                      group_description,
                      stop_if_existing=True):
    """
    Upload a set of PSF files in a given group.

    :param folder: a path containing all PSF files to be added.
        Only files ending in .PSF (case-insensitive) are considered.
    :param group_name: the name of the group to create. If it exists and is
        non-empty, a UniquenessError is raised.
    :param group_description: a string to be set as the group description.
        Overwrites previous descriptions, if the group was existing.
    :param stop_if_existing: if True, check for the md5 of the files and,
        if the file already exists in the DB, raises a MultipleObjectsError.
        If False, simply adds the existing PsfData node to the group.
    """
    import os
    import aiida.common
    from aiida.common import aiidalogger
    from aiida.orm import Group
    from aiida.common.exceptions import UniquenessError, NotExistent
    from aiida.backends.utils import get_automatic_user
    from aiida.orm.querybuilder import QueryBuilder
    if not os.path.isdir(folder):
        raise ValueError("folder must be a directory")

    # only files, and only those ending with .psf or .PSF;
    # go to the real file if it is a symlink
    files = [
        os.path.realpath(os.path.join(folder, i)) for i in os.listdir(folder)
        if os.path.isfile(os.path.join(folder, i))
        and i.lower().endswith('.psf')
    ]

    nfiles = len(files)

    try:
        group = Group.get(name=group_name, type_string=PSFGROUP_TYPE)
        group_created = False
    except NotExistent:
        group = Group(name=group_name,
                      type_string=PSFGROUP_TYPE,
                      user=get_automatic_user())
        group_created = True

    if group.user != get_automatic_user():
        raise UniquenessError("There is already a PsfFamily group with name {}"
                              ", but it belongs to user {}, therefore you "
                              "cannot modify it".format(
                                  group_name, group.user.email))

    # Always update description, even if the group already existed
    group.description = group_description

    # NOTE: GROUP SAVED ONLY AFTER CHECKS OF UNICITY

    pseudo_and_created = []

    for f in files:
        md5sum = aiida.common.utils.md5_file(f)
        qb = QueryBuilder()
        qb.append(PsfData, filters={'attributes.md5': {'==': md5sum}})
        existing_psf = qb.first()

        #existing_psf = PsfData.query(dbattributes__key="md5",
        #                            dbattributes__tval = md5sum)

        if existing_psf is None:
            # return the psfdata instances, not stored
            pseudo, created = PsfData.get_or_create(f,
                                                    use_first=True,
                                                    store_psf=False)
            # to check whether only one psf per element exists
            # NOTE: actually, created has the meaning of "to_be_created"
            pseudo_and_created.append((pseudo, created))
        else:
            if stop_if_existing:
                raise ValueError("A PSF with identical MD5 to "
                                 " {} cannot be added with stop_if_existing"
                                 "".format(f))
            existing_psf = existing_psf[0]
            pseudo_and_created.append((existing_psf, False))

    # check whether pseudo are unique per element
    elements = [(i[0].element, i[0].md5sum) for i in pseudo_and_created]
    # If group already exists, check also that I am not inserting more than
    # once the same element
    if not group_created:
        for aiida_n in group.nodes:
            # Skip non-pseudos
            if not isinstance(aiida_n, PsfData):
                continue
            elements.append((aiida_n.element, aiida_n.md5sum))

    elements = set(elements)  # Discard elements with the same MD5, that would
    # not be stored twice
    elements_names = [e[0] for e in elements]

    if not len(elements_names) == len(set(elements_names)):
        duplicates = set(
            [x for x in elements_names if elements_names.count(x) > 1])
        duplicates_string = ", ".join(i for i in duplicates)
        raise UniquenessError("More than one PSF found for the elements: " +
                              duplicates_string + ".")

    # At this point, save the group, if still unstored
    if group_created:
        group.store()

    # save the psf in the database, and add them to group
    for pseudo, created in pseudo_and_created:
        if created:
            pseudo.store()

            aiidalogger.debug("New node {} created for file {}".format(
                pseudo.uuid, pseudo.filename))
        else:
            aiidalogger.debug("Reusing node {} for file {}".format(
                pseudo.uuid, pseudo.filename))

    # Add elements to the group all togetehr
    group.add_nodes(pseudo for pseudo, created in pseudo_and_created)

    nuploaded = len([_ for _, created in pseudo_and_created if created])

    return nfiles, nuploaded
Example #21
0
    def get_calcs_from_groups(self):
        """
        Extract the crystal structures and parameter data nodes from the given
        groups and create calculation 'pairs' (stru, para).
        """
        wf_dict = self.inputs.wf_parameters.get_dict()
        #get all delta structure
        str_gr = wf_dict.get('struc_group', 'delta')

        try:
            group_pk = int(str_gr)
        except ValueError:
            group_pk = None
            group_name = str_gr

        if group_pk is not None:
            try:
                str_group = Group(dbgroup=group_pk)
            except NotExistent:
                str_group = None
                message = ('You have to provide a valid pk for a Group of'
                          'structures or a Group name. Wf_para key: "struc_group".'
                          'given pk= {} is not a valid group'
                          '(or is your group name integer?)'.format(group_pk))
                #print(message)
                self.report(message)
                self.abort_nowait('I abort, because I have no structures to calculate ...')
        else:
            try:
                str_group = Group.get_from_string(group_name)
            except NotExistent:
                str_group = None
                message = ('You have to provide a valid pk for a Group of'
                          'structures or a Group name. Wf_para key: "struc_group".'
                          'given group name= {} is not a valid group'
                          '(or is your group name integer?)'.format(group_name))
                #print(message)
                self.report(message)
                self.abort_nowait('I abort, because I have no structures to calculate ...')


        #get all delta parameters
        para_gr = wf_dict.get('para_group', 'delta')

        if not para_gr:
            #waring use defauls
            message = 'COMMENT: I did recieve "para_group=None" as input. I will use inpgen defaults'
            self.report(message)

        try:
            group_pk = int(para_gr )
        except ValueError:
            group_pk = None
            group_name = para_gr

        if group_pk is not None:
            try:
                para_group = Group(dbgroup=group_pk)
            except NotExistent:
                para_group = None
                message = ('ERROR: You have to provide a valid pk for a Group of'
                          'parameters or a Group name (or use None for inpgen defaults). Wf_para key: "para_group".'
                          'given pk= {} is not a valid group'
                          '(or is your group name integer?)'.format(group_pk))
                #print(message)
                self.report(message)
                self.abort_nowait('ERROR: I abort, because I have no paremeters to calculate and '
                                  'I guess you did not want to use the inpgen default...')
        else:
            try:
                para_group = Group.get_from_string(group_name)
            except NotExistent:
                para_group = None
                message = ('ERROR: You have to provide a valid pk for a Group of'
                          'parameters or a Group name (or use None for inpgen defaults). Wf_para key: "struc_group".'
                          'given group name= {} is not a valid group'
                          '(or is your group name integer?)'.format(group_name))
                #print(message)
                self.report(message)
                self.abort_nowait('ERROR: I abort, because I have no paremeters to calculate and '
                                  'I guess you did not want to use the inpgen default...')

        # creating calculation pairs (structure, parameters)

        para_nodesi = para_group.nodes
        para_nodes = []

        for para in para_nodesi:
            para_nodes.append(para)
        #print para_nodes
        n_para = len(para_nodes)
        stru_nodes = str_group.nodes
        n_stru = len(stru_nodes)
        if n_para != n_stru:
            message = ('COMMENT: You did not provide the same number of parameter'
                       'nodes as structure nodes. Is this wanted? npara={} nstru={}'.format(n_para, n_stru))
            self.report(message)
        calcs = []
        for struc in stru_nodes:
            para = get_paranode(struc, para_nodes)
            #if para:
            calcs.append((struc, para))
            #else:
            #    calcs.append((struc))
        pprint(calcs[:20])
        self.ctx.calcs_to_run = calcs
    def setUpClass(cls, *args, **kwargs):
        super(TestVerdiCalculation, cls).setUpClass(*args, **kwargs)
        from aiida.backends.tests.utils.fixtures import import_archive_fixture
        from aiida.common.exceptions import ModificationNotAllowed
        from aiida.common.links import LinkType
        from aiida.orm import Code, Computer, Group, Node, JobCalculation, CalculationFactory
        from aiida.orm.data.parameter import ParameterData
        from aiida.orm.querybuilder import QueryBuilder
        from aiida.work.processes import ProcessState

        rmq_config = rmq.get_rmq_config()

        # These two need to share a common event loop otherwise the first will never send
        # the message while the daemon is running listening to intercept
        cls.runner = runners.Runner(rmq_config=rmq_config,
                                    rmq_submit=True,
                                    poll_interval=0.)
        cls.daemon_runner = runners.DaemonRunner(rmq_config=rmq_config,
                                                 rmq_submit=True,
                                                 poll_interval=0.)

        cls.computer = Computer(name='comp',
                                hostname='localhost',
                                transport_type='local',
                                scheduler_type='direct',
                                workdir='/tmp/aiida').store()

        cls.code = Code(remote_computer_exec=(cls.computer,
                                              '/bin/true')).store()
        cls.group = Group(name='test_group').store()
        cls.node = Node().store()
        cls.calcs = []

        from aiida.orm.backend import construct_backend
        backend = construct_backend()
        authinfo = backend.authinfos.create(
            computer=cls.computer, user=backend.users.get_automatic_user())
        authinfo.store()

        # Create 13 JobCalculations (one for each CalculationState)
        for calculation_state in calc_states:

            calc = JobCalculation(computer=cls.computer,
                                  resources={
                                      'num_machines': 1,
                                      'num_mpiprocs_per_machine': 1
                                  }).store()

            # Trying to set NEW will raise, but in this case we don't need to change the state
            try:
                calc._set_state(calculation_state)
            except ModificationNotAllowed:
                pass

            try:
                exit_status = JobCalculationExitStatus[calculation_state]
            except KeyError:
                if calculation_state == 'IMPORTED':
                    calc._set_process_state(ProcessState.FINISHED)
                else:
                    calc._set_process_state(ProcessState.RUNNING)
            else:
                calc._set_exit_status(exit_status)
                calc._set_process_state(ProcessState.FINISHED)

            cls.calcs.append(calc)

            if calculation_state == 'PARSING':

                cls.KEY_ONE = 'key_one'
                cls.KEY_TWO = 'key_two'
                cls.VAL_ONE = 'val_one'
                cls.VAL_TWO = 'val_two'

                output_parameters = ParameterData(dict={
                    cls.KEY_ONE: cls.VAL_ONE,
                    cls.KEY_TWO: cls.VAL_TWO,
                }).store()

                output_parameters.add_link_from(calc,
                                                'output_parameters',
                                                link_type=LinkType.RETURN)

                # Create shortcut for easy dereferencing
                cls.result_job = calc

                # Add a single calc to a group
                cls.group.add_nodes([calc])

        # Load the fixture containing a single ArithmeticAddCalculation node
        import_archive_fixture(
            'calculation/simpleplugins.arithmetic.add.aiida')

        # Get the imported ArithmeticAddCalculation node
        ArithmeticAddCalculation = CalculationFactory(
            'simpleplugins.arithmetic.add')
        calculations = QueryBuilder().append(ArithmeticAddCalculation).all()[0]
        cls.arithmetic_job = calculations[0]
Example #23
0
    def create_trajectory_data():
        """Create TrajectoryData object with two arrays."""

        traj = TrajectoryData()

        # I create sample data
        stepids = np.array([60, 70])
        times = stepids * 0.01
        cells = np.array([[[
            2.,
            0.,
            0.,
        ], [
            0.,
            2.,
            0.,
        ], [
            0.,
            0.,
            2.,
        ]], [[
            3.,
            0.,
            0.,
        ], [
            0.,
            3.,
            0.,
        ], [
            0.,
            0.,
            3.,
        ]]])
        symbols = ['H', 'O', 'C']
        positions = np.array([[[0., 0., 0.], [0.5, 0.5, 0.5], [1.5, 1.5, 1.5]],
                              [[0., 0., 0.], [0.5, 0.5, 0.5], [1.5, 1.5,
                                                               1.5]]])
        velocities = np.array([[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
                               [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5],
                                [-0.5, -0.5, -0.5]]])

        # I set the node
        traj.set_trajectory(stepids=stepids,
                            cells=cells,
                            symbols=symbols,
                            positions=positions,
                            times=times,
                            velocities=velocities)

        traj.store()

        # Create 2 groups and add the data to one of them
        g_ne = Group(label='non_empty_group')
        g_ne.store()
        g_ne.add_nodes(traj)

        g_e = Group(label='empty_group')
        g_e.store()

        return {
            DummyVerdiDataListable.NODE_ID_STR: traj.id,
            DummyVerdiDataListable.NON_EMPTY_GROUP_ID_STR: g_ne.id,
            DummyVerdiDataListable.EMPTY_GROUP_ID_STR: g_e.id
        }
Example #24
0
 def setUpClass(cls, *args, **kwargs):
     super(TestVerdiGroupSetup, cls).setUpClass(*args, **kwargs)
     from aiida.orm import Group
     for grp in ["dummygroup1", "dummygroup2", "dummygroup3", "dummygroup4"]:
         Group(name=grp).store()
# Clear and delete CSS groups created in a previous session
delete_groups()

# Query for all the CURATED-COFs
qb = QueryBuilder()
qb.append(Group, filters={'label': {'like': r'curated-cof\_%\_v%'}})
all_groups = qb.all(flat=True)  #

# Create groups for each and fill them with desired nodes
print("Creating discovery groups, with tagged nodes:", dis_nodes)

all_dis_groups = []
for full_group in all_groups:
    mat_id = full_group.label.split("_")[1]
    dis_group = Group(label=GROUP_DIR + mat_id).store()
    all_dis_groups.append(dis_group)
    left_nodes = dis_nodes.copy()
    for node in full_group.nodes:
        if TAG_KEY not in node.extras:
            sys.exit("WARNING: node <{}> has no extra '{}'".format(
                node.id, TAG_KEY))
        if node.extras[TAG_KEY] in left_nodes:
            dis_group.add_nodes(node)
            left_nodes.remove(node.extras[TAG_KEY])
            if node.extras[
                    TAG_KEY] == 'orig_cif':  #to change at a certain point!
                dis_group.set_extra('mat_id', mat_id)
                for extra_key in [
                        'doi_ref', 'workflow_version', 'name_conventional',
                        'class_material'
Example #26
0
    def create_structure_bands():
        """Create bands structure object."""
        alat = 4.  # angstrom
        cell = [
            [
                alat,
                0.,
                0.,
            ],
            [
                0.,
                alat,
                0.,
            ],
            [
                0.,
                0.,
                alat,
            ],
        ]
        strct = StructureData(cell=cell)
        strct.append_atom(position=(0., 0., 0.), symbols='Fe')
        strct.append_atom(position=(alat / 2., alat / 2., alat / 2.),
                          symbols='O')
        strct.store()

        @calcfunction
        def connect_structure_bands(strct):  # pylint: disable=unused-argument
            alat = 4.
            cell = np.array([
                [alat, 0., 0.],
                [0., alat, 0.],
                [0., 0., alat],
            ])

            kpnts = KpointsData()
            kpnts.set_cell(cell)
            kpnts.set_kpoints([[0., 0., 0.], [0.1, 0.1, 0.1]])

            bands = BandsData()
            bands.set_kpointsdata(kpnts)
            bands.set_bands([[1.0, 2.0], [3.0, 4.0]])
            return bands

        bands = connect_structure_bands(strct)

        bands_isolated = BandsData()
        bands_isolated.store()

        # Create 2 groups and add the data to one of them
        g_ne = Group(label='non_empty_group')
        g_ne.store()
        g_ne.add_nodes(bands)
        g_ne.add_nodes(bands_isolated)

        g_e = Group(label='empty_group')
        g_e.store()

        return {
            DummyVerdiDataListable.NODE_ID_STR: bands.id,
            DummyVerdiDataListable.NON_EMPTY_GROUP_ID_STR: g_ne.id,
            DummyVerdiDataListable.EMPTY_GROUP_ID_STR: g_e.id
        }
Example #27
0
def import_data_dj(
    in_path,
    group=None,
    ignore_unknown_nodes=False,
    extras_mode_existing='kcl',
    extras_mode_new='import',
    comment_mode='newest',
    silent=False
):
    """Import exported AiiDA archive to the AiiDA database and repository.

    Specific for the Django backend.
    If ``in_path`` is a folder, calls extract_tree; otherwise, tries to detect the compression format
    (zip, tar.gz, tar.bz2, ...) and calls the correct function.

    :param in_path: the path to a file or folder that can be imported in AiiDA.
    :type in_path: str

    :param group: Group wherein all imported Nodes will be placed.
    :type group: :py:class:`~aiida.orm.groups.Group`

    :param extras_mode_existing: 3 letter code that will identify what to do with the extras import.
        The first letter acts on extras that are present in the original node and not present in the imported node.
        Can be either:
        'k' (keep it) or
        'n' (do not keep it).
        The second letter acts on the imported extras that are not present in the original node.
        Can be either:
        'c' (create it) or
        'n' (do not create it).
        The third letter defines what to do in case of a name collision.
        Can be either:
        'l' (leave the old value),
        'u' (update with a new value),
        'd' (delete the extra), or
        'a' (ask what to do if the content is different).
    :type extras_mode_existing: str

    :param extras_mode_new: 'import' to import extras of new nodes or 'none' to ignore them.
    :type extras_mode_new: str

    :param comment_mode: Comment import modes (when same UUIDs are found).
        Can be either:
        'newest' (will keep the Comment with the most recent modification time (mtime)) or
        'overwrite' (will overwrite existing Comments with the ones from the import file).
    :type comment_mode: str

    :param silent: suppress prints.
    :type silent: bool

    :return: New and existing Nodes and Links.
    :rtype: dict

    :raises `~aiida.tools.importexport.common.exceptions.ImportValidationError`: if parameters or the contents of
        `metadata.json` or `data.json` can not be validated.
    :raises `~aiida.tools.importexport.common.exceptions.CorruptArchive`: if the provided archive at ``in_path`` is
        corrupted.
    :raises `~aiida.tools.importexport.common.exceptions.IncompatibleArchiveVersionError`: if the provided archive's
        export version is not equal to the export version of AiiDA at the moment of import.
    :raises `~aiida.tools.importexport.common.exceptions.ArchiveImportError`: if there are any internal errors when
        importing.
    :raises `~aiida.tools.importexport.common.exceptions.ImportUniquenessError`: if a new unique entity can not be
        created.
    """
    from django.db import transaction  # pylint: disable=import-error,no-name-in-module
    from aiida.backends.djsite.db import models

    # This is the export version expected by this function
    expected_export_version = StrictVersion(EXPORT_VERSION)

    # The returned dictionary with new and existing nodes and links
    ret_dict = {}

    # Initial check(s)
    if group:
        if not isinstance(group, Group):
            raise exceptions.ImportValidationError('group must be a Group entity')
        elif not group.is_stored:
            group.store()

    ################
    # EXTRACT DATA #
    ################
    # The sandbox has to remain open until the end
    with SandboxFolder() as folder:
        if os.path.isdir(in_path):
            extract_tree(in_path, folder)
        else:
            if tarfile.is_tarfile(in_path):
                extract_tar(in_path, folder, silent=silent, nodes_export_subfolder=NODES_EXPORT_SUBFOLDER)
            elif zipfile.is_zipfile(in_path):
                try:
                    extract_zip(in_path, folder, silent=silent, nodes_export_subfolder=NODES_EXPORT_SUBFOLDER)
                except ValueError as exc:
                    print('The following problem occured while processing the provided file: {}'.format(exc))
                    return
            else:
                raise exceptions.ImportValidationError(
                    'Unable to detect the input file format, it is neither a '
                    '(possibly compressed) tar file, nor a zip file.'
                )

        if not folder.get_content_list():
            raise exceptions.CorruptArchive('The provided file/folder ({}) is empty'.format(in_path))
        try:
            with io.open(folder.get_abs_path('metadata.json'), 'r', encoding='utf8') as fhandle:
                metadata = json.load(fhandle)

            with io.open(folder.get_abs_path('data.json'), 'r', encoding='utf8') as fhandle:
                data = json.load(fhandle)
        except IOError as error:
            raise exceptions.CorruptArchive(
                'Unable to find the file {} in the import file or folder'.format(error.filename)
            )

        ######################
        # PRELIMINARY CHECKS #
        ######################
        export_version = StrictVersion(str(metadata['export_version']))
        if export_version != expected_export_version:
            msg = 'Export file version is {}, can import only version {}'\
                    .format(metadata['export_version'], expected_export_version)
            if export_version < expected_export_version:
                msg += "\nUse 'verdi export migrate' to update this export file."
            else:
                msg += '\nUpdate your AiiDA version in order to import this file.'

            raise exceptions.IncompatibleArchiveVersionError(msg)

        ##########################################################################
        # CREATE UUID REVERSE TABLES AND CHECK IF I HAVE ALL NODES FOR THE LINKS #
        ##########################################################################
        linked_nodes = set(chain.from_iterable((l['input'], l['output']) for l in data['links_uuid']))
        group_nodes = set(chain.from_iterable(six.itervalues(data['groups_uuid'])))

        if NODE_ENTITY_NAME in data['export_data']:
            import_nodes_uuid = set(v['uuid'] for v in data['export_data'][NODE_ENTITY_NAME].values())
        else:
            import_nodes_uuid = set()

        # the combined set of linked_nodes and group_nodes was obtained from looking at all the links
        # the set of import_nodes_uuid was received from the stuff actually referred to in export_data
        unknown_nodes = linked_nodes.union(group_nodes) - import_nodes_uuid

        if unknown_nodes and not ignore_unknown_nodes:
            raise exceptions.DanglingLinkError(
                'The import file refers to {} nodes with unknown UUID, therefore it cannot be imported. Either first '
                'import the unknown nodes, or export also the parents when exporting. The unknown UUIDs are:\n'
                ''.format(len(unknown_nodes)) + '\n'.join('* {}'.format(uuid) for uuid in unknown_nodes)
            )

        ###################################
        # DOUBLE-CHECK MODEL DEPENDENCIES #
        ###################################
        # The entity import order. It is defined by the database model relationships.

        model_order = (
            USER_ENTITY_NAME, COMPUTER_ENTITY_NAME, NODE_ENTITY_NAME, GROUP_ENTITY_NAME, LOG_ENTITY_NAME,
            COMMENT_ENTITY_NAME
        )

        for import_field_name in metadata['all_fields_info']:
            if import_field_name not in model_order:
                raise exceptions.ImportValidationError(
                    "You are trying to import an unknown model '{}'!".format(import_field_name)
                )

        for idx, model_name in enumerate(model_order):
            dependencies = []
            for field in metadata['all_fields_info'][model_name].values():
                try:
                    dependencies.append(field['requires'])
                except KeyError:
                    # (No ForeignKey)
                    pass
            for dependency in dependencies:
                if dependency not in model_order[:idx]:
                    raise exceptions.ArchiveImportError(
                        'Model {} requires {} but would be loaded first; stopping...'.format(model_name, dependency)
                    )

        ###################################################
        # CREATE IMPORT DATA DIRECT UNIQUE_FIELD MAPPINGS #
        ###################################################
        import_unique_ids_mappings = {}
        for model_name, import_data in data['export_data'].items():
            if model_name in metadata['unique_identifiers']:
                # I have to reconvert the pk to integer
                import_unique_ids_mappings[model_name] = {
                    int(k): v[metadata['unique_identifiers'][model_name]] for k, v in import_data.items()
                }

        ###############
        # IMPORT DATA #
        ###############
        # DO ALL WITH A TRANSACTION
        with transaction.atomic():
            foreign_ids_reverse_mappings = {}
            new_entries = {}
            existing_entries = {}

            # I first generate the list of data
            for model_name in model_order:
                cls_signature = entity_names_to_signatures[model_name]
                model = get_object_from_string(cls_signature)
                fields_info = metadata['all_fields_info'].get(model_name, {})
                unique_identifier = metadata['unique_identifiers'].get(model_name, None)

                new_entries[model_name] = {}
                existing_entries[model_name] = {}

                foreign_ids_reverse_mappings[model_name] = {}

                # Not necessarily all models are exported
                if model_name in data['export_data']:

                    # skip nodes that are already present in the DB
                    if unique_identifier is not None:
                        import_unique_ids = set(v[unique_identifier] for v in data['export_data'][model_name].values())

                        relevant_db_entries_result = model.objects.filter(
                            **{'{}__in'.format(unique_identifier): import_unique_ids}
                        )
                        # Note: uuids need to be converted to strings
                        relevant_db_entries = {
                            str(getattr(n, unique_identifier)): n for n in relevant_db_entries_result
                        }

                        foreign_ids_reverse_mappings[model_name] = {k: v.pk for k, v in relevant_db_entries.items()}
                        for key, value in data['export_data'][model_name].items():
                            if value[unique_identifier] in relevant_db_entries.keys():
                                # Already in DB
                                existing_entries[model_name][key] = value
                            else:
                                # To be added
                                new_entries[model_name][key] = value
                    else:
                        new_entries[model_name] = data['export_data'][model_name].copy()

            # Show Comment mode if not silent
            if not silent:
                print('Comment mode: {}'.format(comment_mode))

            # I import data from the given model
            for model_name in model_order:
                cls_signature = entity_names_to_signatures[model_name]
                model = get_object_from_string(cls_signature)
                fields_info = metadata['all_fields_info'].get(model_name, {})
                unique_identifier = metadata['unique_identifiers'].get(model_name, None)

                # EXISTING ENTRIES
                for import_entry_pk, entry_data in existing_entries[model_name].items():
                    unique_id = entry_data[unique_identifier]
                    existing_entry_id = foreign_ids_reverse_mappings[model_name][unique_id]
                    import_data = dict(
                        deserialize_field(
                            k,
                            v,
                            fields_info=fields_info,
                            import_unique_ids_mappings=import_unique_ids_mappings,
                            foreign_ids_reverse_mappings=foreign_ids_reverse_mappings
                        ) for k, v in entry_data.items()
                    )
                    # TODO COMPARE, AND COMPARE ATTRIBUTES

                    if model is models.DbComment:
                        new_entry_uuid = merge_comment(import_data, comment_mode)
                        if new_entry_uuid is not None:
                            entry_data[unique_identifier] = new_entry_uuid
                            new_entries[model_name][import_entry_pk] = entry_data

                    if model_name not in ret_dict:
                        ret_dict[model_name] = {'new': [], 'existing': []}
                    ret_dict[model_name]['existing'].append((import_entry_pk, existing_entry_id))
                    if not silent:
                        print('existing %s: %s (%s->%s)' % (model_name, unique_id, import_entry_pk, existing_entry_id))
                        # print("  `-> WARNING: NO DUPLICITY CHECK DONE!")
                        # CHECK ALSO FILES!

                # Store all objects for this model in a list, and store them all in once at the end.
                objects_to_create = []
                # This is needed later to associate the import entry with the new pk
                import_new_entry_pks = {}
                imported_comp_names = set()

                # NEW ENTRIES
                for import_entry_pk, entry_data in new_entries[model_name].items():
                    unique_id = entry_data[unique_identifier]
                    import_data = dict(
                        deserialize_field(
                            k,
                            v,
                            fields_info=fields_info,
                            import_unique_ids_mappings=import_unique_ids_mappings,
                            foreign_ids_reverse_mappings=foreign_ids_reverse_mappings
                        ) for k, v in entry_data.items()
                    )

                    if model is models.DbGroup:
                        # Check if there is already a group with the same name
                        dupl_counter = 0
                        orig_label = import_data['label']
                        while model.objects.filter(label=import_data['label']):
                            import_data['label'] = orig_label + DUPL_SUFFIX.format(dupl_counter)
                            dupl_counter += 1
                            if dupl_counter == 100:
                                raise exceptions.ImportUniquenessError(
                                    'A group of that label ( {} ) already exists and I could not create a new one'
                                    ''.format(orig_label)
                                )

                    elif model is models.DbComputer:
                        # Check if there is already a computer with the same name in the database
                        dupl = (
                            model.objects.filter(name=import_data['name']) or import_data['name'] in imported_comp_names
                        )
                        orig_name = import_data['name']
                        dupl_counter = 0
                        while dupl:
                            # Rename the new computer
                            import_data['name'] = (orig_name + DUPL_SUFFIX.format(dupl_counter))
                            dupl = (
                                model.objects.filter(name=import_data['name']) or
                                import_data['name'] in imported_comp_names
                            )
                            dupl_counter += 1
                            if dupl_counter == 100:
                                raise exceptions.ImportUniquenessError(
                                    'A computer of that name ( {} ) already exists and I could not create a new one'
                                    ''.format(orig_name)
                                )

                        imported_comp_names.add(import_data['name'])

                    objects_to_create.append(model(**import_data))
                    import_new_entry_pks[unique_id] = import_entry_pk

                if model_name == NODE_ENTITY_NAME:
                    if not silent:
                        print('STORING NEW NODE REPOSITORY FILES...')

                    # NEW NODES
                    for object_ in objects_to_create:
                        import_entry_uuid = object_.uuid
                        import_entry_pk = import_new_entry_pks[import_entry_uuid]

                        # Before storing entries in the DB, I store the files (if these are nodes).
                        # Note: only for new entries!
                        subfolder = folder.get_subfolder(
                            os.path.join(NODES_EXPORT_SUBFOLDER, export_shard_uuid(import_entry_uuid))
                        )
                        if not subfolder.exists():
                            raise exceptions.CorruptArchive(
                                'Unable to find the repository folder for Node with UUID={} in the exported '
                                'file'.format(import_entry_uuid)
                            )
                        destdir = RepositoryFolder(section=Repository._section_name, uuid=import_entry_uuid)
                        # Replace the folder, possibly destroying existing previous folders, and move the files
                        # (faster if we are on the same filesystem, and in any case the source is a SandboxFolder)
                        destdir.replace_with_folder(subfolder.abspath, move=True, overwrite=True)

                        # For DbNodes, we also have to store its attributes
                        if not silent:
                            print('STORING NEW NODE ATTRIBUTES...')

                        # Get attributes from import file
                        try:
                            object_.attributes = data['node_attributes'][str(import_entry_pk)]
                        except KeyError:
                            raise exceptions.CorruptArchive(
                                'Unable to find attribute info for Node with UUID={}'.format(import_entry_uuid)
                            )

                        # For DbNodes, we also have to store its extras
                        if extras_mode_new == 'import':
                            if not silent:
                                print('STORING NEW NODE EXTRAS...')

                            # Get extras from import file
                            try:
                                extras = data['node_extras'][str(import_entry_pk)]
                            except KeyError:
                                raise exceptions.CorruptArchive(
                                    'Unable to find extra info for Node with UUID={}'.format(import_entry_uuid)
                                )
                            # TODO: remove when aiida extras will be moved somewhere else
                            # from here
                            extras = {key: value for key, value in extras.items() if not key.startswith('_aiida_')}
                            if object_.node_type.endswith('code.Code.'):
                                extras = {key: value for key, value in extras.items() if not key == 'hidden'}
                            # till here
                            object_.extras = extras
                        elif extras_mode_new == 'none':
                            if not silent:
                                print('SKIPPING NEW NODE EXTRAS...')
                        else:
                            raise exceptions.ImportValidationError(
                                "Unknown extras_mode_new value: {}, should be either 'import' or 'none'"
                                ''.format(extras_mode_new)
                            )

                    # EXISTING NODES (Extras)
                    # For the existing nodes that are also in the imported list we also update their extras if necessary
                    if not silent:
                        print('UPDATING EXISTING NODE EXTRAS (mode: {})'.format(extras_mode_existing))

                    import_existing_entry_pks = {
                        entry_data[unique_identifier]: import_entry_pk
                        for import_entry_pk, entry_data in existing_entries[model_name].items()
                    }
                    for node in models.DbNode.objects.filter(uuid__in=import_existing_entry_pks).all():  # pylint: disable=no-member
                        import_entry_uuid = str(node.uuid)
                        import_entry_pk = import_existing_entry_pks[import_entry_uuid]

                        # Get extras from import file
                        try:
                            extras = data['node_extras'][str(import_entry_pk)]
                        except KeyError:
                            raise exceptions.CorruptArchive(
                                'Unable to find extra info for ode with UUID={}'.format(import_entry_uuid)
                            )

                        # TODO: remove when aiida extras will be moved somewhere else
                        # from here
                        extras = {key: value for key, value in extras.items() if not key.startswith('_aiida_')}
                        if node.node_type.endswith('code.Code.'):
                            extras = {key: value for key, value in extras.items() if not key == 'hidden'}
                        # till here
                        node.extras = merge_extras(node.extras, extras, extras_mode_existing)

                        # Already saving existing node here to update its extras
                        node.save()

                # If there is an mtime in the field, disable the automatic update
                # to keep the mtime that we have set here
                if 'mtime' in [field.name for field in model._meta.local_fields]:
                    with models.suppress_auto_now([(model, ['mtime'])]):
                        # Store them all in once; however, the PK are not set in this way...
                        model.objects.bulk_create(objects_to_create)
                else:
                    model.objects.bulk_create(objects_to_create)

                # Get back the just-saved entries
                just_saved_queryset = model.objects.filter(
                    **{
                        '{}__in'.format(unique_identifier): import_new_entry_pks.keys()
                    }
                ).values_list(unique_identifier, 'pk')
                # note: convert uuids from type UUID to strings
                just_saved = {str(key): value for key, value in just_saved_queryset}

                # Now I have the PKs, print the info
                # Moreover, add newly created Nodes to foreign_ids_reverse_mappings
                for unique_id, new_pk in just_saved.items():
                    import_entry_pk = import_new_entry_pks[unique_id]
                    foreign_ids_reverse_mappings[model_name][unique_id] = new_pk
                    if model_name not in ret_dict:
                        ret_dict[model_name] = {'new': [], 'existing': []}
                    ret_dict[model_name]['new'].append((import_entry_pk, new_pk))

                    if not silent:
                        print('NEW %s: %s (%s->%s)' % (model_name, unique_id, import_entry_pk, new_pk))

            if not silent:
                print('STORING NODE LINKS...')
            import_links = data['links_uuid']
            links_to_store = []

            # Needed, since QueryBuilder does not yet work for recently saved Nodes
            existing_links_raw = models.DbLink.objects.all().values_list('input', 'output', 'label', 'type')
            existing_links = {(l[0], l[1], l[2], l[3]) for l in existing_links_raw}
            existing_outgoing_unique = {(l[0], l[3]) for l in existing_links_raw}
            existing_outgoing_unique_pair = {(l[0], l[2], l[3]) for l in existing_links_raw}
            existing_incoming_unique = {(l[1], l[3]) for l in existing_links_raw}
            existing_incoming_unique_pair = {(l[1], l[2], l[3]) for l in existing_links_raw}

            calculation_node_types = 'process.calculation.'
            workflow_node_types = 'process.workflow.'
            data_node_types = 'data.'

            link_mapping = {
                LinkType.CALL_CALC: (workflow_node_types, calculation_node_types, 'unique_triple', 'unique'),
                LinkType.CALL_WORK: (workflow_node_types, workflow_node_types, 'unique_triple', 'unique'),
                LinkType.CREATE: (calculation_node_types, data_node_types, 'unique_pair', 'unique'),
                LinkType.INPUT_CALC: (data_node_types, calculation_node_types, 'unique_triple', 'unique_pair'),
                LinkType.INPUT_WORK: (data_node_types, workflow_node_types, 'unique_triple', 'unique_pair'),
                LinkType.RETURN: (workflow_node_types, data_node_types, 'unique_pair', 'unique_triple'),
            }

            for link in import_links:
                # Check for dangling Links within the, supposed, self-consistent archive
                try:
                    in_id = foreign_ids_reverse_mappings[NODE_ENTITY_NAME][link['input']]
                    out_id = foreign_ids_reverse_mappings[NODE_ENTITY_NAME][link['output']]
                except KeyError:
                    if ignore_unknown_nodes:
                        continue
                    else:
                        raise exceptions.ImportValidationError(
                            'Trying to create a link with one or both unknown nodes, stopping (in_uuid={}, '
                            'out_uuid={}, label={}, type={})'.format(
                                link['input'], link['output'], link['label'], link['type']
                            )
                        )

                # Check if link already exists, skip if it does
                # This is equivalent to an existing triple link (i.e. unique_triple from below)
                if (in_id, out_id, link['label'], link['type']) in existing_links:
                    continue

                # Since backend specific Links (DbLink) are not validated upon creation, we will now validate them.
                try:
                    validate_link_label(link['label'])
                except ValueError as why:
                    raise exceptions.ImportValidationError('Error during Link label validation: {}'.format(why))

                source = models.DbNode.objects.get(id=in_id)
                target = models.DbNode.objects.get(id=out_id)

                if source.uuid == target.uuid:
                    raise exceptions.ImportValidationError('Cannot add a link to oneself')

                link_type = LinkType(link['type'])
                type_source, type_target, outdegree, indegree = link_mapping[link_type]

                # Check if source Node is a valid type
                if not source.node_type.startswith(type_source):
                    raise exceptions.ImportValidationError(
                        'Cannot add a {} link from {} to {}'.format(link_type, source.node_type, target.node_type)
                    )

                # Check if target Node is a valid type
                if not target.node_type.startswith(type_target):
                    raise exceptions.ImportValidationError(
                        'Cannot add a {} link from {} to {}'.format(link_type, source.node_type, target.node_type)
                    )

                # If the outdegree is `unique` there cannot already be any other outgoing link of that type,
                # i.e., the source Node may not have a LinkType of current LinkType, going out, existing already.
                if outdegree == 'unique' and (in_id, link['type']) in existing_outgoing_unique:
                    raise exceptions.ImportValidationError(
                        'Node<{}> already has an outgoing {} link'.format(source.uuid, link_type)
                    )

                # If the outdegree is `unique_pair`,
                # then the link labels for outgoing links of this type should be unique,
                # i.e., the source Node may not have a LinkType of current LinkType, going out,
                # that also has the current Link label, existing already.
                elif outdegree == 'unique_pair' and \
                (in_id, link['label'], link['type']) in existing_outgoing_unique_pair:
                    raise exceptions.ImportValidationError(
                        'Node<{}> already has an outgoing {} link with label "{}"'.format(
                            source.uuid, link_type, link['label']
                        )
                    )

                # If the indegree is `unique` there cannot already be any other incoming links of that type,
                # i.e., the target Node may not have a LinkType of current LinkType, coming in, existing already.
                if indegree == 'unique' and (out_id, link['type']) in existing_incoming_unique:
                    raise exceptions.ImportValidationError(
                        'Node<{}> already has an incoming {} link'.format(target.uuid, link_type)
                    )

                # If the indegree is `unique_pair`,
                # then the link labels for incoming links of this type should be unique,
                # i.e., the target Node may not have a LinkType of current LinkType, coming in
                # that also has the current Link label, existing already.
                elif indegree == 'unique_pair' and \
                (out_id, link['label'], link['type']) in existing_incoming_unique_pair:
                    raise exceptions.ImportValidationError(
                        'Node<{}> already has an incoming {} link with label "{}"'.format(
                            target.uuid, link_type, link['label']
                        )
                    )

                # New link
                links_to_store.append(
                    models.DbLink(input_id=in_id, output_id=out_id, label=link['label'], type=link['type'])
                )
                if 'Link' not in ret_dict:
                    ret_dict['Link'] = {'new': []}
                ret_dict['Link']['new'].append((in_id, out_id))

                # Add new Link to sets of existing Links 'input PK', 'output PK', 'label', 'type'
                existing_links.add((in_id, out_id, link['label'], link['type']))
                existing_outgoing_unique.add((in_id, link['type']))
                existing_outgoing_unique_pair.add((in_id, link['label'], link['type']))
                existing_incoming_unique.add((out_id, link['type']))
                existing_incoming_unique_pair.add((out_id, link['label'], link['type']))

            # Store new links
            if links_to_store:
                if not silent:
                    print('   ({} new links...)'.format(len(links_to_store)))

                models.DbLink.objects.bulk_create(links_to_store)
            else:
                if not silent:
                    print('   (0 new links...)')

            if not silent:
                print('STORING GROUP ELEMENTS...')
            import_groups = data['groups_uuid']
            for groupuuid, groupnodes in import_groups.items():
                # TODO: cache these to avoid too many queries
                group_ = models.DbGroup.objects.get(uuid=groupuuid)
                nodes_to_store = [foreign_ids_reverse_mappings[NODE_ENTITY_NAME][node_uuid] for node_uuid in groupnodes]
                if nodes_to_store:
                    group_.dbnodes.add(*nodes_to_store)

        ######################################################
        # Put everything in a specific group
        ######################################################
        existing = existing_entries.get(NODE_ENTITY_NAME, {})
        existing_pk = [foreign_ids_reverse_mappings[NODE_ENTITY_NAME][v['uuid']] for v in six.itervalues(existing)]
        new = new_entries.get(NODE_ENTITY_NAME, {})
        new_pk = [foreign_ids_reverse_mappings[NODE_ENTITY_NAME][v['uuid']] for v in six.itervalues(new)]

        pks_for_group = existing_pk + new_pk

        # So that we do not create empty groups
        if pks_for_group:
            # If user specified a group, import all things into it
            if not group:
                # Get an unique name for the import group, based on the current (local) time
                basename = timezone.localtime(timezone.now()).strftime('%Y%m%d-%H%M%S')
                counter = 0
                group_label = basename

                while Group.objects.find(filters={'label': group_label}):
                    counter += 1
                    group_label = '{}_{}'.format(basename, counter)

                    if counter == 100:
                        raise exceptions.ImportUniquenessError(
                            "Overflow of import groups (more than 100 import groups exists with basename '{}')"
                            ''.format(basename)
                        )
                group = Group(label=group_label, type_string=IMPORTGROUP_TYPE).store()

            # Add all the nodes to the new group
            # TODO: decide if we want to return the group label
            nodes = [entry[0] for entry in QueryBuilder().append(Node, filters={'id': {'in': pks_for_group}}).all()]
            group.add_nodes(nodes)

            if not silent:
                print("IMPORTED NODES ARE GROUPED IN THE IMPORT GROUP LABELED '{}'".format(group.label))
        else:
            if not silent:
                print('NO NODES TO IMPORT, SO NO GROUP CREATED, IF IT DID NOT ALREADY EXIST')

    if not silent:
        print('*** WARNING: MISSING EXISTING UUID CHECKS!!')
        print('*** WARNING: TODO: UPDATE IMPORT_DATA WITH DEFAULT VALUES! (e.g. calc status, user pwd, ...)')
        print('DONE.')

    return ret_dict
Example #28
0
def import_data_sqla(
    in_path,
    group=None,
    ignore_unknown_nodes=False,
    extras_mode_existing='kcl',
    extras_mode_new='import',
    comment_mode='newest',
    silent=False
):
    """Import exported AiiDA archive to the AiiDA database and repository.

    Specific for the SQLAlchemy backend.
    If ``in_path`` is a folder, calls extract_tree; otherwise, tries to detect the compression format
    (zip, tar.gz, tar.bz2, ...) and calls the correct function.

    :param in_path: the path to a file or folder that can be imported in AiiDA.
    :type in_path: str

    :param group: Group wherein all imported Nodes will be placed.
    :type group: :py:class:`~aiida.orm.groups.Group`

    :param extras_mode_existing: 3 letter code that will identify what to do with the extras import.
        The first letter acts on extras that are present in the original node and not present in the imported node.
        Can be either:
        'k' (keep it) or
        'n' (do not keep it).
        The second letter acts on the imported extras that are not present in the original node.
        Can be either:
        'c' (create it) or
        'n' (do not create it).
        The third letter defines what to do in case of a name collision.
        Can be either:
        'l' (leave the old value),
        'u' (update with a new value),
        'd' (delete the extra), or
        'a' (ask what to do if the content is different).
    :type extras_mode_existing: str

    :param extras_mode_new: 'import' to import extras of new nodes or 'none' to ignore them.
    :type extras_mode_new: str

    :param comment_mode: Comment import modes (when same UUIDs are found).
        Can be either:
        'newest' (will keep the Comment with the most recent modification time (mtime)) or
        'overwrite' (will overwrite existing Comments with the ones from the import file).
    :type comment_mode: str

    :param silent: suppress prints.
    :type silent: bool

    :return: New and existing Nodes and Links.
    :rtype: dict

    :raises `~aiida.tools.importexport.common.exceptions.ImportValidationError`: if parameters or the contents of
        `metadata.json` or `data.json` can not be validated.
    :raises `~aiida.tools.importexport.common.exceptions.CorruptArchive`: if the provided archive at ``in_path`` is
        corrupted.
    :raises `~aiida.tools.importexport.common.exceptions.IncompatibleArchiveVersionError`: if the provided archive's
        export version is not equal to the export version of AiiDA at the moment of import.
    :raises `~aiida.tools.importexport.common.exceptions.ArchiveImportError`: if there are any internal errors when
        importing.
    :raises `~aiida.tools.importexport.common.exceptions.ImportUniquenessError`: if a new unique entity can not be
        created.
    """
    from aiida.backends.sqlalchemy.models.node import DbNode, DbLink
    from aiida.backends.sqlalchemy.utils import flag_modified

    # This is the export version expected by this function
    expected_export_version = StrictVersion(EXPORT_VERSION)

    # The returned dictionary with new and existing nodes and links
    ret_dict = {}

    # Initial check(s)
    if group:
        if not isinstance(group, Group):
            raise exceptions.ImportValidationError('group must be a Group entity')
        elif not group.is_stored:
            group.store()

    ################
    # EXTRACT DATA #
    ################
    # The sandbox has to remain open until the end
    with SandboxFolder() as folder:
        if os.path.isdir(in_path):
            extract_tree(in_path, folder)
        else:
            if tarfile.is_tarfile(in_path):
                extract_tar(in_path, folder, silent=silent, nodes_export_subfolder=NODES_EXPORT_SUBFOLDER)
            elif zipfile.is_zipfile(in_path):
                extract_zip(in_path, folder, silent=silent, nodes_export_subfolder=NODES_EXPORT_SUBFOLDER)
            else:
                raise exceptions.ImportValidationError(
                    'Unable to detect the input file format, it is neither a '
                    '(possibly compressed) tar file, nor a zip file.'
                )

        if not folder.get_content_list():
            raise exceptions.CorruptArchive('The provided file/folder ({}) is empty'.format(in_path))
        try:
            with io.open(folder.get_abs_path('metadata.json'), encoding='utf8') as fhandle:
                metadata = json.load(fhandle)

            with io.open(folder.get_abs_path('data.json'), encoding='utf8') as fhandle:
                data = json.load(fhandle)
        except IOError as error:
            raise exceptions.CorruptArchive(
                'Unable to find the file {} in the import file or folder'.format(error.filename)
            )

        ######################
        # PRELIMINARY CHECKS #
        ######################
        export_version = StrictVersion(str(metadata['export_version']))
        if export_version != expected_export_version:
            msg = 'Export file version is {}, can import only version {}'\
                    .format(metadata['export_version'], expected_export_version)
            if export_version < expected_export_version:
                msg += "\nUse 'verdi export migrate' to update this export file."
            else:
                msg += '\nUpdate your AiiDA version in order to import this file.'

            raise exceptions.IncompatibleArchiveVersionError(msg)

        ###################################################################
        #           CREATE UUID REVERSE TABLES AND CHECK IF               #
        #              I HAVE ALL NODES FOR THE LINKS                     #
        ###################################################################
        linked_nodes = set(chain.from_iterable((l['input'], l['output']) for l in data['links_uuid']))
        group_nodes = set(chain.from_iterable(six.itervalues(data['groups_uuid'])))

        # Check that UUIDs are valid
        linked_nodes = set(x for x in linked_nodes if validate_uuid(x))
        group_nodes = set(x for x in group_nodes if validate_uuid(x))

        import_nodes_uuid = set()
        for value in data['export_data'].get(NODE_ENTITY_NAME, {}).values():
            import_nodes_uuid.add(value['uuid'])

        unknown_nodes = linked_nodes.union(group_nodes) - import_nodes_uuid

        if unknown_nodes and not ignore_unknown_nodes:
            raise exceptions.DanglingLinkError(
                'The import file refers to {} nodes with unknown UUID, therefore it cannot be imported. Either first '
                'import the unknown nodes, or export also the parents when exporting. The unknown UUIDs are:\n'
                ''.format(len(unknown_nodes)) + '\n'.join('* {}'.format(uuid) for uuid in unknown_nodes)
            )

        ###################################
        # DOUBLE-CHECK MODEL DEPENDENCIES #
        ###################################
        # The entity import order. It is defined by the database model relationships.
        entity_sig_order = [
            entity_names_to_signatures[m] for m in (
                USER_ENTITY_NAME, COMPUTER_ENTITY_NAME, NODE_ENTITY_NAME, GROUP_ENTITY_NAME, LOG_ENTITY_NAME,
                COMMENT_ENTITY_NAME
            )
        ]

        #  I make a new list that contains the entity names:
        # eg: ['User', 'Computer', 'Node', 'Group']
        all_entity_names = [signatures_to_entity_names[entity_sig] for entity_sig in entity_sig_order]
        for import_field_name in metadata['all_fields_info']:
            if import_field_name not in all_entity_names:
                raise exceptions.ImportValidationError(
                    "You are trying to import an unknown model '{}'!".format(import_field_name)
                )

        for idx, entity_sig in enumerate(entity_sig_order):
            dependencies = []
            entity_name = signatures_to_entity_names[entity_sig]
            # for every field, I checked the dependencies given as value for key requires
            for field in metadata['all_fields_info'][entity_name].values():
                try:
                    dependencies.append(field['requires'])
                except KeyError:
                    # (No ForeignKey)
                    pass
            for dependency in dependencies:
                if dependency not in all_entity_names[:idx]:
                    raise exceptions.ArchiveImportError(
                        'Entity {} requires {} but would be loaded first; stopping...'.format(entity_sig, dependency)
                    )

        ###################################################
        # CREATE IMPORT DATA DIRECT UNIQUE_FIELD MAPPINGS #
        ###################################################
        # This is nested dictionary of entity_name:{id:uuid}
        # to map one id (the pk) to a different one.
        # One of the things to remove for v0.4
        # {
        # u'Node': {2362: u'82a897b5-fb3a-47d7-8b22-c5fe1b4f2c14',
        #           2363: u'ef04aa5d-99e7-4bfd-95ef-fe412a6a3524', 2364: u'1dc59576-af21-4d71-81c2-bac1fc82a84a'},
        # u'User': {1: u'aiida@localhost'}
        # }
        import_unique_ids_mappings = {}
        # Export data since v0.3 contains the keys entity_name
        for entity_name, import_data in data['export_data'].items():
            # Again I need the entity_name since that's what's being stored since 0.3
            if entity_name in metadata['unique_identifiers']:
                # I have to reconvert the pk to integer
                import_unique_ids_mappings[entity_name] = {
                    int(k): v[metadata['unique_identifiers'][entity_name]] for k, v in import_data.items()
                }
        ###############
        # IMPORT DATA #
        ###############
        # DO ALL WITH A TRANSACTION
        import aiida.backends.sqlalchemy

        session = aiida.backends.sqlalchemy.get_scoped_session()

        try:
            foreign_ids_reverse_mappings = {}
            new_entries = {}
            existing_entries = {}

            # I first generate the list of data
            for entity_sig in entity_sig_order:
                entity_name = signatures_to_entity_names[entity_sig]
                entity = entity_names_to_entities[entity_name]
                # I get the unique identifier, since v0.3 stored under entity_name
                unique_identifier = metadata['unique_identifiers'].get(entity_name, None)

                # so, new_entries. Also, since v0.3 it makes more sense to use the entity_name
                new_entries[entity_name] = {}
                existing_entries[entity_name] = {}
                foreign_ids_reverse_mappings[entity_name] = {}

                # Not necessarily all models are exported
                if entity_name in data['export_data']:

                    if unique_identifier is not None:
                        import_unique_ids = set(v[unique_identifier] for v in data['export_data'][entity_name].values())

                        relevant_db_entries = dict()
                        if import_unique_ids:
                            builder = QueryBuilder()
                            builder.append(
                                entity,
                                filters={unique_identifier: {
                                    'in': import_unique_ids
                                }},
                                project=['*'],
                                tag='res'
                            )
                            relevant_db_entries = {
                                str(getattr(v[0], unique_identifier)):  # str() to convert UUID() to string
                                v[0] for v in builder.all()
                            }

                            foreign_ids_reverse_mappings[entity_name] = {
                                k: v.pk for k, v in relevant_db_entries.items()
                            }

                        imported_comp_names = set()
                        for key, value in data['export_data'][entity_name].items():
                            if entity_name == GROUP_ENTITY_NAME:
                                # Check if there is already a group with the same name,
                                # and if so, recreate the name
                                orig_label = value['label']
                                dupl_counter = 0
                                while QueryBuilder().append(entity, filters={'label': {'==': value['label']}}).count():
                                    # Rename the new group
                                    value['label'] = orig_label + DUPL_SUFFIX.format(dupl_counter)
                                    dupl_counter += 1
                                    if dupl_counter == 100:
                                        raise exceptions.ImportUniquenessError(
                                            'A group of that label ( {} ) already exists and I could not create a new '
                                            'one'.format(orig_label)
                                        )

                            elif entity_name == COMPUTER_ENTITY_NAME:
                                # The following is done for compatibility
                                # reasons in case the export file was generated
                                # with the Django export method. In Django the
                                # metadata and the transport parameters are
                                # stored as (unicode) strings of the serialized
                                # JSON objects and not as simple serialized
                                # JSON objects.
                                if isinstance(value['metadata'], (six.string_types, six.binary_type)):
                                    value['metadata'] = json.loads(value['metadata'])

                                # Check if there is already a computer with the
                                # same name in the database
                                builder = QueryBuilder()
                                builder.append(
                                    entity, filters={'name': {
                                        '==': value['name']
                                    }}, project=['*'], tag='res'
                                )
                                dupl = (builder.count() or value['name'] in imported_comp_names)
                                dupl_counter = 0
                                orig_name = value['name']
                                while dupl:
                                    # Rename the new computer
                                    value['name'] = (orig_name + DUPL_SUFFIX.format(dupl_counter))
                                    builder = QueryBuilder()
                                    builder.append(
                                        entity, filters={'name': {
                                            '==': value['name']
                                        }}, project=['*'], tag='res'
                                    )
                                    dupl = (builder.count() or value['name'] in imported_comp_names)
                                    dupl_counter += 1
                                    if dupl_counter == 100:
                                        raise exceptions.ImportUniquenessError(
                                            'A computer of that name ( {} ) already exists and I could not create a '
                                            'new one'.format(orig_name)
                                        )

                                imported_comp_names.add(value['name'])

                            if value[unique_identifier] in relevant_db_entries:
                                # Already in DB
                                # again, switched to entity_name in v0.3
                                existing_entries[entity_name][key] = value
                            else:
                                # To be added
                                new_entries[entity_name][key] = value
                    else:
                        # Why the copy:
                        new_entries[entity_name] = data['export_data'][entity_name].copy()

            # Show Comment mode if not silent
            if not silent:
                print('Comment mode: {}'.format(comment_mode))

            # I import data from the given model
            for entity_sig in entity_sig_order:
                entity_name = signatures_to_entity_names[entity_sig]
                entity = entity_names_to_entities[entity_name]
                fields_info = metadata['all_fields_info'].get(entity_name, {})
                unique_identifier = metadata['unique_identifiers'].get(entity_name, '')

                # EXISTING ENTRIES
                for import_entry_pk, entry_data in existing_entries[entity_name].items():
                    unique_id = entry_data[unique_identifier]
                    existing_entry_pk = foreign_ids_reverse_mappings[entity_name][unique_id]
                    import_data = dict(
                        deserialize_field(
                            k,
                            v,
                            fields_info=fields_info,
                            import_unique_ids_mappings=import_unique_ids_mappings,
                            foreign_ids_reverse_mappings=foreign_ids_reverse_mappings
                        ) for k, v in entry_data.items()
                    )
                    # TODO COMPARE, AND COMPARE ATTRIBUTES

                    if entity_sig is entity_names_to_signatures[COMMENT_ENTITY_NAME]:
                        new_entry_uuid = merge_comment(import_data, comment_mode)
                        if new_entry_uuid is not None:
                            entry_data[unique_identifier] = new_entry_uuid
                            new_entries[entity_name][import_entry_pk] = entry_data

                    if entity_name not in ret_dict:
                        ret_dict[entity_name] = {'new': [], 'existing': []}
                    ret_dict[entity_name]['existing'].append((import_entry_pk, existing_entry_pk))
                    if not silent:
                        print('existing %s: %s (%s->%s)' % (entity_sig, unique_id, import_entry_pk, existing_entry_pk))

                # Store all objects for this model in a list, and store them
                # all in once at the end.
                objects_to_create = list()
                # In the following list we add the objects to be updated
                objects_to_update = list()
                # This is needed later to associate the import entry with the new pk
                import_new_entry_pks = dict()

                # NEW ENTRIES
                for import_entry_pk, entry_data in new_entries[entity_name].items():
                    unique_id = entry_data[unique_identifier]
                    import_data = dict(
                        deserialize_field(
                            k,
                            v,
                            fields_info=fields_info,
                            import_unique_ids_mappings=import_unique_ids_mappings,
                            foreign_ids_reverse_mappings=foreign_ids_reverse_mappings
                        ) for k, v in entry_data.items()
                    )

                    # We convert the Django fields to SQLA. Note that some of
                    # the Django fields were converted to SQLA compatible
                    # fields by the deserialize_field method. This was done
                    # for optimization reasons in Django but makes them
                    # compatible with the SQLA schema and they don't need any
                    # further conversion.
                    if entity_name in file_fields_to_model_fields:
                        for file_fkey in file_fields_to_model_fields[entity_name]:

                            # This is an exception because the DbLog model defines the `_metadata` column instead of the
                            # `metadata` column used in the Django model. This is because the SqlAlchemy model base
                            # class already has a metadata attribute that cannot be overridden. For consistency, the
                            # `DbLog` class however expects the `metadata` keyword in its constructor, so we should
                            # ignore the mapping here
                            if entity_name == LOG_ENTITY_NAME and file_fkey == 'metadata':
                                continue

                            model_fkey = file_fields_to_model_fields[entity_name][file_fkey]
                            if model_fkey in import_data:
                                continue
                            import_data[model_fkey] = import_data[file_fkey]
                            import_data.pop(file_fkey, None)

                    db_entity = get_object_from_string(entity_names_to_sqla_schema[entity_name])

                    objects_to_create.append(db_entity(**import_data))
                    import_new_entry_pks[unique_id] = import_entry_pk

                if entity_sig == entity_names_to_signatures[NODE_ENTITY_NAME]:
                    if not silent:
                        print('STORING NEW NODE REPOSITORY FILES & ATTRIBUTES...')

                    # NEW NODES
                    for object_ in objects_to_create:
                        import_entry_uuid = object_.uuid
                        import_entry_pk = import_new_entry_pks[import_entry_uuid]

                        # Before storing entries in the DB, I store the files (if these are nodes).
                        # Note: only for new entries!
                        subfolder = folder.get_subfolder(
                            os.path.join(NODES_EXPORT_SUBFOLDER, export_shard_uuid(import_entry_uuid))
                        )
                        if not subfolder.exists():
                            raise exceptions.CorruptArchive(
                                'Unable to find the repository folder for Node with UUID={} in the exported '
                                'file'.format(import_entry_uuid)
                            )
                        destdir = RepositoryFolder(section=Repository._section_name, uuid=import_entry_uuid)
                        # Replace the folder, possibly destroying existing previous folders, and move the files
                        # (faster if we are on the same filesystem, and in any case the source is a SandboxFolder)
                        destdir.replace_with_folder(subfolder.abspath, move=True, overwrite=True)

                        # For Nodes, we also have to store Attributes!
                        # Get attributes from import file
                        try:
                            object_.attributes = data['node_attributes'][str(import_entry_pk)]
                        except KeyError:
                            raise exceptions.CorruptArchive(
                                'Unable to find attribute info for Node with UUID={}'.format(import_entry_uuid)
                            )

                        # For DbNodes, we also have to store extras
                        # Get extras from import file
                        if extras_mode_new == 'import':
                            if not silent:
                                print('STORING NEW NODE EXTRAS...')
                            try:
                                extras = data['node_extras'][str(import_entry_pk)]
                            except KeyError:
                                raise exceptions.CorruptArchive(
                                    'Unable to find extra info for Node with UUID={}'.format(import_entry_uuid)
                                )
                            # TODO: remove when aiida extras will be moved somewhere else
                            # from here
                            extras = {key: value for key, value in extras.items() if not key.startswith('_aiida_')}
                            if object_.node_type.endswith('code.Code.'):
                                extras = {key: value for key, value in extras.items() if not key == 'hidden'}
                            # till here
                            object_.extras = extras
                        elif extras_mode_new == 'none':
                            if not silent:
                                print('SKIPPING NEW NODE EXTRAS...')
                        else:
                            raise exceptions.ImportValidationError(
                                "Unknown extras_mode_new value: {}, should be either 'import' or 'none'"
                                ''.format(extras_mode_new)
                            )

                    # EXISTING NODES (Extras)
                    if not silent:
                        print('UPDATING EXISTING NODE EXTRAS (mode: {})'.format(extras_mode_existing))

                    import_existing_entry_pks = {
                        entry_data[unique_identifier]: import_entry_pk
                        for import_entry_pk, entry_data in existing_entries[entity_name].items()
                    }
                    for node in session.query(DbNode).filter(DbNode.uuid.in_(import_existing_entry_pks)).all():
                        import_entry_uuid = str(node.uuid)
                        import_entry_pk = import_existing_entry_pks[import_entry_uuid]

                        # Get extras from import file
                        try:
                            extras = data['node_extras'][str(import_entry_pk)]
                        except KeyError:
                            raise exceptions.CorruptArchive(
                                'Unable to find extra info for Node with UUID={}'.format(import_entry_uuid)
                            )

                        # TODO: remove when aiida extras will be moved somewhere else
                        # from here
                        extras = {key: value for key, value in extras.items() if not key.startswith('_aiida_')}
                        if node.node_type.endswith('code.Code.'):
                            extras = {key: value for key, value in extras.items() if not key == 'hidden'}
                        # till here
                        node.extras = merge_extras(node.extras, extras, extras_mode_existing)
                        flag_modified(node, 'extras')
                        objects_to_update.append(node)

                # Store them all in once; However, the PK are not set in this way...
                if objects_to_create:
                    session.add_all(objects_to_create)
                if objects_to_update:
                    session.add_all(objects_to_update)

                session.flush()

                if import_new_entry_pks.keys():
                    builder = QueryBuilder()
                    builder.append(
                        entity,
                        filters={unique_identifier: {
                            'in': list(import_new_entry_pks.keys())
                        }},
                        project=[unique_identifier, 'id'],
                        tag='res'
                    )
                    just_saved = {v[0]: v[1] for v in builder.all()}
                else:
                    just_saved = dict()

                # Now I have the PKs, print the info
                # Moreover, add newly created Nodes to foreign_ids_reverse_mappings
                for unique_id, new_pk in just_saved.items():
                    from uuid import UUID
                    if isinstance(unique_id, UUID):
                        unique_id = str(unique_id)
                    import_entry_pk = import_new_entry_pks[unique_id]
                    foreign_ids_reverse_mappings[entity_name][unique_id] = new_pk
                    if entity_name not in ret_dict:
                        ret_dict[entity_name] = {'new': [], 'existing': []}
                    ret_dict[entity_name]['new'].append((import_entry_pk, new_pk))

                    if not silent:
                        print('NEW %s: %s (%s->%s)' % (entity_sig, unique_id, import_entry_pk, new_pk))

            if not silent:
                print('STORING NODE LINKS...')

            import_links = data['links_uuid']

            for link in import_links:
                # Check for dangling Links within the, supposed, self-consistent archive
                try:
                    in_id = foreign_ids_reverse_mappings[NODE_ENTITY_NAME][link['input']]
                    out_id = foreign_ids_reverse_mappings[NODE_ENTITY_NAME][link['output']]
                except KeyError:
                    if ignore_unknown_nodes:
                        continue
                    else:
                        raise exceptions.ImportValidationError(
                            'Trying to create a link with one or both unknown nodes, stopping (in_uuid={}, '
                            'out_uuid={}, label={}, type={})'.format(
                                link['input'], link['output'], link['label'], link['type']
                            )
                        )

                # Since backend specific Links (DbLink) are not validated upon creation, we will now validate them.
                source = QueryBuilder().append(Node, filters={'id': in_id}, project='*').first()[0]
                target = QueryBuilder().append(Node, filters={'id': out_id}, project='*').first()[0]
                link_type = LinkType(link['type'])

                # Check for existence of a triple link, i.e. unique triple.
                # If it exists, then the link already exists, continue to next link, otherwise, validate link.
                if link_triple_exists(source, target, link_type, link['label']):
                    continue

                try:
                    validate_link(source, target, link_type, link['label'])
                except ValueError as why:
                    raise exceptions.ImportValidationError('Error occurred during Link validation: {}'.format(why))

                # New link
                session.add(DbLink(input_id=in_id, output_id=out_id, label=link['label'], type=link['type']))
                if 'Link' not in ret_dict:
                    ret_dict['Link'] = {'new': []}
                ret_dict['Link']['new'].append((in_id, out_id))

            if not silent:
                print('   ({} new links...)'.format(len(ret_dict.get('Link', {}).get('new', []))))

            if not silent:
                print('STORING GROUP ELEMENTS...')
            import_groups = data['groups_uuid']
            for groupuuid, groupnodes in import_groups.items():
                # # TODO: cache these to avoid too many queries
                qb_group = QueryBuilder().append(Group, filters={'uuid': {'==': groupuuid}})
                group_ = qb_group.first()[0]
                nodes_ids_to_add = [
                    foreign_ids_reverse_mappings[NODE_ENTITY_NAME][node_uuid] for node_uuid in groupnodes
                ]
                qb_nodes = QueryBuilder().append(Node, filters={'id': {'in': nodes_ids_to_add}})
                # Adding nodes to group avoiding the SQLA ORM to increase speed
                nodes_to_add = [n[0].backend_entity for n in qb_nodes.all()]
                group_.backend_entity.add_nodes(nodes_to_add, skip_orm=True)

            ######################################################
            # Put everything in a specific group
            ######################################################
            existing = existing_entries.get(NODE_ENTITY_NAME, {})
            existing_pk = [foreign_ids_reverse_mappings[NODE_ENTITY_NAME][v['uuid']] for v in six.itervalues(existing)]
            new = new_entries.get(NODE_ENTITY_NAME, {})
            new_pk = [foreign_ids_reverse_mappings[NODE_ENTITY_NAME][v['uuid']] for v in six.itervalues(new)]

            pks_for_group = existing_pk + new_pk

            # So that we do not create empty groups
            if pks_for_group:
                # If user specified a group, import all things into it
                if not group:
                    from aiida.backends.sqlalchemy.models.group import DbGroup

                    # Get an unique name for the import group, based on the current (local) time
                    basename = timezone.localtime(timezone.now()).strftime('%Y%m%d-%H%M%S')
                    counter = 0
                    group_label = basename
                    while session.query(DbGroup).filter(DbGroup.label == group_label).count() > 0:
                        counter += 1
                        group_label = '{}_{}'.format(basename, counter)

                        if counter == 100:
                            raise exceptions.ImportUniquenessError(
                                "Overflow of import groups (more than 100 import groups exists with basename '{}')"
                                ''.format(basename)
                            )
                    group = Group(label=group_label, type_string=IMPORTGROUP_TYPE)
                    session.add(group.backend_entity._dbmodel)

                # Adding nodes to group avoiding the SQLA ORM to increase speed
                nodes = [
                    entry[0].backend_entity
                    for entry in QueryBuilder().append(Node, filters={
                        'id': {
                            'in': pks_for_group
                        }
                    }).all()
                ]
                group.backend_entity.add_nodes(nodes, skip_orm=True)
                if not silent:
                    print("IMPORTED NODES ARE GROUPED IN THE IMPORT GROUP LABELED '{}'".format(group.label))
            else:
                if not silent:
                    print('NO NODES TO IMPORT, SO NO GROUP CREATED, IF IT DID NOT ALREADY EXIST')

            if not silent:
                print('COMMITTING EVERYTHING...')
            session.commit()
        except:
            if not silent:
                print('Rolling back')
            session.rollback()
            raise

    if not silent:
        print('*** WARNING: MISSING EXISTING UUID CHECKS!!')
        print('*** WARNING: TODO: UPDATE IMPORT_DATA WITH DEFAULT VALUES! (e.g. calc status, user pwd, ...)')
        print('DONE.')

    return ret_dict
    [
        alat,
        0.,
        0.,
    ],
    [
        0.,
        alat,
        0.,
    ],
    [
        0.,
        0.,
        alat,
    ],
]

# BaTiO3 cubic structure
s = StructureData(cell=cell)
s.append_atom(position=(0., 0., 0.), symbols=['Ba'])
s.append_atom(position=(alat / 2., alat / 2., alat / 2.), symbols=['Ti'])
s.append_atom(position=(alat / 2., alat / 2., 0.), symbols=['O'])
s.append_atom(position=(alat / 2., 0., alat / 2.), symbols=['O'])
s.append_atom(position=(0., alat / 2., alat / 2.), symbols=['O'])

g = Group(name="input_group").store()
g.add_nodes(s.store())

w = TestWorkChain
run(w, structure=s.store(), code=code)