コード例 #1
0
ファイル: test_orbitals.py プロジェクト: CasperWA/aiida_core
    def test_validation_for_fields(self):
        """Verify that the values are properly validated"""
        RealhydrogenOrbital = OrbitalFactory('realhydrogen')  # pylint: disable=invalid-name

        with self.assertRaises(ValidationError) as exc:
            RealhydrogenOrbital(
                **{
                    'position': (-1, -2, -3),
                    'angular_momentum': 100,
                    'magnetic_number': 0,
                    'radial_nodes': 2
                })
        self.assertIn('angular_momentum', str(exc.exception))

        with self.assertRaises(ValidationError) as exc:
            RealhydrogenOrbital(
                **{
                    'position': (-1, -2, -3),
                    'angular_momentum': 1,
                    'magnetic_number': 3,
                    'radial_nodes': 2
                })
        self.assertIn('magnetic number must be in the range',
                      str(exc.exception))

        with self.assertRaises(ValidationError) as exc:
            RealhydrogenOrbital(
                **{
                    'position': (-1, -2, -3),
                    'angular_momentum': 1,
                    'magnetic_number': 0,
                    'radial_nodes': 100
                })
        self.assertIn('radial_nodes', str(exc.exception))
コード例 #2
0
ファイル: test_orbitals.py プロジェクト: CasperWA/aiida_core
    def test_get_name_from_quantum_numbers(self):
        """
        Test if the function ``get_name_from_quantum_numbers`` works as expected
        """
        RealhydrogenOrbital = OrbitalFactory('realhydrogen')  # pylint: disable=invalid-name

        name = RealhydrogenOrbital.get_name_from_quantum_numbers(
            angular_momentum=1)
        self.assertEqual(name, 'P')

        name = RealhydrogenOrbital.get_name_from_quantum_numbers(
            angular_momentum=0)
        self.assertEqual(name, 'S')

        name = RealhydrogenOrbital.get_name_from_quantum_numbers(
            angular_momentum=0, magnetic_number=0)
        self.assertEqual(name, 'S')

        name = RealhydrogenOrbital.get_name_from_quantum_numbers(
            angular_momentum=1, magnetic_number=1)
        self.assertEqual(name, 'PX')

        name = RealhydrogenOrbital.get_name_from_quantum_numbers(
            angular_momentum=2, magnetic_number=4)
        self.assertEqual(name, 'DXY')
コード例 #3
0
def find_orbitals_from_statelines(out_info_dict):
    """
    This function reads in all the state_lines, that is, the lines describing
    which atomic states, taken from the pseudopotential, are used for the
    projection. Then it converts these state_lines into a set of orbitals.

    :param out_info_dict: contains various technical internals useful in parsing
    :return: orbitals, a list of orbitals suitable for setting ProjectionData
    """
    out_file = out_info_dict['out_file']
    atomnum_re = re.compile(r'atom (.*?)\(')
    element_re = re.compile(r'\((.*?)\)')
    lnum_re = re.compile(r'l=(.*?)m=')
    mnum_re = re.compile(r'm=(.*?)\)')
    wfc_lines = out_info_dict['wfc_lines']
    state_lines = [out_file[wfc_line] for wfc_line in wfc_lines]
    state_dicts = []
    for state_line in state_lines:
        try:
            state_dict = {}
            state_dict['atomnum'] = int(atomnum_re.findall(state_line)[0])
            state_dict['atomnum'] -= 1  # to keep with orbital indexing
            state_dict['kind_name'] = element_re.findall(state_line)[0].strip()
            state_dict['angular_momentum'] = int(
                lnum_re.findall(state_line)[0])
            state_dict['magnetic_number'] = int(mnum_re.findall(state_line)[0])
            state_dict['magnetic_number'] -= 1  # to keep with orbital indexing
        except ValueError:
            raise QEOutputParsingError(
                'State lines are not formatted in a standard way.')
        state_dicts.append(state_dict)

    # here is some logic to figure out the value of radial_nodes to use
    new_state_dicts = []
    for i in range(len(state_dicts)):
        radial_nodes = 0
        state_dict = state_dicts[i].copy()
        for j in range(i - 1, -1, -1):
            if state_dict == state_dicts[j]:
                radial_nodes += 1
        state_dict['radial_nodes'] = radial_nodes
        new_state_dicts.append(state_dict)
    state_dicts = new_state_dicts

    # here is some logic to assign positions based on the atom_index
    structure = out_info_dict['structure']
    for state_dict in state_dicts:
        site_index = state_dict.pop('atomnum')
        state_dict['position'] = structure.sites[site_index].position

    # here we set the resulting state_dicts to a new set of orbitals
    orbitals = []
    RealhydrogenOrbital = OrbitalFactory('realhydrogen')
    for state_dict in state_dicts:
        orbitals.append(RealhydrogenOrbital(**state_dict))

    return orbitals
