Beispiel #1
0
    def create_cif_data(cls):
        with tempfile.NamedTemporaryFile() as f:
            filename = f.name
            f.write(cls.valid_sample_cif_str)
            f.flush()
            a = CifData(file=filename,
                        source={
                            'version': '1234',
                            'db_name': 'COD',
                            'id': '0000001'
                        })
            a.store()

            g_ne = Group(name='non_empty_group')
            g_ne.store()
            g_ne.add_nodes(a)

            g_e = Group(name='empty_group')
            g_e.store()

        return {
            TestVerdiDataListable.NODE_ID_STR: a.id,
            TestVerdiDataListable.NON_EMPTY_GROUP_ID_STR: g_ne.id,
            TestVerdiDataListable.EMPTY_GROUP_ID_STR: g_e.id
        }
Beispiel #2
0
    def results(self):
        """Attach the output parameters and structure of the last workchain to the outputs."""
        if self.ctx.is_converged and self.ctx.iteration <= self.inputs.max_meta_convergence_iterations.value:
            self.report('workchain completed after {} iterations'.format(
                self.ctx.iteration))
        else:
            self.report(
                'maximum number of meta convergence iterations exceeded')

        # Get the latest workchain, which is either the workchain_scf if it ran or otherwise the last regular workchain
        try:
            workchain = self.ctx.workchain_scf
            structure = workchain.inp.structure
        except AttributeError:
            workchain = self.ctx.workchains[-1]
            structure = workchain.out.output_structure

        if 'group' in self.inputs:
            # Retrieve the final successful PwCalculation through the output_parameters of the PwBaseWorkChain
            try:
                calculation = workchain.out.output_parameters.get_inputs(
                    node_type=PwCalculation)[0]
            except (AttributeError, IndexError):
                self.report(
                    'could not retrieve the last run PwCalculation to add to the result group'
                )
            else:
                group, _ = Group.get_or_create(name=self.inputs.group.value)
                group.add_nodes(calculation)
                self.report(
                    "storing the final PwCalculation<{}> in the group '{}'".
                    format(calculation.pk, self.inputs.group.value))

        self.out_many(self.exposed_outputs(workchain, PwBaseWorkChain))
        self.out('output_structure', structure)
Beispiel #3
0
    def create_structure_data():
        from aiida.orm.data.structure import StructureData, Site, Kind
        from aiida.orm.group import Group

        alat = 4.  # angstrom
        cell = [
            [
                alat,
                0.,
                0.,
            ],
            [
                0.,
                alat,
                0.,
            ],
            [
                0.,
                0.,
                alat,
            ],
        ]

        # BaTiO3 cubic structure
        struc = StructureData(cell=cell)
        struc.append_atom(position=(0., 0., 0.), symbols='Ba')
        struc.append_atom(position=(alat / 2., alat / 2., alat / 2.),
                          symbols='Ti')
        struc.append_atom(position=(alat / 2., alat / 2., 0.), symbols='O')
        struc.append_atom(position=(alat / 2., 0., alat / 2.), symbols='O')
        struc.append_atom(position=(0., alat / 2., alat / 2.), symbols='O')
        struc.store()

        # Create 2 groups and add the data to one of them
        g_ne = Group(name='non_empty_group')
        g_ne.store()
        g_ne.add_nodes(struc)

        g_e = Group(name='empty_group')
        g_e.store()

        return {
            TestVerdiDataListable.NODE_ID_STR: struc.id,
            TestVerdiDataListable.NON_EMPTY_GROUP_ID_STR: g_ne.id,
            TestVerdiDataListable.EMPTY_GROUP_ID_STR: g_e.id
        }
Beispiel #4
0
    def test_group_batch_size(self):
        """
        Test that the group addition in batches works as expected.
        """
        from aiida.orm.group import Group
        from aiida.orm.data import Data

        # Create 100 nodes
        nodes = []
        for _ in range(100):
            nodes.append(Data().store())

        # Add nodes to groups using different batch size. Check in the end the
        # correct addition.
        batch_sizes = (1, 3, 10, 1000)
        for batch_size in batch_sizes:
            group = Group(name='test_batches_' + str(batch_size)).store()
            group.add_nodes(nodes, skip_orm=True, batch_size=batch_size)
            self.assertEqual(set(_.pk for _ in nodes),
                             set(_.pk for _ in group.nodes))
Beispiel #5
0
    def test_creation_from_dbgroup(self):
        n = Node().store()

        g = Group(name='testgroup_from_dbgroup')
        g.store()
        g.add_nodes(n)

        dbgroup = g.dbgroup

        with self.assertRaises(ValueError):
            # Cannot pass more parameters, even if valid, if
            # dbgroup is specified
            Group(dbgroup=dbgroup, name="test")

        gcopy = Group(dbgroup=dbgroup)

        self.assertEquals(g.pk, gcopy.pk)
        self.assertEquals(g.uuid, gcopy.uuid)

        # To avoid to find it in further tests
        g.delete()
Beispiel #6
0
    def setUpClass(cls, *args, **kwargs):
        """
        Create some data needed for the tests
        """
        super(TestVerdiDataCommands, cls).setUpClass()

        from aiida.orm.user import User
        from aiida.orm.group import Group

        # Create a secondary user
        new_email = "[email protected]"
        new_user = User(email=new_email)
        new_user.force_save()

        # Create a group to add specific data inside
        g1 = Group(name=cls.group_name)
        g1.store()
        cls.group_id = g1.id

        cls.create_bands_data(cls.cmd_to_nodeid_map,
                              cls.cmd_to_nodeid_map_for_groups,
                              cls.cmd_to_nodeid_map_for_nuser, g1, new_user)

        cls.create_structure_data(cls.cmd_to_nodeid_map,
                                  cls.cmd_to_nodeid_map_for_groups,
                                  cls.cmd_to_nodeid_map_for_nuser, g1,
                                  new_user)

        cls.create_cif_data(cls.cmd_to_nodeid_map,
                            cls.cmd_to_nodeid_map_for_groups,
                            cls.cmd_to_nodeid_map_for_nuser, g1, new_user)

        cls.create_trajectory_data(cls.cmd_to_nodeid_map,
                                   cls.cmd_to_nodeid_map_for_groups,
                                   cls.cmd_to_nodeid_map_for_nuser, g1,
                                   new_user)