コード例 #4
0
ファイル: test_orbitals.py プロジェクト: CasperWA/aiida_core
    def test_optional_fields(self):
        """
        Testing (some of) the optional parameters to check that the functionality works
        (they are indeed optional but accepted if specified, they are validated, ...)
        """
        RealhydrogenOrbital = OrbitalFactory('realhydrogen')  # pylint: disable=invalid-name

        orbital = RealhydrogenOrbital(
            **{
                'position': (-1, -2, -3),
                'angular_momentum': 1,
                'magnetic_number': 0,
                'radial_nodes': 2
            })
        # Check that the optional value is there and has its default value
        self.assertEqual(orbital.get_orbital_dict()['spin'], 0)
        self.assertEqual(orbital.get_orbital_dict()['diffusivity'], None)

        orbital = RealhydrogenOrbital(
            **{
                'position': (-1, -2, -3),
                'angular_momentum': 1,
                'magnetic_number': 0,
                'radial_nodes': 2,
                'spin': 1,
                'diffusivity': 3.1
            })
        self.assertEqual(orbital.get_orbital_dict()['spin'], 1)
        self.assertEqual(orbital.get_orbital_dict()['diffusivity'], 3.1)

        with self.assertRaises(ValidationError) as exc:
            RealhydrogenOrbital(
                **{
                    'position': (-1, -2, -3),
                    'angular_momentum': 1,
                    'magnetic_number': 0,
                    'radial_nodes': 2,
                    'spin': 1,
                    'diffusivity': 'a'
                })
        self.assertIn('diffusivity', str(exc.exception))

        with self.assertRaises(ValidationError) as exc:
            RealhydrogenOrbital(
                **{
                    'position': (-1, -2, -3),
                    'angular_momentum': 1,
                    'magnetic_number': 0,
                    'radial_nodes': 2,
                    'spin': 5,
                    'diffusivity': 3.1
                })
        self.assertIn('spin', str(exc.exception))
コード例 #5
0
ファイル: test_orbitals.py プロジェクト: CasperWA/aiida_core
    def test_unknown_fields(self):
        """Verify that unkwown fields raise a validation error."""
        RealhydrogenOrbital = OrbitalFactory('realhydrogen')  # pylint: disable=invalid-name

        with self.assertRaises(ValidationError) as exc:
            RealhydrogenOrbital(
                **{
                    'position': (-1, -2, -3),
                    'angular_momentum': 1,
                    'magnetic_number': 0,
                    'radial_nodes': 2,
                    'some_strange_key': 1
                })
        self.assertIn('some_strange_key', str(exc.exception))
コード例 #6
0
ファイル: orbital.py プロジェクト: CasperWA/aiida_core
    def get_orbitals(self, **kwargs):
        """
        Returns all orbitals by default. If a site is provided, returns
        all orbitals cooresponding to the location of that site, additional
        arguments may be provided, which act as filters on the retrieved
        orbitals.

        :param site: if provided, returns all orbitals with position of site
        :kwargs: attributes than can filter the set of returned orbitals
        :return list_of_outputs: a list of orbitals
        """

        orbital_dicts = copy.deepcopy(self.get_attribute(
            'orbital_dicts', None))
        if orbital_dicts is None:
            raise AttributeError('Orbitals must be set before being retrieved')

        filter_dict = {}
        filter_dict.update(kwargs)
        # prevents KeyError from occuring
        orbital_dicts = [
            x for x in orbital_dicts if all([y in x for y in filter_dict])
        ]
        orbital_dicts = [
            x for x in orbital_dicts
            if all([x[y] == filter_dict[y] for y in filter_dict])
        ]

        list_of_outputs = []
        for orbital_dict in orbital_dicts:
            try:
                orbital_type = orbital_dict.pop('_orbital_type')
            except KeyError:
                raise ValidationError(
                    'No _orbital_type found in: {}'.format(orbital_dict))

            cls = OrbitalFactory(orbital_type)
            orbital = cls(**orbital_dict)
            list_of_outputs.append(orbital)
        return list_of_outputs
コード例 #7
0
ファイル: test_orbitals.py プロジェクト: CasperWA/aiida_core
    def test_required_fields(self):
        """Verify that required fields are validated."""
        RealhydrogenOrbital = OrbitalFactory('realhydrogen')  # pylint: disable=invalid-name
        # Check that the required fields of the base class are not enough
        with self.assertRaises(ValidationError):
            RealhydrogenOrbital(position=(1, 2, 3))

        orbital = RealhydrogenOrbital(
            **{
                'position': (-1, -2, -3),
                'angular_momentum': 1,
                'magnetic_number': 0,
                'radial_nodes': 2
            })
        self.assertAlmostEqual(orbital.get_orbital_dict()['position'][0], -1.)
        self.assertAlmostEqual(orbital.get_orbital_dict()['position'][1], -2.)
        self.assertAlmostEqual(orbital.get_orbital_dict()['position'][2], -3.)
        self.assertAlmostEqual(orbital.get_orbital_dict()['angular_momentum'],
                               1)
        self.assertAlmostEqual(orbital.get_orbital_dict()['magnetic_number'],
                               0)
        self.assertAlmostEqual(orbital.get_orbital_dict()['radial_nodes'], 2)