Beispiel #7
0
    def results(self):
        """Attach the desired output nodes directly as outputs of the workchain."""
        self.report('workchain succesfully completed')
        self.out('scf_parameters',
                 self.ctx.workchain_scf.out.output_parameters)
        self.out('band_parameters',
                 self.ctx.workchain_bands.out.output_parameters)
        self.out('band_structure', self.ctx.workchain_bands.out.output_band)

        if 'group' in self.inputs:
            output_band = self.ctx.workchain_bands.out.output_band
            group, _ = Group.get_or_create(name=self.inputs.group.value)
            group.add_nodes(output_band)
            self.report("storing the output_band<{}> in the group '{}'".format(
                output_band.pk, self.inputs.group.value))
Beispiel #8
0
def load_example_structures():
    """ Read input structures into the database

    Structures are read from subfolder "example-structures"
    and stored in the group "example-structures".

    :return: group of available structures
    """
    from aiida.orm.group import Group

    try:
        group = Group.get(name=group_name)

    except NotExistent:
        import glob
        import os
        from ase.io import read
        from aiida.orm.data.structure import StructureData

        paths = glob.glob(group_name + '/*.cif')

        structure_nodes = []
        for path in paths:
            fname = os.path.basename(path)
            name = os.path.splitext(fname)[0]

            structure = StructureData(ase=read(path))
            if "ML" in name:
                # surface normal of monolayers should be oriented along z
                structure.set_pbc([True, True, False])
            else:
                structure.set_pbc([True, True, True])
            structure.label = name
            print("Storing {} in database".format(name))
            structure.store()
            structure_nodes.append(structure)

        group = Group(name=group_name)
        group.store()
        group.description = "\
        Set of atomic structures used by examples for AiiDA plugins of different codes"

        group.add_nodes(structure_nodes)

    return group
Beispiel #9
0
    def test_name_desc(self):
        g = Group(name='testgroup2', description='some desc')
        self.assertEquals(g.name, 'testgroup2')
        self.assertEquals(g.description, 'some desc')
        self.assertTrue(g.is_user_defined)
        g.store()
        # Same checks after storing
        self.assertEquals(g.name, 'testgroup2')
        self.assertTrue(g.is_user_defined)
        self.assertEquals(g.description, 'some desc')

        # To avoid to find it in further tests
        g.delete()