コード例 #8
0
def find_orbitals_from_statelines(out_info_dict):
    """This function reads in all the state_lines, that is, the lines describing which atomic states, taken from the
    pseudopotential, are used for the projection. Then it converts these state_lines into a set of orbitals.

    :param out_info_dict: contains various technical internals useful in parsing
    :return: orbitals, a list of orbitals suitable for setting ProjectionData
    """

    # Format of statelines
    # From PP/src/projwfc.f90: (since Oct. 8 2019)
    #
    # 1000 FORMAT (5x,"state #",i4,": atom ",i3," (",a3,"), wfc ",i2," (l=",i1)
    # IF (lspinorb) THEN
    # 1001 FORMAT (" j=",f3.1," m_j=",f4.1,")")
    # ELSE IF (noncolin) THEN
    # 1002 FORMAT (" m=",i2," s_z=",f4.1,")")
    # ELSE
    # 1003 FORMAT (" m=",i2,")")
    # ENDIF
    #
    # Before:
    # IF (lspinorb) THEN
    # ...
    # 1000    FORMAT (5x,"state #",i4,": atom ",i3," (",a3,"), wfc ",i2, &
    #               " (j=",f3.1," l=",i1," m_j=",f4.1,")")
    # ELSE
    # ...
    # 1500    FORMAT (5x,"state #",i4,": atom ",i3," (",a3,"), wfc ",i2, &
    #               " (l=",i1," m=",i2," s_z=",f4.1,")")
    # ENDIF

    out_file = out_info_dict['out_file']
    atomnum_re = re.compile(r'atom\s*([0-9]+?)[^0-9]')
    element_re = re.compile(r'atom\s*[0-9]+\s*\(\s*([A-Za-z0-9-_]+?)\s*\)')
    if out_info_dict['spinorbit']:
        # spinorbit
        lnum_re = re.compile(r'l=\s*([0-9]+?)[^0-9]')
        jnum_re = re.compile(r'j=\s*([0-9.]+?)[^0-9.]')
        mjnum_re = re.compile(r'm_j=\s*([-0-9.]+?)[^-0-9.]')
    elif not out_info_dict['collinear']:
        # non-collinear
        lnum_re = re.compile(r'l=\s*([0-9]+?)[^0-9]')
        mnum_re = re.compile(r'm=\s*([-0-9]+?)[^-0-9]')
        sznum_re = re.compile(r's_z=\s*([-0-9.]*?)[^-0-9.]')
    else:
        # collinear / no spin
        lnum_re = re.compile(r'l=\s*([0-9]+?)[^0-9]')
        mnum_re = re.compile(r'm=\s*([-0-9]+?)[^-0-9]')
    wfc_lines = out_info_dict['wfc_lines']
    state_lines = [out_file[wfc_line] for wfc_line in wfc_lines]
    state_dicts = []
    for state_line in state_lines:
        try:
            state_dict = {}
            state_dict['atomnum'] = int(atomnum_re.findall(state_line)[0])
            state_dict['atomnum'] -= 1  # to keep with orbital indexing
            state_dict['kind_name'] = element_re.findall(state_line)[0].strip()
            state_dict['angular_momentum'] = int(
                lnum_re.findall(state_line)[0])
            if out_info_dict['spinorbit']:
                state_dict['total_angular_momentum'] = float(
                    jnum_re.findall(state_line)[0])
                state_dict['magnetic_number'] = float(
                    mjnum_re.findall(state_line)[0])
            else:
                if not out_info_dict['collinear']:
                    state_dict['spin'] = float(sznum_re.findall(state_line)[0])
                state_dict['magnetic_number'] = int(
                    mnum_re.findall(state_line)[0])
                state_dict[
                    'magnetic_number'] -= 1  # to keep with orbital indexing
        except ValueError:
            raise QEOutputParsingError(
                'State lines are not formatted in a standard way.')
        state_dicts.append(state_dict)

    # here is some logic to figure out the value of radial_nodes to use
    new_state_dicts = []
    for i in range(len(state_dicts)):
        radial_nodes = 0
        state_dict = state_dicts[i].copy()
        for j in range(i - 1, -1, -1):
            if state_dict == state_dicts[j]:
                radial_nodes += 1
        state_dict['radial_nodes'] = radial_nodes
        new_state_dicts.append(state_dict)
    state_dicts = new_state_dicts

    # here is some logic to assign positions based on the atom_index
    structure = out_info_dict['structure']
    for state_dict in state_dicts:
        site_index = state_dict.pop('atomnum')
        state_dict['position'] = structure.sites[site_index].position

    # here we set the resulting state_dicts to a new set of orbitals
    orbitals = []
    if out_info_dict['spinorbit']:
        OrbitalCls = OrbitalFactory('spinorbithydrogen')
    elif not out_info_dict['collinear']:
        OrbitalCls = OrbitalFactory('noncollinearhydrogen')
    else:
        OrbitalCls = OrbitalFactory('realhydrogen')
    for state_dict in state_dicts:
        orbitals.append(OrbitalCls(**state_dict))

    return orbitals
コード例 #9
0
    def set_projectiondata(self,
                           list_of_orbitals,
                           list_of_projections=None,
                           list_of_energy=None,
                           list_of_pdos=None,
                           tags=None,
                           bands_check=True):
        """
        Stores the projwfc_array using the projwfc_label, after validating both.

        :param list_of_orbitals: list of orbitals, of class orbital data.
                                 They should be the ones up on which the
                                 projection array corresponds with.

        :param list_of_projections: list of arrays of projections of a atomic
                              wavefunctions onto bloch wavefunctions. Since the
                              projection is for every bloch wavefunction which
                              can be specified by its spin (if used), band, and
                              kpoint the dimensions must be
                              nspin x nbands x nkpoints for the projwfc array.
                              Or nbands x nkpoints if spin is not used.

        :param energy_axis: list of energy axis for the list_of_pdos

        :param list_of_pdos: a list of projected density of states for the
                             atomic wavefunctions, units in states/eV

        :param tags: A list of tags, not supported currently.

        :param bands_check: if false, skips checks of whether the bands has
                            been already set, and whether the sizes match. For
                            use in parsers, where the BandsData has not yet
                            been stored and therefore get_reference_bandsdata
                            cannot be called
        """

        # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements

        def single_to_list(item):
            """
            Checks if the item is a list or tuple, and converts it to a list
            if it is not already a list or tuple

            :param item: an object which may or may not be a list or tuple
            :return: item_list: the input item unchanged if list or tuple and
                                [item] otherwise
            """
            if isinstance(item, (list, tuple)):
                return item

            return [item]

        def array_list_checker(array_list, array_name, orb_length):
            """
            Does basic checks over everything in the array_list. Makes sure that
            all the arrays are np.ndarray floats, that the length is same as
            required_length, raises exception using array_name if there is
            a failure
            """
            if not all([isinstance(_, np.ndarray) for _ in array_list]):
                raise exceptions.ValidationError(
                    '{} was not composed entirely of ndarrays'.format(
                        array_name))
            if len(array_list) != orb_length:
                raise exceptions.ValidationError(
                    '{} did not have the same length as the '
                    'list of orbitals'.format(array_name))

        ##############
        list_of_orbitals = single_to_list(list_of_orbitals)
        list_of_orbitals = copy.deepcopy(list_of_orbitals)

        # validates the input data
        if not list_of_pdos and not list_of_projections:
            raise exceptions.ValidationError(
                'Must set either pdos or projections')
        if bool(list_of_energy) != bool(list_of_pdos):
            raise exceptions.ValidationError(
                'list_of_pdos and list_of_energy must always be set together')

        orb_length = len(list_of_orbitals)

        # verifies and sets the orbital dicts
        list_of_orbital_dicts = []
        for i, _ in enumerate(list_of_orbitals):
            this_orbital = list_of_orbitals[i]
            orbital_dict = this_orbital.get_orbital_dict()
            try:
                orbital_type = orbital_dict.pop('_orbital_type')
            except KeyError:
                raise exceptions.ValidationError(
                    'No _orbital_type key found in dictionary: {}'.format(
                        orbital_dict))
            cls = OrbitalFactory(orbital_type)
            test_orbital = cls(**orbital_dict)
            list_of_orbital_dicts.append(test_orbital.get_orbital_dict())
        self.set_attribute('orbital_dicts', list_of_orbital_dicts)

        # verifies and sets the projections
        if list_of_projections:
            list_of_projections = single_to_list(list_of_projections)
            array_list_checker(list_of_projections, 'projections', orb_length)
            for i, _ in enumerate(list_of_projections):
                this_projection = list_of_projections[i]
                array_name = self._from_index_to_arrayname(i)
                if bands_check:
                    self._check_projections_bands(this_projection)
                self.set_array('proj_{}'.format(array_name), this_projection)

        # verifies and sets both pdos and energy
        if list_of_pdos:
            list_of_pdos = single_to_list(list_of_pdos)
            list_of_energy = single_to_list(list_of_energy)
            array_list_checker(list_of_pdos, 'pdos', orb_length)
            array_list_checker(list_of_energy, 'energy', orb_length)
            for i, _ in enumerate(list_of_pdos):
                this_pdos = list_of_pdos[i]
                this_energy = list_of_energy[i]
                array_name = self._from_index_to_arrayname(i)
                if bands_check:
                    self._check_projections_bands(this_projection)
                self.set_array('pdos_{}'.format(array_name), this_pdos)
                self.set_array('energy_{}'.format(array_name), this_energy)

        # verifies and sets the tags
        if tags is not None:
            try:
                if len(tags) != len(list_of_orbitals):
                    raise exceptions.ValidationError(
                        'must set as many tags as projections')
            except IndexError:
                return exceptions.ValidationError('tags must be a list')

            if not all([isinstance(_, str) for _ in tags]):
                raise exceptions.ValidationError(
                    'Tags must set a list of strings')
            self.set_attribute('tags', tags)