Beispiel #10
0
    def _list_calculations_old(cls, states=None, past_days=None, group=None,
                               group_pk=None, all_users=False, pks=[],
                               relative_ctime=True):
        """
        Return a string with a description of the AiiDA calculations.

        .. todo:: does not support the query for the IMPORTED state (since it
          checks the state in the Attributes, not in the DbCalcState table).
          Decide which is the correct logic and implement the correct query.

        :param states: a list of string with states. If set, print only the
            calculations in the states "states", otherwise shows all.
            Default = None.
        :param past_days: If specified, show only calculations that were
            created in the given number of past days.
        :param group: If specified, show only calculations belonging to a
            user-defined group with the given name.
            Can use colons to separate the group name from the type,
            as specified in :py:meth:`aiida.orm.group.Group.get_from_string`
            method.
        :param group_pk: If specified, show only calculations belonging to a
            user-defined group with the given PK.
        :param pks: if specified, must be a list of integers, and only
            calculations within that list are shown. Otherwise, all
            calculations are shown.
            If specified, sets state to None and ignores the
            value of the ``past_days`` option.")
        :param relative_ctime: if true, prints the creation time relative from now.
                               (like 2days ago). Default = True
        :param all_users: if True, list calculation belonging to all users.
                           Default = False

        :return: a string with description of calculations.
        """
        # I assume that calc_states are strings. If this changes in the future,
        # update the filter below from dbattributes__tval to the correct field.
        from aiida.backends.djsite.db.models import DbAuthInfo, DbAttribute
        from aiida.daemon.timestamps import get_last_daemon_timestamp

        if states:
            for state in states:
                if state not in calc_states:
                    return "Invalid state provided: {}.".format(state)

        warnings_list = []

        now = timezone.now()

        if pks:
            q_object = Q(pk__in=pks)
        else:
            q_object = Q()

            if group is not None:
                g_pk = Group.get_from_string(group).pk
                q_object.add(Q(dbgroups__pk=g_pk), Q.AND)

            if group_pk is not None:
                q_object.add(Q(dbgroups__pk=group_pk), Q.AND)

            if not all_users:
                q_object.add(Q(user=get_automatic_user()), Q.AND)

            if states is not None:
                q_object.add(Q(dbattributes__key='state',
                               dbattributes__tval__in=states, ), Q.AND)
            if past_days is not None:
                now = timezone.now()
                n_days_ago = now - datetime.timedelta(days=past_days)
                q_object.add(Q(ctime__gte=n_days_ago), Q.AND)

        calc_list_pk = list(
            cls.query(q_object).distinct().values_list('pk', flat=True))

        calc_list = cls.query(pk__in=calc_list_pk).order_by('ctime')

        scheduler_states = dict(
            DbAttribute.objects.filter(dbnode__pk__in=calc_list_pk,
                                       key='scheduler_state').values_list(
                'dbnode__pk', 'tval'))

        # I do the query now, so that the list of pks gets cached
        calc_list_data = list(
            calc_list.filter(
                # dbcomputer__dbauthinfo__aiidauser=F('user')
            ).distinct().order_by('ctime').values(
                'pk', 'dbcomputer__name', 'ctime',
                'type', 'dbcomputer__enabled',
                'dbcomputer__pk',
                'user__pk'))
        list_comp_pk = [i['dbcomputer__pk'] for i in calc_list_data]
        list_aiduser_pk = [i['user__pk']
                           for i in calc_list_data]
        enabled_data = DbAuthInfo.objects.filter(
            dbcomputer__pk__in=list_comp_pk, aiidauser__pk__in=list_aiduser_pk
        ).values_list('dbcomputer__pk', 'aiidauser__pk', 'enabled')

        enabled_auth_dict = {(i[0], i[1]): i[2] for i in enabled_data}

        states = {c.pk: c._get_state_string() for c in calc_list}

        scheduler_lastcheck = dict(DbAttribute.objects.filter(
            dbnode__in=calc_list,
            key='scheduler_lastchecktime').values_list('dbnode__pk', 'dval'))

        ## Get the last daemon check
        try:
            last_daemon_check = get_last_daemon_timestamp('updater',
                                                          when='stop')
        except ValueError:
            last_check_string = ("# Last daemon state_updater check: "
                                 "(Error while retrieving the information)")
        else:
            if last_daemon_check is None:
                last_check_string = "# Last daemon state_updater check: (Never)"
            else:
                last_check_string = ("# Last daemon state_updater check: "
                                     "{} ({})".format(
                    str_timedelta(now - last_daemon_check,
                                  negative_to_zero=True),
                    timezone.localtime(last_daemon_check).strftime(
                        "at %H:%M:%S on %Y-%m-%d")))

        disabled_ignorant_states = [
            None, calc_states.FINISHED, calc_states.SUBMISSIONFAILED,
            calc_states.RETRIEVALFAILED, calc_states.PARSINGFAILED,
            calc_states.FAILED
        ]

        if not calc_list:
            return last_check_string
        else:
            # first save a matrix of results to be printed
            res_str_list = [last_check_string]
            str_matrix = []
            title = ['# Pk', 'State', 'Creation',
                     'Sched. state', 'Computer', 'Type']
            str_matrix.append(title)
            len_title = [len(i) for i in title]

            for calcdata in calc_list_data:
                remote_state = "None"

                calc_state = states[calcdata['pk']]
                remote_computer = calcdata['dbcomputer__name']
                try:
                    sched_state = scheduler_states.get(calcdata['pk'], None)
                    if sched_state is None:
                        remote_state = "(unknown)"
                    else:
                        remote_state = '{}'.format(sched_state)
                        if calc_state == calc_states.WITHSCHEDULER:
                            last_check = scheduler_lastcheck.get(calcdata['pk'],
                                                                 None)
                            if last_check is not None:
                                when_string = " {}".format(
                                    str_timedelta(now - last_check, short=True,
                                                  negative_to_zero=True))
                                verb_string = "was "
                            else:
                                when_string = ""
                                verb_string = ""
                            remote_state = "{}{}{}".format(verb_string,
                                                           sched_state,
                                                           when_string)
                except ValueError:
                    raise

                calc_module = \
                from_type_to_pluginclassname(calcdata['type']).rsplit(".", 1)[0]
                prefix = 'calculation.job.'
                prefix_len = len(prefix)
                if calc_module.startswith(prefix):
                    calc_module = calc_module[prefix_len:].strip()

                if relative_ctime:
                    calc_ctime = str_timedelta(now - calcdata['ctime'],
                                               negative_to_zero=True,
                                               max_num_fields=1)
                else:
                    calc_ctime = " ".join([timezone.localtime(
                        calcdata['ctime']).isoformat().split('T')[0],
                                           timezone.localtime(calcdata[
                                                                  'ctime']).isoformat().split(
                                               'T')[1].split('.')[
                                               0].rsplit(":", 1)[0]])

                the_state = states[calcdata['pk']]

                # decide if it is needed to print enabled/disabled information
                # By default, if the computer is not configured for the
                # given user, assume it is user_enabled
                user_enabled = enabled_auth_dict.get(
                    (calcdata['dbcomputer__pk'],
                     calcdata['user__pk']), True)
                global_enabled = calcdata["dbcomputer__enabled"]

                enabled = "" if (user_enabled and global_enabled or
                                 the_state in disabled_ignorant_states) else " [Disabled]"

                str_matrix.append([calcdata['pk'],
                                   the_state,
                                   calc_ctime,
                                   remote_state,
                                   remote_computer + "{}".format(enabled),
                                   calc_module
                                   ])

            # prepare a formatted text of minimal row length (to fit in terminals!)
            rows = []
            for j in range(len(str_matrix[0])):
                rows.append([len(str(i[j])) for i in str_matrix])
            line_lengths = [str(max(max(rows[i]), len_title[i])) for i in
                            range(len(rows))]
            fmt_string = "{:<" + "}|{:<".join(line_lengths) + "}"
            for row in str_matrix:
                res_str_list.append(fmt_string.format(*[str(i) for i in row]))

            res_str_list += ["# {}".format(_) for _ in warnings_list]
            return "\n".join(res_str_list)
Beispiel #11
0
    [
        alat,
        0.,
        0.,
    ],
    [
        0.,
        alat,
        0.,
    ],
    [
        0.,
        0.,
        alat,
    ],
]

# BaTiO3 cubic structure
s = StructureData(cell=cell)
s.append_atom(position=(0., 0., 0.), symbols=['Ba'])
s.append_atom(position=(alat / 2., alat / 2., alat / 2.), symbols=['Ti'])
s.append_atom(position=(alat / 2., alat / 2., 0.), symbols=['O'])
s.append_atom(position=(alat / 2., 0., alat / 2.), symbols=['O'])
s.append_atom(position=(0., alat / 2., alat / 2.), symbols=['O'])

g = Group.create(name="input_group")
g.add_nodes(s.store())

w = TestWorkChain
run(w, structure=s.store(), code=code)
Beispiel #12
0
    def test_query(self):
        """
        Test if queries are working
        """
        from aiida.orm.group import Group
        from aiida.common.exceptions import NotExistent, MultipleObjectsError
        from aiida.backends.djsite.db.models import DbUser
        from aiida.backends.djsite.utils import get_automatic_user

        g1 = Group(name='testquery1').store()
        g2 = Group(name='testquery2').store()

        n1 = Node().store()
        n2 = Node().store()
        n3 = Node().store()
        n4 = Node().store()

        g1.add_nodes([n1, n2])
        g2.add_nodes([n1, n3])

        newuser = DbUser.objects.create_user(email='*****@*****.**', password='')
        g3 = Group(name='testquery3', user=newuser).store()

        # I should find it
        g1copy = Group.get(uuid=g1.uuid)
        self.assertEquals(g1.pk, g1copy.pk)

        # Try queries
        res = Group.query(nodes=n4)
        self.assertEquals([_.pk for _ in res], [])

        res = Group.query(nodes=n1)
        self.assertEquals([_.pk for _ in res], [_.pk for _ in [g1, g2]])

        res = Group.query(nodes=n2)
        self.assertEquals([_.pk for _ in res], [_.pk for _ in [g1]])

        # I try to use 'get' with zero or multiple results
        with self.assertRaises(NotExistent):
            Group.get(nodes=n4)
        with self.assertRaises(MultipleObjectsError):
            Group.get(nodes=n1)

        self.assertEquals(Group.get(nodes=n2).pk, g1.pk)

        # Query by user
        res = Group.query(user=newuser)
        self.assertEquals(set(_.pk for _ in res), set(_.pk for _ in [g3]))

        # Same query, but using a string (the username=email) instead of
        # a DbUser object
        res = Group.query(user=newuser.email)
        self.assertEquals(set(_.pk for _ in res), set(_.pk for _ in [g3]))

        res = Group.query(user=get_automatic_user())
        self.assertEquals(set(_.pk for _ in res), set(_.pk for _ in [g1, g2]))

        # Final cleanup
        g1.delete()
        g2.delete()
        newuser.delete()