コード例 #10
0
def _generate_wannier_orbitals(  # pylint: disable=too-many-arguments,too-many-locals,too-many-statements # noqa:  disable=MC0001
        position_cart=None,
        structure=None,
        kind_name=None,
        radial=1,
        ang_mtm_name=None,
        ang_mtm_l_list=None,
        ang_mtm_mr_list=None,
        spin=None,
        zona=None,
        zaxis=None,
        xaxis=None,
        spin_axis=None):
    """
    Use this method to emulate the input style of Wannier90,
    when setting the orbitals (see chapter 3 in the user_guide). Position
    can be provided either in Cartesian coordiantes using ``position_cart``
    or can be assigned based on an input structure and ``kind_name``.

    :param position_cart: position in Cartesian coordinates or list of
                          positions in Cartesian coodriantes
    :param structure: input structure for use with kind_names
    :param kind_name: kind_name, for use with the structure
    :param radial: number of radial nodes
    :param ang_mtm_name: orbital name or list of orbital names, cannot
                         be used in conjunction with ang_mtm_l_list or
                         ang_mtm_mr_list
    :param ang_mtm_l_list: angular momentum (either an integer or a list), if 
                 ang_mtm_mr_list is not specified will return all orbitals associated with it
    :param ang_mtm_mr_list: magnetic angular momentum number must be specified
                       along with ang_mtm_l_list. Note that if this is specified,
                       ang_mtm_l_list must be an integer and not a list
    :param spin: the spin, spin up can be specified with 1,u or U and
                 spin down can be specified using -1,d,D
    :param zona: as specified in user guide, applied to all orbitals
    :param zaxis: the zaxis, list of three floats
                  as described in wannier user guide
    :param xaxis: the xaxis, list of three floats as described in the
                  wannier user guide
    :param spin_axis: the spin alignment axis, as described in the
                      user guide
    """
    from aiida.plugins import DataFactory
    from aiida.plugins import OrbitalFactory
    from aiida.common import InputValidationError

    def convert_to_list(item):
        """
        internal method, checks if the item is already a list or tuple.
        if not returns a tuple containing only item, otherwise returns
        tuple(item)
        """
        if isinstance(item, (list, tuple)):
            return tuple(item)
        return tuple([item])

    def combine_dictlists(dict_list1, dict_list2):
        """
        Creates a list of every dict in dict_list1 updated with every
        dict in dict_list2
        """
        out_list = []
        # excpetion handling for the case of empty dicts
        dict_list1_empty = not any([bool(x) for x in dict_list1])
        dict_list2_empty = not any([bool(x) for x in dict_list2])
        if dict_list1_empty and dict_list2_empty:
            raise InputValidationError('One dict must not be empty')
        if dict_list1_empty:
            return dict_list2
        if dict_list2_empty:
            return dict_list2

        for dict_1 in dict_list1:
            for dict_2 in dict_list2:
                temp_1 = dict_1.copy()
                temp_2 = dict_2.copy()
                temp_1.update(temp_2)
                out_list.append(temp_1)
        return out_list

    RealhydrogenOrbital = OrbitalFactory('realhydrogen')

    #########################################################################
    # Validation of inputs                                                  #
    #########################################################################
    if position_cart is None and kind_name is None:
        raise InputValidationError('Must supply a kind_name or position')
    if position_cart is not None and kind_name is not None:
        raise InputValidationError('Must supply position or kind_name'
                                   ' not both')

    structure_class = DataFactory('structure')
    if kind_name is not None:
        if not isinstance(structure, structure_class):
            raise InputValidationError('Must supply a StructureData as '
                                       'structure if using kind_name')
        if not isinstance(kind_name, six.string_types):
            raise InputValidationError('kind_name must be a string')

    if ang_mtm_name is None and ang_mtm_l_list is None:
        raise InputValidationError(
            "Must supply ang_mtm_name or ang_mtm_l_list")
    if ang_mtm_name is not None and (ang_mtm_l_list is not None
                                     or ang_mtm_mr_list is not None):
        raise InputValidationError(
            "Cannot supply ang_mtm_l_list or ang_mtm_mr_list"
            " but not both")
    if ang_mtm_l_list is None and ang_mtm_mr_list is not None:
        raise InputValidationError("Cannot supply ang_mtm_mr_list without "
                                   "ang_mtm_l_list")

    ####################################################################
    #Setting up initial basic parameters
    ####################################################################
    projection_dict = {}
    if radial:
        projection_dict['radial_nodes'] = radial - 1
    if xaxis:
        projection_dict['x_orientation'] = xaxis
    if zaxis:
        projection_dict['z_orientation'] = zaxis
    if kind_name:
        projection_dict['kind_name'] = kind_name
    if spin_axis:
        projection_dict['spin_orientation'] = spin_axis
    if zona:
        projection_dict['diffusivity'] = zona

    projection_dicts = [projection_dict]

    #####################################################################
    # Setting up Positions                                              #
    #####################################################################
    # finds all the positions to append the orbitals to (if applicable)
    position_list = []
    if kind_name:
        for site in structure.sites:
            if site.kind_name == kind_name:
                position_list.append(site.position)
        if not position_list:
            raise InputValidationError("No valid positions found in structure "
                                       "using {}".format(kind_name))
    # otherwise turns position into position_list
    else:
        position_list = [convert_to_list(position_cart)]
    position_dicts = [{"position": v} for v in position_list]
    projection_dicts = combine_dictlists(projection_dicts, position_dicts)

    #######################################################################
    # Setting up angular momentum                                         #
    #######################################################################
    # if ang_mtm_l_list, ang_mtm_mr_list provided, setup dicts
    if ang_mtm_l_list is not None:
        ang_mtm_l_list = convert_to_list(ang_mtm_l_list)
        ang_mtm_dicts = []
        for ang_mtm_l in ang_mtm_l_list:
            if ang_mtm_l >= 0:
                ang_mtm_dicts += [{
                    'angular_momentum': ang_mtm_l,
                    'magnetic_number': i
                } for i in range(2 * ang_mtm_l + 1)]
            else:
                ang_mtm_dicts += [{
                    'angular_momentum': ang_mtm_l,
                    'magnetic_number': i
                } for i in range(-ang_mtm_l + 1)]
        if ang_mtm_mr_list is not None:
            if len(ang_mtm_l_list) > 1:
                raise InputValidationError("If you are giving specific"
                                           " magnetic numbers please do"
                                           " not supply more than one"
                                           " angular number.")
            ang_mtm_mr_list = convert_to_list(ang_mtm_mr_list)
            ang_mtm_l_num = ang_mtm_l_list[0]
            ang_mtm_dicts = [{
                'angular_momentum': ang_mtm_l_num,
                'magnetic_number': j - 1
            } for j in ang_mtm_mr_list]
    if ang_mtm_name is not None:
        ang_mtm_names = convert_to_list(ang_mtm_name)
        ang_mtm_dicts = []
        for name in ang_mtm_names:
            # get_quantum_numbers_from_name (in AiiDA) might not return
            # a consistent order since it creates the list from a dictionary
            # This might be considered a bug in AiiDA, but since AiiDA is going
            # to drop py2 support soon, this might not be fixed, so we work
            # around the issue here.
            ang_mtm_dicts += sorted(
                RealhydrogenOrbital.get_quantum_numbers_from_name(name),
                key=lambda qnums:
                (qnums['angular_momentum'], qnums['magnetic_number']))
    projection_dicts = combine_dictlists(projection_dicts, ang_mtm_dicts)

    #####################################################################
    # Setting up the spin                                               #
    #####################################################################
    if spin:
        spin_dict = {'U': 1, 'u': 1, 1: 1, 'D': -1, 'd': -1, -1: -1}
        if isinstance(spin, (list, tuple)):
            spin = [spin_dict[x] for x in spin]
        else:
            spin = [spin_dict[spin]]
        spin_dicts = [{'spin': v} for v in spin]
        projection_dicts = combine_dictlists(projection_dicts, spin_dicts)

    # generating and returning a list of all corresponding orbitals
    orbital_out = []
    for projection_dict in projection_dicts:
        realh = RealhydrogenOrbital(**projection_dict)
        orbital_out.append(realh)
    return orbital_out