Beispiel #13
0
    def test_delete(self):
        from aiida.orm.group import Group
        from aiida.common.exceptions import NotExistent

        n = Node().store()

        g = Group(name='testgroup3', description='some other desc')
        g.store()

        gcopy = Group.get(name='testgroup3')
        self.assertEquals(g.uuid, gcopy.uuid)

        g.add_nodes(n)
        self.assertEquals(len(g.nodes), 1)

        g.delete()

        with self.assertRaises(NotExistent):
            # The group does not exist anymore
            Group.get(name='testgroup3')

        # I should be able to restore it
        g.store()

        # Now, however, by deleting and recreating it, I lost the elements
        self.assertEquals(len(g.nodes), 0)
        self.assertEquals(g.name, 'testgroup3')
        self.assertEquals(g.description, 'some other desc')
        self.assertTrue(g.is_user_defined)

        # To avoid to find it in further tests
        g.delete()
Beispiel #14
0
    def test_remove_nodes(self):
        """
        Test node removal
        """
        from aiida.orm.group import Group

        n1 = Node().store()
        n2 = Node().store()
        n3 = Node().store()
        n4 = Node().store()
        n5 = Node().store()
        n6 = Node().store()
        n7 = Node().store()
        n8 = Node().store()
        n_out = Node().store()

        g = Group(name='test_remove_nodes').store()

        # Add initial nodes
        g.add_nodes([n1, n2, n3, n4, n5, n6, n7, n8])
        # Check
        self.assertEquals(
            set([_.pk for _ in [n1, n2, n3, n4, n5, n6, n7, n8]]),
            set([_.pk for _ in g.nodes]))

        # Remove a node that is not in the group: nothing should happen
        # (same behavior of Django)
        g.remove_nodes(n_out)
        # Re-check
        self.assertEquals(
            set([_.pk for _ in [n1, n2, n3, n4, n5, n6, n7, n8]]),
            set([_.pk for _ in g.nodes]))

        # Remove one Node and check
        g.remove_nodes(n4)
        self.assertEquals(set([_.pk for _ in [n1, n2, n3, n5, n6, n7, n8]]),
                          set([_.pk for _ in g.nodes]))
        # Remove one DbNode and check
        g.remove_nodes(n7.dbnode)
        self.assertEquals(set([_.pk for _ in [n1, n2, n3, n5, n6, n8]]),
                          set([_.pk for _ in g.nodes]))
        # Remove a list of Nodes and check
        g.remove_nodes([n1, n8])
        self.assertEquals(set([_.pk for _ in [n2, n3, n5, n6]]),
                          set([_.pk for _ in g.nodes]))
        # Remove a list of Nodes and check
        g.remove_nodes([n1, n8])
        self.assertEquals(set([_.pk for _ in [n2, n3, n5, n6]]),
                          set([_.pk for _ in g.nodes]))
        # Remove a list of DbNodes and check
        g.remove_nodes([n2.dbnode, n5.dbnode])
        self.assertEquals(set([_.pk for _ in [n3, n6]]),
                          set([_.pk for _ in g.nodes]))

        # Remove a mixed list of Nodes and DbNodes and check
        g.remove_nodes([n3, n6.dbnode])
        self.assertEquals(set(), set([_.pk for _ in g.nodes]))

        # Cleanup
        g.delete()
Beispiel #15
0
    def test_add_nodes(self):
        """
        Test different ways of adding nodes
        """
        from aiida.orm.group import Group

        n1 = Node().store()
        n2 = Node().store()
        n3 = Node().store()
        n4 = Node().store()
        n5 = Node().store()
        n6 = Node().store()
        n7 = Node().store()
        n8 = Node().store()

        g = Group(name='test_adding_nodes')
        g.store()
        # Single node
        g.add_nodes(n1)
        # List of nodes
        g.add_nodes([n2, n3])
        # Single DbNode
        g.add_nodes(n4.dbnode)
        # List of DbNodes
        g.add_nodes([n5.dbnode, n6.dbnode])
        # List of Nodes and DbNodes
        g.add_nodes([n7, n8.dbnode])

        # Check
        self.assertEquals(
            set([_.pk for _ in [n1, n2, n3, n4, n5, n6, n7, n8]]),
            set([_.pk for _ in g.nodes]))

        # Try to add a node that is already present: there should be no problem
        g.add_nodes(n1)
        self.assertEquals(
            set([_.pk for _ in [n1, n2, n3, n4, n5, n6, n7, n8]]),
            set([_.pk for _ in g.nodes]))

        # Cleanup
        g.delete()
Beispiel #16
0
    def test_description(self):
        """
        Test the update of the description both for stored and unstored
        groups.
        """
        from aiida.orm.group import Group

        n = Node().store()

        g1 = Group(name='testgroupdescription1', description="g1").store()
        g1.add_nodes(n)

        g2 = Group(name='testgroupdescription2', description="g2")

        # Preliminary checks
        self.assertTrue(g1.is_stored)
        self.assertFalse(g2.is_stored)
        self.assertEquals(g1.description, "g1")
        self.assertEquals(g2.description, "g2")

        # Change
        g1.description = "new1"
        g2.description = "new2"

        # Test that the groups remained in their proper stored state and that
        # the description was updated
        self.assertTrue(g1.is_stored)
        self.assertFalse(g2.is_stored)
        self.assertEquals(g1.description, "new1")
        self.assertEquals(g2.description, "new2")

        # Store g2 and check that the description is OK
        g2.store()
        self.assertTrue(g2.is_stored)
        self.assertEquals(g2.description, "new2")

        # clean-up
        g1.delete()
        g2.delete()
Beispiel #17
0
    def test_creation(self):
        from aiida.orm.group import Group

        n = Node()
        stored_n = Node().store()

        with self.assertRaises(ValueError):
            # No name specified
            g = Group()

        g = Group(name='testgroup')

        with self.assertRaises(ValueError):
            # Too many parameters
            g = Group(name='testgroup', not_existing_kwarg=True)

        with self.assertRaises(ModificationNotAllowed):
            # g unstored
            g.add_nodes(n)

        with self.assertRaises(ModificationNotAllowed):
            # g unstored
            g.add_nodes(stored_n)

        g.store()

        with self.assertRaises(ValueError):
            # n unstored
            g.add_nodes(n)

        g.add_nodes(stored_n)

        nodes = list(g.nodes)
        self.assertEquals(len(nodes), 1)
        self.assertEquals(nodes[0].pk, stored_n.pk)

        # To avoid to find it in further tests
        g.delete()
Beispiel #18
0
    def create_structure_bands():
        alat = 4.  # angstrom
        cell = [
            [
                alat,
                0.,
                0.,
            ],
            [
                0.,
                alat,
                0.,
            ],
            [
                0.,
                0.,
                alat,
            ],
        ]
        s = StructureData(cell=cell)
        s.append_atom(position=(0., 0., 0.), symbols='Fe')
        s.append_atom(position=(alat / 2., alat / 2., alat / 2.), symbols='O')
        s.store()

        @wf
        def connect_structure_bands(structure):
            alat = 4.
            cell = np.array([
                [alat, 0., 0.],
                [0., alat, 0.],
                [0., 0., alat],
            ])

            k = KpointsData()
            k.set_cell(cell)
            k.set_kpoints_path([('G', 'M', 2)])

            b = BandsData()
            b.set_kpointsdata(k)
            b.set_bands([[1.0, 2.0], [3.0, 4.0]])

            k.store()
            b.store()

            return b

        b = connect_structure_bands(s)

        # Create 2 groups and add the data to one of them
        g_ne = Group(name='non_empty_group')
        g_ne.store()
        g_ne.add_nodes(b)

        g_e = Group(name='empty_group')
        g_e.store()

        return {
            TestVerdiDataListable.NODE_ID_STR: b.id,
            TestVerdiDataListable.NON_EMPTY_GROUP_ID_STR: g_ne.id,
            TestVerdiDataListable.EMPTY_GROUP_ID_STR: g_e.id
        }
Beispiel #19
0
    def test_rename_existing(self):
        """
        Test that renaming to an already existing name is not permitted
        """
        from aiida.orm.group import Group

        name_group_a = 'group_a'
        name_group_b = 'group_b'
        name_group_c = 'group_c'

        group_a = Group(name=name_group_a, description='I am the Original G')
        group_a.store()

        # Before storing everything should be fine
        group_b = Group(name=name_group_a,
                        description='They will try to rename me')
        group_c = Group(name=name_group_c,
                        description='They will try to rename me')

        # Storing for duplicate group name should trigger UniquenessError
        with self.assertRaises(exceptions.IntegrityError):
            group_b.store()

        # Before storing everything should be fine
        group_c.name = name_group_a

        # Reverting to unique name before storing
        group_c.name = name_group_c
        group_c.store()

        # After storing name change to existing should raise
        with self.assertRaises(exceptions.IntegrityError):
            group_c.name = name_group_a
Beispiel #20
0
    def test_rename_existing(self):
        """
        Test that renaming to an already existing name is not permitted
        """
        from aiida.backends.sqlalchemy import get_scoped_session
        from aiida.orm.group import Group

        name_group_a = 'group_a'
        name_group_b = 'group_b'
        name_group_c = 'group_c'

        group_a = Group(name=name_group_a, description='I am the Original G')
        group_a.store()

        # Before storing everything should be fine
        group_b = Group(name=name_group_a, description='They will try to rename me')
        group_c = Group(name=name_group_c, description='They will try to rename me')

        session = get_scoped_session()

        # Storing for duplicate group name should trigger Integrity
        try:
            session.begin_nested()
            with self.assertRaises(exceptions.IntegrityError):
                group_b.store()
        finally:
            session.rollback()

        # Before storing everything should be fine
        group_c.name = name_group_a

        # Reverting to unique name before storing
        group_c.name = name_group_c
        group_c.store()

        # After storing name change to existing should raise
        try:
            session.begin_nested()
            with self.assertRaises(exceptions.IntegrityError):
                group_c.name = name_group_a
        finally:
            session.rollback()