コード例 #11
0
def _format_single_projection(orbital):  #pylint: disable=too-many-locals
    """
    Creates an appropriate wannier line from input orbitaldata,
    will raise an exception if the orbital does not contain enough
    information, or the information is badly formated
    """
    from aiida.plugins import OrbitalFactory
    RealhydrogenOrbital = OrbitalFactory("realhydrogen")

    if not isinstance(orbital, RealhydrogenOrbital):
        raise InputValidationError(
            "Only realhydrogen orbitals are currently supported for Wannier90 input."
        )
    orb_dict = copy.deepcopy(orbital.get_orbital_dict())

    def _get_attribute(name, required=True):
        res = orb_dict.get(name, None)
        if res is None and required:
            raise InputValidationError(
                "Orbital is missing attribute '{}'.".format(name))
        return res

    def _format_projection_values_float(name, value):
        """
        Return a string for a given key-value pair of the projections block, e.g.
        ``'c=0.132443,1.324823823,0.547423243'``, where we know that values are floats
        that will be formatted with a specific formatting option.
        """
        if value is None:
            return ''
        if not isinstance(value, (tuple, list)):
            value = [value]
        return '{}={}'.format(name,
                              ','.join("{:.10f}".format(x) for x in value))

    def _format_projection_values_generic(name, value):
        """
        Return a string for a given key-value pair of the projections block, e.g.
        ``'l=1'``, where formatting of the values is done without specifying
        a custom format - this is ok for e.g. integers, while for floats it's
        better to use :func:`_format_projection_values_float` function that
        properly formats floats, avoiding differences between python versions.
        """
        if value is None:
            return ''
        if not isinstance(value, (tuple, list)):
            value = [value]
        return '{}={}'.format(name, ','.join("{}".format(x) for x in value))

    # required arguments
    position = _get_attribute("position")
    angular_momentum = _get_attribute("angular_momentum")
    magnetic_number = _get_attribute("magnetic_number")
    wann_string = (
        _format_projection_values_float('c', position) + ':' +
        _format_projection_values_generic('l', angular_momentum) + ',' +
        _format_projection_values_generic('mr', magnetic_number + 1))

    # optional, colon-separated arguments
    zaxis = _get_attribute("z_orientation", required=False)
    xaxis = _get_attribute("x_orientation", required=False)
    radial = _get_attribute("radial_nodes", required=False)
    zona = _get_attribute("diffusivity", required=False)
    if any(arg is not None for arg in [zaxis, xaxis, radial, zona]):
        zaxis_string = _format_projection_values_float('z', zaxis)
        xaxis_string = _format_projection_values_float('x', xaxis)
        radial_string = _format_projection_values_generic('r', radial + 1)
        zona_string = _format_projection_values_float('zona', zona)
        wann_string += ':{}:{}:{}:{}'.format(zaxis_string, xaxis_string,
                                             radial_string, zona_string)

    # spin, optional
    # Careful with spin, it is insufficient to set the spin the projection
    # line alone. You must, in addition, apply the appropriate settings:
    # either set spinors=.true. or use spinor_projections, see user guide
    spin = _get_attribute("spin", required=False)
    if spin is not None and spin != 0:
        spin_dict = {-1: "d", 1: "u"}
        wann_string += "({})".format(spin_dict[spin])
    spin_orient = _get_attribute("spin_orientation", required=False)
    if spin_orient is not None:
        wann_string += "[" + ",".join(
            ["{:18.10f}".format(x) for x in spin_orient]) + "]"

    return wann_string