Beispiel #21
0
def import_data(in_path, ignore_unknown_nodes=False, silent=False):
    """
    Import exported AiiDA environment to the AiiDA database.
    If the 'in_path' is a folder, calls export_tree; otherwise, tries to
    detect the compression format (zip, tar.gz, tar.bz2, ...) and calls the
    correct function.

    :param in_path: the path to a file or folder that can be imported in AiiDA
    """
    import json
    import os
    import tarfile
    import zipfile
    from itertools import chain

    from django.db import transaction
    from aiida.utils import timezone

    from aiida.orm.node import Node
    from aiida.orm.group import Group
    from aiida.common.exceptions import UniquenessError
    from aiida.common.folders import SandboxFolder, RepositoryFolder
    from aiida.backends.djsite.db import models
    from aiida.common.utils import get_class_string, get_object_from_string
    from aiida.common.datastructures import calc_states

    # This is the export version expected by this function
    expected_export_version = '0.1'

    # The name of the subfolder in which the node files are stored
    nodes_export_subfolder = 'nodes'

    # The returned dictionary with new and existing nodes and links
    ret_dict = {}

    ################
    # EXTRACT DATA #
    ################
    # The sandbox has to remain open until the end
    with SandboxFolder() as folder:
        if os.path.isdir(in_path):
            extract_tree(in_path, folder, silent=silent)
        else:
            if tarfile.is_tarfile(in_path):
                extract_tar(in_path,
                            folder,
                            silent=silent,
                            nodes_export_subfolder=nodes_export_subfolder)
            elif zipfile.is_zipfile(in_path):
                extract_zip(in_path,
                            folder,
                            silent=silent,
                            nodes_export_subfolder=nodes_export_subfolder)
            else:
                raise ValueError(
                    "Unable to detect the input file format, it "
                    "is neither a (possibly compressed) tar file, "
                    "nor a zip file.")

        try:
            with open(folder.get_abs_path('metadata.json')) as f:
                metadata = json.load(f)

            with open(folder.get_abs_path('data.json')) as f:
                data = json.load(f)
        except IOError as e:
            raise ValueError("Unable to find the file {} in the import "
                             "file or folder".format(e.filename))

        ######################
        # PRELIMINARY CHECKS #
        ######################
        if metadata['export_version'] != expected_export_version:
            raise ValueError(
                "File export version is {}, but I can import only "
                "version {}".format(metadata['export_version'],
                                    expected_export_version))

        ##########################################################################
        # CREATE UUID REVERSE TABLES AND CHECK IF I HAVE ALL NODES FOR THE LINKS #
        ##########################################################################
        linked_nodes = set(
            chain.from_iterable(
                (l['input'], l['output']) for l in data['links_uuid']))
        group_nodes = set(chain.from_iterable(
            data['groups_uuid'].itervalues()))

        # I preload the nodes, I need to check each of them later, and I also
        # store them in a reverse table
        # I break up the query due to SQLite limitations..
        relevant_db_nodes = {}
        for group in grouper(999, linked_nodes):
            relevant_db_nodes.update({
                n.uuid: n
                for n in models.DbNode.objects.filter(uuid__in=group)
            })

        db_nodes_uuid = set(relevant_db_nodes.keys())
        dbnode_model = get_class_string(models.DbNode)
        import_nodes_uuid = set(
            v['uuid'] for v in data['export_data'][dbnode_model].values())

        unknown_nodes = linked_nodes.union(group_nodes) - db_nodes_uuid.union(
            import_nodes_uuid)

        if unknown_nodes and not ignore_unknown_nodes:
            raise ValueError(
                "The import file refers to {} nodes with unknown UUID, therefore "
                "it cannot be imported. Either first import the unknown nodes, "
                "or export also the parents when exporting. The unknown UUIDs "
                "are:\n".format(len(unknown_nodes)) +
                "\n".join('* {}'.format(uuid) for uuid in unknown_nodes))

        ###################################
        # DOUBLE-CHECK MODEL DEPENDENCIES #
        ###################################
        # I hardcode here the model order, for simplicity; in any case, this is
        # fixed by the export version
        model_order = [
            get_class_string(m) for m in (
                models.DbUser,
                models.DbComputer,
                models.DbNode,
                models.DbGroup,
            )
        ]

        # Models that do appear in the import file, but whose import is
        # managed manually
        model_manual = [
            get_class_string(m) for m in (
                models.DbLink,
                models.DbAttribute,
            )
        ]

        all_known_models = model_order + model_manual

        for import_field_name in metadata['all_fields_info']:
            if import_field_name not in all_known_models:
                raise NotImplementedError(
                    "Apparently, you are importing a "
                    "file with a model '{}', but this does not appear in "
                    "all_known_models!".format(import_field_name))

        for idx, model_name in enumerate(model_order):
            dependencies = []
            for field in metadata['all_fields_info'][model_name].values():
                try:
                    dependencies.append(field['requires'])
                except KeyError:
                    # (No ForeignKey)
                    pass
            for dependency in dependencies:
                if dependency not in model_order[:idx]:
                    raise ValueError(
                        "Model {} requires {} but would be loaded "
                        "first; stopping...".format(model_name, dependency))

        ###################################################
        # CREATE IMPORT DATA DIRECT UNIQUE_FIELD MAPPINGS #
        ###################################################
        import_unique_ids_mappings = {}
        for model_name, import_data in data['export_data'].iteritems():
            if model_name in metadata['unique_identifiers']:
                # I have to reconvert the pk to integer
                import_unique_ids_mappings[model_name] = {
                    int(k): v[metadata['unique_identifiers'][model_name]]
                    for k, v in import_data.iteritems()
                }

        ###############
        # IMPORT DATA #
        ###############
        # DO ALL WITH A TRANSACTION
        with transaction.commit_on_success():
            foreign_ids_reverse_mappings = {}
            new_entries = {}
            existing_entries = {}

            # I first generate the list of data
            for model_name in model_order:
                Model = get_object_from_string(model_name)
                fields_info = metadata['all_fields_info'].get(model_name, {})
                unique_identifier = metadata['unique_identifiers'].get(
                    model_name, None)

                new_entries[model_name] = {}
                existing_entries[model_name] = {}

                foreign_ids_reverse_mappings[model_name] = {}

                # Not necessarily all models are exported
                if model_name in data['export_data']:

                    if unique_identifier is not None:
                        import_unique_ids = set(
                            v[unique_identifier]
                            for v in data['export_data'][model_name].values())

                        relevant_db_entries = {
                            getattr(n, unique_identifier): n
                            for n in Model.objects.filter(
                                **{
                                    '{}__in'.format(unique_identifier):
                                    import_unique_ids
                                })
                        }

                        foreign_ids_reverse_mappings[model_name] = {
                            k: v.pk
                            for k, v in relevant_db_entries.iteritems()
                        }
                        for k, v in data['export_data'][model_name].iteritems(
                        ):
                            if v[unique_identifier] in relevant_db_entries.keys(
                            ):
                                # Already in DB
                                existing_entries[model_name][k] = v
                            else:
                                # To be added
                                new_entries[model_name][k] = v
                    else:
                        new_entries[model_name] = data['export_data'][
                            model_name].copy()

            # I import data from the given model
            for model_name in model_order:
                Model = get_object_from_string(model_name)
                fields_info = metadata['all_fields_info'].get(model_name, {})
                unique_identifier = metadata['unique_identifiers'].get(
                    model_name, None)

                for import_entry_id, entry_data in existing_entries[
                        model_name].iteritems():
                    unique_id = entry_data[unique_identifier]
                    existing_entry_id = foreign_ids_reverse_mappings[
                        model_name][unique_id]
                    # TODO COMPARE, AND COMPARE ATTRIBUTES
                    if model_name not in ret_dict:
                        ret_dict[model_name] = {'new': [], 'existing': []}
                    ret_dict[model_name]['existing'].append(
                        (import_entry_id, existing_entry_id))
                    if not silent:
                        print "existing %s: %s (%s->%s)" % (
                            model_name, unique_id, import_entry_id,
                            existing_entry_id)
                        # print "  `-> WARNING: NO DUPLICITY CHECK DONE!"
                        # CHECK ALSO FILES!

                # Store all objects for this model in a list, and store them
                # all in once at the end.
                objects_to_create = []
                # This is needed later to associate the import entry with the new pk
                import_entry_ids = {}
                for import_entry_id, entry_data in new_entries[
                        model_name].iteritems():
                    unique_id = entry_data[unique_identifier]
                    import_data = dict(
                        deserialize_field(k,
                                          v,
                                          fields_info=fields_info,
                                          import_unique_ids_mappings=
                                          import_unique_ids_mappings,
                                          foreign_ids_reverse_mappings=
                                          foreign_ids_reverse_mappings)
                        for k, v in entry_data.iteritems())

                    objects_to_create.append(Model(**import_data))
                    import_entry_ids[unique_id] = import_entry_id

                # Before storing entries in the DB, I store the files (if these
                # are nodes). Note: only for new entries!
                if model_name == get_class_string(models.DbNode):
                    if not silent:
                        print "STORING NEW NODE FILES..."
                    for o in objects_to_create:

                        subfolder = folder.get_subfolder(
                            os.path.join(nodes_export_subfolder,
                                         export_shard_uuid(o.uuid)))
                        if not subfolder.exists():
                            raise ValueError(
                                "Unable to find the repository "
                                "folder for node with UUID={} in the exported "
                                "file".format(o.uuid))
                        destdir = RepositoryFolder(section=Node._section_name,
                                                   uuid=o.uuid)
                        # Replace the folder, possibly destroying existing
                        # previous folders, and move the files (faster if we
                        # are on the same filesystem, and
                        # in any case the source is a SandboxFolder)
                        destdir.replace_with_folder(subfolder.abspath,
                                                    move=True,
                                                    overwrite=True)

                # Store them all in once; however, the PK are not set in this way...
                Model.objects.bulk_create(objects_to_create)

                # Get back the just-saved entries
                just_saved = dict(
                    Model.objects.filter(
                        **{
                            "{}__in".format(unique_identifier):
                            import_entry_ids.keys()
                        }).values_list(unique_identifier, 'pk'))

                imported_states = []
                if model_name == get_class_string(models.DbNode):
                    if not silent:
                        print "SETTING THE IMPORTED STATES FOR NEW NODES..."
                    # I set for all nodes, even if I should set it only
                    # for calculations
                    for unique_id, new_pk in just_saved.iteritems():
                        imported_states.append(
                            models.DbCalcState(dbnode_id=new_pk,
                                               state=calc_states.IMPORTED))
                    models.DbCalcState.objects.bulk_create(imported_states)

                # Now I have the PKs, print the info
                # Moreover, set the foreing_ids_reverse_mappings
                for unique_id, new_pk in just_saved.iteritems():
                    import_entry_id = import_entry_ids[unique_id]
                    foreign_ids_reverse_mappings[model_name][
                        unique_id] = new_pk
                    if model_name not in ret_dict:
                        ret_dict[model_name] = {'new': [], 'existing': []}
                    ret_dict[model_name]['new'].append(
                        (import_entry_id, new_pk))

                    if not silent:
                        print "NEW %s: %s (%s->%s)" % (model_name, unique_id,
                                                       import_entry_id, new_pk)

                # For DbNodes, we also have to store Attributes!
                if model_name == get_class_string(models.DbNode):
                    if not silent:
                        print "STORING NEW NODE ATTRIBUTES..."
                    for unique_id, new_pk in just_saved.iteritems():
                        import_entry_id = import_entry_ids[unique_id]
                        # Get attributes from import file
                        try:
                            attributes = data['node_attributes'][str(
                                import_entry_id)]
                            attributes_conversion = data[
                                'node_attributes_conversion'][str(
                                    import_entry_id)]
                        except KeyError:
                            raise ValueError(
                                "Unable to find attribute info "
                                "for DbNode with UUID = {}".format(unique_id))

                        # Here I have to deserialize the attributes
                        deserialized_attributes = deserialize_attributes(
                            attributes, attributes_conversion)
                        models.DbAttribute.reset_values_for_node(
                            dbnode=new_pk,
                            attributes=deserialized_attributes,
                            with_transaction=False)

            if not silent:
                print "STORING NODE LINKS..."
            ## TODO: check that we are not creating input links of an already
            ##       existing node...
            import_links = data['links_uuid']
            links_to_store = []

            # Needed for fast checks of existing links
            existing_links_raw = models.DbLink.objects.all().values_list(
                'input', 'output', 'label')
            existing_links_labels = {(l[0], l[1]): l[2]
                                     for l in existing_links_raw}
            existing_input_links = {(l[1], l[2]): l[0]
                                    for l in existing_links_raw}

            dbnode_reverse_mappings = foreign_ids_reverse_mappings[
                get_class_string(models.DbNode)]
            for link in import_links:
                try:
                    in_id = dbnode_reverse_mappings[link['input']]
                    out_id = dbnode_reverse_mappings[link['output']]
                except KeyError:
                    if ignore_unknown_nodes:
                        continue
                    else:
                        raise ValueError("Trying to create a link with one "
                                         "or both unknown nodes, stopping "
                                         "(in_uuid={}, out_uuid={}, "
                                         "label={})".format(
                                             link['input'], link['output'],
                                             link['label']))

                try:
                    existing_label = existing_links_labels[in_id, out_id]
                    if existing_label != link['label']:
                        raise ValueError(
                            "Trying to rename an existing link name, "
                            "stopping (in={}, out={}, old_label={}, "
                            "new_label={})".format(in_id, out_id,
                                                   existing_label,
                                                   link['label']))
                        # Do nothing, the link is already in place and has the correct
                        # name
                except KeyError:
                    try:
                        existing_input = existing_input_links[out_id,
                                                              link['label']]
                        # If existing_input were the correct one, I would have found
                        # it already in the previous step!
                        raise ValueError(
                            "There exists already an input link to "
                            "node {} with label {} but it does not "
                            "come the expected input {}".format(
                                out_id, link['label'], in_id))
                    except KeyError:
                        # New link
                        links_to_store.append(
                            models.DbLink(input_id=in_id,
                                          output_id=out_id,
                                          label=link['label']))
                        if 'aiida.backends.djsite.db.models.DbLink' not in ret_dict:
                            ret_dict[
                                'aiida.backends.djsite.db.models.DbLink'] = {
                                    'new': []
                                }
                        ret_dict['aiida.backends.djsite.db.models.DbLink'][
                            'new'].append((in_id, out_id))

            # Store new links
            if links_to_store:
                if not silent:
                    print "   ({} new links...)".format(len(links_to_store))

                models.DbLink.objects.bulk_create(links_to_store)
            else:
                if not silent:
                    print "   (0 new links...)"

            if not silent:
                print "STORING GROUP ELEMENTS..."
            import_groups = data['groups_uuid']
            for groupuuid, groupnodes in import_groups.iteritems():
                # TODO: cache these to avoid too many queries
                group = models.DbGroup.objects.get(uuid=groupuuid)
                nodes_to_store = [
                    dbnode_reverse_mappings[node_uuid]
                    for node_uuid in groupnodes
                ]
                if nodes_to_store:
                    group.dbnodes.add(*nodes_to_store)

            ######################################################
            # Put everything in a specific group
            dbnode_model_name = get_class_string(models.DbNode)
            existing = existing_entries.get(dbnode_model_name, {})
            existing_pk = [
                foreign_ids_reverse_mappings[dbnode_model_name][v['uuid']]
                for v in existing.itervalues()
            ]
            new = new_entries.get(dbnode_model_name, {})
            new_pk = [
                foreign_ids_reverse_mappings[dbnode_model_name][v['uuid']]
                for v in new.itervalues()
            ]

            pks_for_group = existing_pk + new_pk

            # So that we do not create empty groups
            if pks_for_group:
                # Get an unique name for the import group, based on the
                # current (local) time
                basename = timezone.localtime(
                    timezone.now()).strftime("%Y%m%d-%H%M%S")
                counter = 0
                created = False
                while not created:
                    if counter == 0:
                        group_name = basename
                    else:
                        group_name = "{}_{}".format(basename, counter)
                    try:
                        group = Group(name=group_name,
                                      type_string=IMPORTGROUP_TYPE).store()
                        created = True
                    except UniquenessError:
                        counter += 1

                # Add all the nodes to the new group
                # TODO: decide if we want to return the group name
                group.add_nodes(
                    models.DbNode.objects.filter(pk__in=pks_for_group))

                if not silent:
                    print "IMPORTED NODES GROUPED IN IMPORT GROUP NAMED '{}'".format(
                        group.name)
            else:
                if not silent:
                    print "NO DBNODES TO IMPORT, SO NO GROUP CREATED"

    if not silent:
        print "*** WARNING: MISSING EXISTING UUID CHECKS!!"
        print "*** WARNING: TODO: UPDATE IMPORT_DATA WITH DEFAULT VALUES! (e.g. calc status, user pwd, ...)"
        print "DONE."

    return ret_dict
Beispiel #22
0
    def create_trajectory_data():
        from aiida.orm.data.array.trajectory import TrajectoryData
        from aiida.orm.group import Group
        import numpy

        # Create a node with two arrays
        n = TrajectoryData()

        # I create sample data
        stepids = numpy.array([60, 70])
        times = stepids * 0.01
        cells = numpy.array([[[
            2.,
            0.,
            0.,
        ], [
            0.,
            2.,
            0.,
        ], [
            0.,
            0.,
            2.,
        ]], [[
            3.,
            0.,
            0.,
        ], [
            0.,
            3.,
            0.,
        ], [
            0.,
            0.,
            3.,
        ]]])
        symbols = numpy.array(['H', 'O', 'C'])
        positions = numpy.array([[[0., 0., 0.], [0.5, 0.5, 0.5],
                                  [1.5, 1.5, 1.5]],
                                 [[0., 0., 0.], [0.5, 0.5, 0.5],
                                  [1.5, 1.5, 1.5]]])
        velocities = numpy.array([[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
                                  [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5],
                                   [-0.5, -0.5, -0.5]]])

        # I set the node
        n.set_trajectory(stepids=stepids,
                         cells=cells,
                         symbols=symbols,
                         positions=positions,
                         times=times,
                         velocities=velocities)

        n.store()

        # Create 2 groups and add the data to one of them
        g_ne = Group(name='non_empty_group')
        g_ne.store()
        g_ne.add_nodes(n)

        g_e = Group(name='empty_group')
        g_e.store()

        return {
            TestVerdiDataListable.NODE_ID_STR: n.id,
            TestVerdiDataListable.NON_EMPTY_GROUP_ID_STR: g_ne.id,
            TestVerdiDataListable.EMPTY_GROUP_ID_STR: g_e.id
        }
Beispiel #23
0
    def test_group_general(self):
        """
        General tests to verify that the group addition with the skip_orm=True flag
        work properly
        """
        from aiida.orm.group import Group
        from aiida.orm.data import Data

        node_01 = Data().store()
        node_02 = Data().store()
        node_03 = Data().store()
        node_04 = Data().store()
        node_05 = Data().store()
        node_06 = Data().store()
        node_07 = Data().store()
        node_08 = Data().store()
        nodes = [
            node_01, node_02, node_03, node_04, node_05, node_06, node_07,
            node_08
        ]

        group = Group(name='test_adding_nodes').store()
        # Single node
        group.add_nodes(node_01, skip_orm=True)
        # List of nodes
        group.add_nodes([node_02, node_03], skip_orm=True)
        # Single DbNode
        group.add_nodes(node_04.dbnode, skip_orm=True)
        # List of DbNodes
        group.add_nodes([node_05.dbnode, node_06.dbnode], skip_orm=True)
        # List of orm.Nodes and DbNodes
        group.add_nodes([node_07, node_08.dbnode], skip_orm=True)

        # Check
        self.assertEqual(set(_.pk for _ in nodes),
                         set(_.pk for _ in group.nodes))

        # Try to add a node that is already present: there should be no problem
        group.add_nodes(node_01, skip_orm=True)
        self.assertEqual(set(_.pk for _ in nodes),
                         set(_.pk for _ in group.nodes))