Пример #1
0
def get_nkpts(xmltree, schema_dict, logger=None):
    """
    Get the number of kpoints that will be used in the calculation specified in the given
    fleur XMl file.

    .. warning::
        For file versions before Max5 only kPointList or kPointCount tags will work. However,
        for kPointCount there is no real guarantee that for every occasion it will correspond
        to the number of kpoints. So a warning is written out

    :param xmltree: etree representing the fleur xml file
    :param schema_dict: schema dictionary corresponding to the file version
                        of the xmltree
    :param logger: logger object for logging warnings, errors

    :returns: int with the number of kpoints
    """
    from masci_tools.util.schema_dict_util import eval_simple_xpath
    from masci_tools.util.schema_dict_util import evaluate_attribute
    from masci_tools.util.xml.common_functions import clear_xml

    if isinstance(xmltree, etree._ElementTree):
        xmltree, _ = clear_xml(xmltree)
        root = xmltree.getroot()
    else:
        root = xmltree

    #Get the name of the current selected kPointSet
    list_name = evaluate_attribute(root,
                                   schema_dict,
                                   'listName',
                                   logger=logger)

    kpointlists = eval_simple_xpath(root,
                                    schema_dict,
                                    'kPointList',
                                    list_return=True,
                                    logger=logger)

    if len(kpointlists) == 0:
        raise ValueError('No Kpoint lists found in the given inp.xml')

    labels = [kpoint_set.attrib.get('name') for kpoint_set in kpointlists]
    if list_name not in labels:
        raise ValueError(
            f'Selected Kpoint list with the name: {list_name} does not exist'
            f'Available list names: {labels}')

    kpoint_index = labels.index(list_name)

    kpoint_set = kpointlists[kpoint_index]

    nkpts = evaluate_attribute(kpoint_set, schema_dict, 'count', logger=logger)

    return nkpts
def test_schema_dict_util_abs_to_rel_path():
    """
    Test of the absolute to relative xpath conversion in schema_dict_util functions
    """
    from lxml import etree
    from masci_tools.util.schema_dict_util import eval_simple_xpath, get_number_of_nodes, tag_exists, \
                                                  evaluate_attribute, evaluate_tag, evaluate_parent_tag, \
                                                  evaluate_text

    schema_dict = schema_dict_34

    parser = etree.XMLParser(attribute_defaults=True,
                             recover=False,
                             encoding='utf-8')
    root = etree.parse(TEST_INPXML_PATH, parser).getroot()

    species = eval_simple_xpath(root, schema_dict, 'species')

    assert tag_exists(species[0], schema_dict, 'lo')
    assert tag_exists(species[1], schema_dict, 'lo')

    assert get_number_of_nodes(species[0], schema_dict, 'lo') == 2
    assert get_number_of_nodes(species[1], schema_dict, 'lo') == 1

    assert evaluate_attribute(species[0],
                              schema_dict,
                              'name',
                              constants=FLEUR_DEFINED_CONSTANTS) == 'Fe-1'

    assert evaluate_text(species[0],
                         schema_dict,
                         'coreConfig',
                         constants=FLEUR_DEFINED_CONSTANTS) == ['[Ne]']

    assert evaluate_tag(species[1],
                        schema_dict,
                        'lo',
                        constants=FLEUR_DEFINED_CONSTANTS) == {
                            'eDeriv': 0,
                            'l': 1,
                            'n': 5,
                            'type': 'SCLO'
                        }
    assert evaluate_parent_tag(species[1],
                               schema_dict,
                               'lo',
                               constants=FLEUR_DEFINED_CONSTANTS) == {
                                   'atomicNumber': 78,
                                   'element': 'Pt',
                                   'name': 'Pt-1'
                               }
Пример #3
0
def switch_kpointset(xmltree, schema_dict, list_name):
    """
    Switch the used k-point set

    .. warning::
        This method is only supported for input versions after the Max5 release

    :param xmltree: xml tree that represents inp.xml
    :param schema_dict: InputSchemaDict containing all information about the structure of the input
    :param list_name: name of the kPoint set to use

    :returns: an xmltree of the inp.xml file with changes.
    """
    from masci_tools.util.schema_dict_util import evaluate_attribute

    existing_labels = evaluate_attribute(xmltree, schema_dict, 'name', contains='kPointList', list_return=True)

    if list_name not in existing_labels:
        raise ValueError(f'The given kPointList {list_name} does not exist',
                         f'Available kPointLists: {existing_labels}')

    return set_first_attrib_value(xmltree, schema_dict, 'listName', list_name)
Пример #4
0
def get_relaxation_information(xmltree, schema_dict, logger=None):
    """
    Get the relaxation information from the given fleur XML file. This includes the current
    displacements, energy and posforce evolution

    :param xmltree: etree representing the fleur xml file
    :param schema_dict: schema dictionary corresponding to the file version
                        of the xmltree
    :param logger: logger object for logging warnings, errors

    :returns: dict with the relaxation information

    :raises ValueError: If no relaxation section is included in the xml tree
    """
    from masci_tools.util.schema_dict_util import tag_exists, read_constants, evaluate_text, eval_simple_xpath
    from masci_tools.util.schema_dict_util import evaluate_attribute
    from masci_tools.util.xml.common_functions import clear_xml

    if isinstance(xmltree, etree._ElementTree):
        xmltree, _ = clear_xml(xmltree)
        root = xmltree.getroot()
    else:
        root = xmltree
    constants = read_constants(root, schema_dict, logger=logger)

    if not tag_exists(root, schema_dict, 'relaxation', logger=logger):
        raise ValueError(
            'No relaxation information included in the given XML file')

    relax_tag = eval_simple_xpath(root,
                                  schema_dict,
                                  'relaxation',
                                  logger=logger)

    out_dict = {}
    out_dict['displacements'] = evaluate_text(relax_tag,
                                              schema_dict,
                                              'displace',
                                              list_return=True,
                                              constants=constants,
                                              logger=logger)

    out_dict['energies'] = evaluate_attribute(relax_tag,
                                              schema_dict,
                                              'energy',
                                              list_return=True,
                                              constants=constants,
                                              logger=logger)

    out_dict['posforces'] = []
    relax_iters = eval_simple_xpath(relax_tag,
                                    schema_dict,
                                    'step',
                                    list_return=True,
                                    logger=logger)
    for step in relax_iters:
        posforces = evaluate_text(step,
                                  schema_dict,
                                  'posforce',
                                  list_return=True,
                                  constants=constants,
                                  logger=logger)
        out_dict['posforces'].append(posforces)

    return out_dict
Пример #5
0
def get_kpoints_data_max4(xmltree, schema_dict, logger=None):
    """
    Get the kpoint sets defined in the given fleur xml file.

    .. note::
        This function is specific to file version before and including the
        Max4 release of fleur

    :param xmltree: etree representing the fleur xml file
    :param schema_dict: schema dictionary corresponding to the file version
                        of the xmltree
    :param logger: logger object for logging warnings, errors

    :returns: tuple containing the kpoint information

    The tuple contains the following entries:

        1. :kpoints: list containing the coordinates of the kpoints
        2. :weights: list containing the weights of the kpoints
        3. :cell: numpy array, bravais matrix of the given system
        4. :pbc: list of booleans, determines in which directions periodic boundary conditions are applicable

    """
    from masci_tools.util.schema_dict_util import read_constants, eval_simple_xpath
    from masci_tools.util.schema_dict_util import evaluate_text, evaluate_attribute
    from masci_tools.util.xml.common_functions import clear_xml

    if isinstance(xmltree, etree._ElementTree):
        xmltree, _ = clear_xml(xmltree)
        root = xmltree.getroot()
    else:
        root = xmltree

    constants = read_constants(root, schema_dict, logger=logger)

    cell, pbc = get_cell(root, schema_dict, logger=logger)

    kpointlist = eval_simple_xpath(root,
                                   schema_dict,
                                   'kPointList',
                                   list_return=True,
                                   not_contains='altKPoint',
                                   logger=logger)

    if len(kpointlist) == 0:
        raise ValueError('No Kpoint lists found in the given inp.xml')

    kpointlist = kpointlist[0]

    kpoints = evaluate_text(kpointlist,
                            schema_dict,
                            'kPoint',
                            constants=constants,
                            not_contains='altKPoint',
                            list_return=True,
                            logger=logger)
    weights = evaluate_attribute(kpointlist,
                                 schema_dict,
                                 'weight',
                                 constants=constants,
                                 not_contains='altKPoint',
                                 list_return=True,
                                 logger=logger)

    return kpoints, weights, cell, pbc
Пример #6
0
def get_kpoints_data(xmltree, schema_dict, name=None, index=None, logger=None):
    """
    Get the kpoint sets defined in the given fleur xml file.

    .. warning::
        For file versions before Max5 the name argument is not valid

    :param xmltree: etree representing the fleur xml file
    :param schema_dict: schema dictionary corresponding to the file version
                        of the xmltree
    :param name: str, optional, if given only the kpoint set with the given name
                 is returned
    :param index: int, optional, if given only the kpoint set with the given index
                  is returned
    :param logger: logger object for logging warnings, errors

    :returns: tuple containing the kpoint information

    The tuple contains the following entries:

        1. :kpoints: dict or list (list if there is only one kpoint set),
                     containing the coordinates of the kpoints
        2. :weights: dict or list (list if there is only one kpoint set),
                     containing the weights of the kpoints
        3. :cell: numpy array, bravais matrix of the given system
        4. :pbc: list of booleans, determines in which directions periodic boundary conditions are applicable

    """
    from masci_tools.util.schema_dict_util import read_constants, eval_simple_xpath
    from masci_tools.util.schema_dict_util import evaluate_text, evaluate_attribute
    from masci_tools.util.xml.common_functions import clear_xml

    if name is not None and index is not None:
        raise ValueError(
            'Only provide one of index or name to select kpoint lists')

    if isinstance(xmltree, etree._ElementTree):
        xmltree, _ = clear_xml(xmltree)
        root = xmltree.getroot()
    else:
        root = xmltree

    constants = read_constants(root, schema_dict, logger=logger)

    cell, pbc = get_cell(root, schema_dict, logger=logger)

    kpointlists = eval_simple_xpath(root,
                                    schema_dict,
                                    'kPointList',
                                    list_return=True,
                                    logger=logger)

    if len(kpointlists) == 0:
        raise ValueError('No Kpoint lists found in the given inp.xml')

    labels = [kpoint_set.attrib.get('name') for kpoint_set in kpointlists]
    if name is not None and name not in labels:
        raise ValueError(f'Found no Kpoint list with the name: {name}'
                         f'Available list names: {labels}')

    if index is not None:
        try:
            kpointlists = [kpointlists[index]]
        except IndexError as exc:
            raise ValueError(f'No kPointList with index {index} found.'
                             f' Only {len(kpointlists)} available') from exc

    kpoints_data = {}
    weights_data = {}
    for kpointlist in kpointlists:

        label = evaluate_attribute(kpointlist,
                                   schema_dict,
                                   'name',
                                   logger=logger)

        if name is not None and name != label:
            continue

        kpoints = evaluate_text(kpointlist,
                                schema_dict,
                                'kPoint',
                                constants=constants,
                                list_return=True,
                                logger=logger)
        weights = evaluate_attribute(kpointlist,
                                     schema_dict,
                                     'weight',
                                     constants=constants,
                                     list_return=True,
                                     logger=logger)

        if not isinstance(kpoints[0], list):
            kpoints = [kpoints]
            weights = [weights]

        kpoints_data[label] = kpoints
        weights_data[label] = weights

    if len(kpoints_data) == 1:
        _, kpoints_data = kpoints_data.popitem()
        _, weights_data = weights_data.popitem()

    return kpoints_data, weights_data, cell, pbc
Пример #7
0
def get_structure_data(xmltree, schema_dict, logger=None):
    """
    Get the structure defined in the given fleur xml file.

    .. warning::
        Only the explicit definition of the Bravais matrix is supported.
        Old inputs containing the `latnam` definitions are not supported

    :param xmltree: etree representing the fleur xml file
    :param schema_dict: schema dictionary corresponding to the file version
                        of the xmltree
    :param logger: logger object for logging warnings, errors

    :returns: tuple containing the structure information

    The tuple contains the following entries:

        1. :atom_data: list of tuples containing the absolute positions and symbols of the atoms
        2. :cell: numpy array, bravais matrix of the given system
        3. :pbc: list of booleans, determines in which directions periodic boundary conditions are applicable

    """
    from masci_tools.util.schema_dict_util import read_constants, eval_simple_xpath
    from masci_tools.util.schema_dict_util import evaluate_text, evaluate_attribute
    from masci_tools.util.xml.common_functions import clear_xml
    from masci_tools.io.common_functions import rel_to_abs, rel_to_abs_f

    if isinstance(xmltree, etree._ElementTree):
        xmltree, _ = clear_xml(xmltree)
        root = xmltree.getroot()
    else:
        root = xmltree
    constants = read_constants(root, schema_dict, logger=logger)
    cell, pbc = get_cell(root, schema_dict, logger=logger)

    species_names = evaluate_attribute(root,
                                       schema_dict,
                                       'name',
                                       constants=constants,
                                       contains='species',
                                       logger=logger)
    species_elements = evaluate_attribute(root,
                                          schema_dict,
                                          'element',
                                          constants=constants,
                                          contains='species',
                                          logger=logger)

    if not isinstance(species_names, list):
        species_names = [species_names]
    if not isinstance(species_elements, list):
        species_elements = [species_elements]

    if len(species_names) != len(species_elements):
        raise ValueError(
            f'Failed to read in species names and elements. Got {len(species_names)} names and {len(species_elements)} elements'
        )

    species_dict = dict(zip(species_names, species_elements))

    atom_data = []
    atom_groups = eval_simple_xpath(root,
                                    schema_dict,
                                    'atomGroup',
                                    list_return=True,
                                    logger=logger)
    for group in atom_groups:

        group_species = evaluate_attribute(group,
                                           schema_dict,
                                           'species',
                                           constants=constants,
                                           logger=logger)

        atom_positions = []

        absolute_positions = evaluate_text(group,
                                           schema_dict,
                                           'absPos',
                                           constants=constants,
                                           list_return=True,
                                           logger=logger,
                                           optional=True)
        relative_positions = evaluate_text(group,
                                           schema_dict,
                                           'relPos',
                                           constants=constants,
                                           list_return=True,
                                           logger=logger,
                                           optional=True)
        film_positions = evaluate_text(group,
                                       schema_dict,
                                       'filmPos',
                                       constants=constants,
                                       list_return=True,
                                       logger=logger,
                                       optional=True)

        atom_positions = absolute_positions

        for rel_pos in relative_positions:
            atom_positions.append(rel_to_abs(rel_pos, cell))

        for film_pos in film_positions:
            atom_positions.append(rel_to_abs_f(film_pos, cell))

        if len(atom_positions) == 0:
            raise ValueError('Failed to read atom positions for group')

        atom_data.extend(
            (pos, species_dict[group_species]) for pos in atom_positions)

    return atom_data, cell, pbc
Пример #8
0
def get_parameter_data(xmltree,
                       schema_dict,
                       inpgen_ready=True,
                       write_ids=True,
                       logger=None):
    """
    This routine returns an python dictionary produced from the inp.xml
    file, which contains all the parameters needed to setup a new inp.xml from a inpgen
    input file to produce the same input (for parameters that the inpgen can control)

    :param xmltree: etree representing the fleur xml file
    :param schema_dict: schema dictionary corresponding to the file version
                        of the xmltree
    :param inpgen_ready: Bool, return a dict which can be inputed into inpgen while setting atoms
    :param write_ids: Bool, if True the atom ids are added to the atom namelists
    :param logger: logger object for logging warnings, errors

    :returns: dict, which will lead to the same inp.xml (in case if other defaults,
              which can not be controlled by input for inpgen, were changed)

    """
    from masci_tools.util.schema_dict_util import read_constants, eval_simple_xpath
    from masci_tools.util.schema_dict_util import evaluate_attribute, evaluate_text
    from masci_tools.util.xml.common_functions import clear_xml
    from masci_tools.util.xml.converters import convert_fleur_lo
    from masci_tools.io.common_functions import filter_out_empty_dict_entries

    # TODO: convert econfig
    # TODO: parse kpoints, somehow count is bad (if symmetry changes), mesh is not known, path cannot be specified

    ########
    parameters = {}
    if isinstance(xmltree, etree._ElementTree):
        xmltree, _ = clear_xml(xmltree)
        root = xmltree.getroot()
    else:
        root = xmltree
    constants = read_constants(root, schema_dict, logger=logger)

    # Create the cards

    # &input # most things are not needed for AiiDA here. or we ignor them for now.
    # film is set by the plugin depended on the structure
    # symor per default = False? to avoid input which fleur can't take

    # &comp
    comp_dict = {}
    comp_dict['jspins'] = evaluate_attribute(root,
                                             schema_dict,
                                             'jspins',
                                             constants=constants,
                                             logger=logger)
    comp_dict['frcor'] = evaluate_attribute(root,
                                            schema_dict,
                                            'frcor',
                                            constants=constants,
                                            logger=logger,
                                            optional=True)
    comp_dict['ctail'] = evaluate_attribute(root,
                                            schema_dict,
                                            'ctail',
                                            constants=constants,
                                            logger=logger)
    comp_dict['kcrel'] = evaluate_attribute(root,
                                            schema_dict,
                                            'kcrel',
                                            constants=constants,
                                            logger=logger,
                                            optional=True)
    comp_dict['gmax'] = evaluate_attribute(root,
                                           schema_dict,
                                           'gmax',
                                           constants=constants,
                                           logger=logger)
    comp_dict['gmaxxc'] = evaluate_attribute(root,
                                             schema_dict,
                                             'gmaxxc',
                                             constants=constants,
                                             logger=logger)
    comp_dict['kmax'] = evaluate_attribute(root,
                                           schema_dict,
                                           'kmax',
                                           constants=constants,
                                           logger=logger)
    parameters['comp'] = filter_out_empty_dict_entries(comp_dict)

    # &atoms
    species_list = eval_simple_xpath(root,
                                     schema_dict,
                                     'species',
                                     list_return=True,
                                     logger=logger)
    species_several = {}
    # first we see if there are several species with the same atomic number
    for species in species_list:
        atom_z = evaluate_attribute(species,
                                    schema_dict,
                                    'atomicNumber',
                                    constants,
                                    logger=logger)
        species_several[atom_z] = species_several.get(atom_z, 0) + 1

    species_count = {}
    for indx, species in enumerate(species_list):
        atom_dict = {}
        atoms_name = f'atom{indx}'
        atom_z = evaluate_attribute(species,
                                    schema_dict,
                                    'atomicNumber',
                                    constants=constants,
                                    logger=logger)
        if not inpgen_ready:
            atom_dict['z'] = atom_z
        species_count[atom_z] = species_count.get(atom_z, 0) + 1
        atom_id = f'{atom_z}.{species_count[atom_z]}'
        if write_ids:
            if species_several[atom_z] > 1:
                atom_dict['id'] = atom_id

        if schema_dict.inp_version <= (0, 31):
            atom_dict['ncst'] = evaluate_attribute(species,
                                                   schema_dict,
                                                   'coreStates',
                                                   constants,
                                                   logger=logger)
        atom_dict['rmt'] = evaluate_attribute(species,
                                              schema_dict,
                                              'radius',
                                              constants=constants,
                                              logger=logger)
        atom_dict['dx'] = evaluate_attribute(species,
                                             schema_dict,
                                             'logIncrement',
                                             constants=constants,
                                             logger=logger)
        atom_dict['jri'] = evaluate_attribute(species,
                                              schema_dict,
                                              'gridPoints',
                                              constants=constants,
                                              logger=logger)
        atom_dict['lmax'] = evaluate_attribute(species,
                                               schema_dict,
                                               'lmax',
                                               constants=constants,
                                               logger=logger)
        atom_dict['lnonsph'] = evaluate_attribute(species,
                                                  schema_dict,
                                                  'lnonsphr',
                                                  constants=constants,
                                                  logger=logger)
        atom_dict['bmu'] = evaluate_attribute(species,
                                              schema_dict,
                                              'magMom',
                                              constants,
                                              logger=logger,
                                              optional=True)

        atom_dict['element'] = evaluate_attribute(species,
                                                  schema_dict,
                                                  'element',
                                                  constants=constants,
                                                  logger=logger)

        #atom_econfig = eval_simple_xpath(species, schema_dict, 'electronConfig')
        atom_lo = eval_simple_xpath(species,
                                    schema_dict,
                                    'lo',
                                    list_return=True,
                                    logger=logger)
        #atom_econfig = eval_simple_xpath(species, schema_dict, 'electronConfig')

        if len(atom_lo) != 0:
            atom_dict['lo'] = convert_fleur_lo(atom_lo)
        parameters[atoms_name] = filter_out_empty_dict_entries(atom_dict)

    # &soc
    soc = evaluate_attribute(root,
                             schema_dict,
                             'l_soc',
                             constants=constants,
                             logger=logger,
                             optional=True)
    theta = evaluate_attribute(root,
                               schema_dict,
                               'theta',
                               constants=constants,
                               contains='soc',
                               logger=logger,
                               optional=True)
    phi = evaluate_attribute(root,
                             schema_dict,
                             'phi',
                             constants=constants,
                             contains='soc',
                             logger=logger,
                             optional=True)
    if soc is not None and soc:
        parameters['soc'] = {'theta': theta, 'phi': phi}

    # &kpt
    #attrib = convert_from_fortran_bool(eval_xpath(root, l_soc_xpath))
    #theta = eval_xpath(root, theta_xpath)
    #phi = eval_xpath(root, phi_xpath)
    # if kpt:
    #    new_parameters['kpt'] = {'theta' : theta, 'phi' : phi}
    #    # ['nkpt', 'kpts', 'div1', 'div2', 'div3',                         'tkb', 'tria'],

    # title
    title = evaluate_text(root,
                          schema_dict,
                          'comment',
                          constants=constants,
                          logger=logger,
                          optional=True)
    if title:
        parameters['title'] = title.replace('\n', '').strip()

    # &exco
    #TODO, easy
    exco_dict = {}
    exco_dict['xctyp'] = evaluate_attribute(root,
                                            schema_dict,
                                            'name',
                                            constants,
                                            contains='xcFunctional',
                                            logger=logger)
    # 'exco' : ['xctyp', 'relxc'],
    parameters['exco'] = filter_out_empty_dict_entries(exco_dict)
    # &film
    # TODO

    # &qss
    # TODO

    # lattice, not supported?

    return parameters
Пример #9
0
def get_nkpts_max4(xmltree, schema_dict, logger=None):
    """
    Get the number of kpoints that will be used in the calculation specified in the given
    fleur XMl file. Version specific for Max4 versions or older

    .. warning::
        For file versions before Max5 only kPointList or kPointCount tags will work. However,
        for kPointCount there is no real guarantee that for every occasion it will correspond
        to the number of kpoints. So a warning is written out

    :param xmltree: etree representing the fleur xml file
    :param schema_dict: schema dictionary corresponding to the file version
                        of the xmltree
    :param logger: logger object for logging warnings, errors

    :returns: int with the number of kpoints
    """
    from masci_tools.util.schema_dict_util import evaluate_attribute, eval_simple_xpath
    from masci_tools.util.xml.common_functions import clear_xml
    import warnings

    if isinstance(xmltree, etree._ElementTree):
        xmltree, _ = clear_xml(xmltree)
        root = xmltree.getroot()
    else:
        root = xmltree

    modes = get_fleur_modes(root, schema_dict, logger=logger)

    alt_kpt_set = None
    if modes['band'] or modes['gw']:
        expected_mode = 'bands' if modes['band'] else 'gw'
        alt_kpts = eval_simple_xpath(root,
                                     schema_dict,
                                     'altKPointSet',
                                     list_return=True,
                                     logger=logger)
        for kpt_set in alt_kpts:
            if evaluate_attribute(kpt_set,
                                  schema_dict,
                                  'purpose',
                                  logger=logger) == expected_mode:
                alt_kpt_set = kpt_set
                break

    kpt_tag = None
    if alt_kpt_set is not None:
        kpt_tag = eval_simple_xpath(alt_kpt_set,
                                    schema_dict,
                                    'kPointList',
                                    list_return=True,
                                    logger=logger)
        if len(kpt_tag) == 0:
            kpt_tag = eval_simple_xpath(alt_kpt_set,
                                        schema_dict,
                                        'kPointCount',
                                        list_return=True,
                                        logger=logger)
            if len(kpt_tag) == 0:
                kpt_tag = None
            else:
                warnings.warn(
                    'kPointCount is not guaranteed to result in the given number of kpoints'
                )

    if kpt_tag is None:
        kpt_tag = eval_simple_xpath(root,
                                    schema_dict,
                                    'kPointList',
                                    not_contains='altKPointSet',
                                    list_return=True,
                                    logger=logger)
        if len(kpt_tag) == 0:
            kpt_tag = eval_simple_xpath(root,
                                        schema_dict,
                                        'kPointCount',
                                        not_contains='altKPointSet',
                                        list_return=True,
                                        logger=logger)
            if len(kpt_tag) == 0:
                raise ValueError('No kPointList or kPointCount found')
            else:
                warnings.warn(
                    'kPointCount is not guaranteed to result in the given number of kpoints'
                )

    kpt_tag = kpt_tag[0]

    nkpts = evaluate_attribute(kpt_tag, schema_dict, 'count', logger=logger)

    return nkpts
Пример #10
0
def get_fleur_modes(xmltree, schema_dict, logger=None):
    """
    Determine the calculation modes of fleur for the given xml file. Calculation modes
    are things that change the produced files or output in the out.xml files

    :param xmltree: etree representing the fleur xml file
    :param schema_dict: schema dictionary corresponding to the file version
                        of the xmltree
    :param logger: logger object for logging warnings, errors

    :returns: dictionary with all the extracted calculation modes

    The following modes are inspected:

        - `jspin`: How many spins are considered in the calculation
        - `noco`: Is the calculation non-collinear?
        - `soc`: Is spin-orbit coupling included?
        - `relax`: Is the calculation a structure relaxation?
        - `gw`: Special mode for GW/Spex calculations
        - `force_theorem`: Is a Force theorem calculation performed?
        - `film`: Is the structure a film system
        - `ldau`: Is LDA+U included?
        - `dos`: Is it a density of states calculation?
        - `band`: Is it a bandstructure calculation?
        - `bz_integration`: How is the integration over the Brillouin-Zone performed?

    """
    from masci_tools.util.schema_dict_util import read_constants
    from masci_tools.util.schema_dict_util import evaluate_attribute, tag_exists
    from masci_tools.util.xml.common_functions import clear_xml

    if isinstance(xmltree, etree._ElementTree):
        xmltree, _ = clear_xml(xmltree)
        root = xmltree.getroot()
    else:
        root = xmltree
    constants = read_constants(root, schema_dict)

    fleur_modes = {}
    fleur_modes['jspin'] = evaluate_attribute(root,
                                              schema_dict,
                                              'jspins',
                                              logger=logger,
                                              constants=constants)

    noco = evaluate_attribute(root,
                              schema_dict,
                              'l_noco',
                              constants=constants,
                              logger=logger,
                              optional=True)
    if noco is None:
        noco = False
    fleur_modes['noco'] = noco

    soc = evaluate_attribute(root,
                             schema_dict,
                             'l_soc',
                             constants=constants,
                             logger=logger,
                             optional=True)
    if soc is None:
        soc = False
    fleur_modes['soc'] = soc

    relax = evaluate_attribute(root,
                               schema_dict,
                               'l_f',
                               constants=constants,
                               logger=logger,
                               optional=True)
    if relax is None:
        relax = False
    fleur_modes['relax'] = relax

    gw = evaluate_attribute(root,
                            schema_dict,
                            'gw',
                            constants=constants,
                            logger=logger,
                            optional=True)
    if gw is None:
        gw = False
    else:
        gw = gw != 0
    fleur_modes['gw'] = gw

    if schema_dict.inp_version > (0, 27):
        fleur_modes['force_theorem'] = tag_exists(root,
                                                  schema_dict,
                                                  'forceTheorem',
                                                  logger=logger)
    else:
        fleur_modes['force_theorem'] = False

    if schema_dict.inp_version >= (0, 33):
        if tag_exists(root, schema_dict, 'cFCoeffs', logger=logger):
            cf_coeff = any(
                evaluate_attribute(root,
                                   schema_dict,
                                   'potential',
                                   contains='cFCoeffs',
                                   logger=logger,
                                   list_return=True,
                                   optional=True))
            cf_coeff = cf_coeff or any(
                evaluate_attribute(root,
                                   schema_dict,
                                   'chargeDensity',
                                   contains='cFCoeffs',
                                   logger=logger,
                                   list_return=True,
                                   optional=True))
        else:
            cf_coeff = False
        fleur_modes['cf_coeff'] = cf_coeff
    else:
        fleur_modes['cf_coeff'] = False

    plot = None
    if tag_exists(root, schema_dict, 'plotting', logger=logger):
        plot = evaluate_attribute(root,
                                  schema_dict,
                                  'iplot',
                                  logger=logger,
                                  optional=True)

    if schema_dict.inp_version >= (0, 29) and plot is not None:
        plot = isinstance(plot, int) and plot != 0

    if plot is None:
        plot = False
    fleur_modes['plot'] = plot

    fleur_modes['film'] = tag_exists(root,
                                     schema_dict,
                                     'filmPos',
                                     logger=logger)
    fleur_modes['ldau'] = tag_exists(root,
                                     schema_dict,
                                     'ldaU',
                                     contains='species',
                                     logger=logger)
    fleur_modes['dos'] = evaluate_attribute(root,
                                            schema_dict,
                                            'dos',
                                            constants=constants,
                                            logger=logger)
    fleur_modes['band'] = evaluate_attribute(root,
                                             schema_dict,
                                             'band',
                                             constants=constants,
                                             logger=logger)
    fleur_modes['bz_integration'] = evaluate_attribute(
        root,
        schema_dict,
        'mode',
        constants=constants,
        tag_name='bzIntegration',
        logger=logger)

    greensf = False
    if schema_dict.inp_version >= (0, 32):
        #We make the assumption that the existence of a greensfCalculation
        #tag implies the existence of a greens function calculation
        greensf = tag_exists(root,
                             schema_dict,
                             'greensfCalculation',
                             contains='species',
                             logger=logger)
        greensf = greensf or tag_exists(root,
                                        schema_dict,
                                        'torgueCalculation',
                                        contains='species',
                                        logger=logger)
    fleur_modes['greensf'] = greensf

    ldahia = False
    if schema_dict.inp_version >= (0, 32):
        ldahia = tag_exists(root,
                            schema_dict,
                            'ldaHIA',
                            contains='species',
                            logger=logger)
    fleur_modes['ldahia'] = ldahia

    return fleur_modes
def test_evaluate_attribute(caplog):
    """
    Test of the evaluate_attribute function
    """
    from lxml import etree
    from masci_tools.util.schema_dict_util import evaluate_attribute

    schema_dict = schema_dict_34

    parser = etree.XMLParser(attribute_defaults=True,
                             recover=False,
                             encoding='utf-8')
    xmltree = etree.parse(TEST_INPXML_PATH, parser)
    outxmltree = etree.parse(TEST_OUTXML_PATH2, parser)

    outroot = outxmltree.getroot()
    root = xmltree.getroot()

    assert evaluate_attribute(root, schema_dict, 'jspins',
                              FLEUR_DEFINED_CONSTANTS) == 2
    assert evaluate_attribute(root, schema_dict, 'l_noco',
                              FLEUR_DEFINED_CONSTANTS)
    assert evaluate_attribute(root, schema_dict, 'mode',
                              FLEUR_DEFINED_CONSTANTS) == 'hist'
    assert pytest.approx(
        evaluate_attribute(root,
                           schema_dict,
                           'radius',
                           FLEUR_DEFINED_CONSTANTS,
                           contains='species')) == [2.2, 2.2]

    with pytest.raises(
            ValueError,
            match=
            'The attrib beta has multiple possible paths with the current specification.'
    ):
        evaluate_attribute(root,
                           schema_dict,
                           'beta',
                           FLEUR_DEFINED_CONSTANTS,
                           exclude=['unique'])

    assert pytest.approx(
        evaluate_attribute(
            root,
            schema_dict,
            'beta',
            FLEUR_DEFINED_CONSTANTS,
            exclude=['unique'],
            contains='nocoParams',
            not_contains='species')) == [np.pi / 2.0, np.pi / 2.0]

    assert pytest.approx(
        evaluate_attribute(
            root,
            schema_dict,
            'beta',
            FLEUR_DEFINED_CONSTANTS,
            tag_name='nocoParams',
            not_contains='species')) == [np.pi / 2.0, np.pi / 2.0]

    with pytest.raises(ValueError,
                       match='No attribute TEST found at tag nocoParams'):
        evaluate_attribute(
            root,
            schema_dict,
            'TEST',
            FLEUR_DEFINED_CONSTANTS,
            tag_name='nocoParams',
            not_contains='species',
        )

    assert pytest.approx(
        evaluate_attribute(
            outroot,
            outschema_dict_34,
            'beta',
            FLEUR_DEFINED_CONSTANTS,
            exclude=['unique'],
            contains='nocoParams',
            not_contains='species')) == [np.pi / 2.0, np.pi / 2.0]

    iteration = outroot.xpath('//iteration')[0]

    assert evaluate_attribute(iteration,
                              outschema_dict_34,
                              'units',
                              FLEUR_DEFINED_CONSTANTS,
                              tag_name='Forcetheorem_SSDISP') == 'Htr'

    with pytest.raises(
            ValueError,
            match='No attribute TEST found at tag Forcetheorem_SSDISP'):
        evaluate_attribute(iteration,
                           outschema_dict_34,
                           'TEST',
                           FLEUR_DEFINED_CONSTANTS,
                           tag_name='Forcetheorem_SSDISP')

    with pytest.raises(
            ValueError,
            match=
            'The attrib spinf has multiple possible paths with the current specification.'
    ):
        evaluate_attribute(root, schema_dict, 'spinf', FLEUR_DEFINED_CONSTANTS)

    with pytest.raises(ValueError,
                       match='No values found for attribute radius'):
        evaluate_attribute(root,
                           schema_dict,
                           'radius',
                           FLEUR_DEFINED_CONSTANTS,
                           not_contains='species')

    with caplog.at_level(logging.WARNING):
        assert evaluate_attribute(root,
                                  schema_dict,
                                  'radius',
                                  FLEUR_DEFINED_CONSTANTS,
                                  not_contains='species',
                                  logger=LOGGER) is None
    assert 'No values found for attribute radius' in caplog.text
Пример #12
0
def set_kpointlist(xmltree,
                   schema_dict,
                   kpoints,
                   weights,
                   name=None,
                   kpoint_type='path',
                   special_labels=None,
                   switch=False,
                   overwrite=False):
    """
    Explicitely create a kPointList from the given kpoints and weights. This routine will add the
    specified kPointList with the given name.

    .. warning::
        For input versions Max4 and older **all** keyword arguments are not valid (`name`, `kpoint_type`,
        `special_labels`, `switch` and `overwrite`)

    :param xmltree: xml tree that represents inp.xml
    :param schema_dict: InputSchemaDict containing all information about the structure of the input
    :param kpoints: list or array containing the **relative** coordinates of the kpoints
    :param weights: list or array containing the weights of the kpoints
    :param name: str for the name of the list, if not given a default name is generated
    :param kpoint_type: str specifying the type of the kPointList ('path', 'mesh', 'spex', 'tria', ...)
    :param special_labels: dict mapping indices to labels. The labels will be inserted for the kpoints
                           corresponding to the given index
    :param switch: bool, if True the kPointlist will be used by Fleur when starting the next calculation
    :param overwrite: bool, if True and a kPointlist with the given name already exists it will be overwritten

    :returns: an xmltree of the inp.xml file with changes.
    """
    from masci_tools.util.schema_dict_util import evaluate_attribute
    from masci_tools.util.xml.converters import convert_text_to_xml, convert_attribute_to_xml
    from masci_tools.util.xml.xml_setters_basic import xml_delete_tag
    from lxml import etree
    import numpy as np

    if not isinstance(kpoints, (list, np.ndarray)) or not isinstance(weights, (list, np.ndarray)):
        raise ValueError('kPoints and weights have to be given as a list or array')

    if len(kpoints) != len(weights):
        raise ValueError('kPoints and weights do not have the same length')

    kpointlist_xpath = get_tag_xpath(schema_dict, 'kPointList')
    nkpts = len(kpoints)

    if special_labels is None:
        special_labels = {}

    existing_labels = evaluate_attribute(xmltree, schema_dict, 'name', contains='kPointList', list_return=True)

    if name is None:
        name = f'default-{len(existing_labels)+1}'

    if name in existing_labels:
        if not overwrite:
            raise ValueError(f'kPointList named {name} already exists. Use overwrite=True to ignore')

        xmltree = xml_delete_tag(xmltree, f"{kpointlist_xpath}[@name='{name}']")

    new_kpo = etree.Element('kPointList', name=name, count=f'{nkpts:d}', type=kpoint_type)
    for indx, (kpoint, weight) in enumerate(zip(kpoints, weights)):
        weight, _ = convert_attribute_to_xml(weight, ['float', 'float_expression'])
        if indx in special_labels:
            new_k = etree.Element('kPoint', weight=weight, label=special_labels[indx])
        else:
            new_k = etree.Element('kPoint', weight=weight)
        text, _ = convert_text_to_xml(kpoint, [{'type': ['float', 'float_expression'], 'length': 3}])
        new_k.text = text
        new_kpo.append(new_k)

    xmltree = create_tag(xmltree, schema_dict, new_kpo)

    if switch:
        xmltree = switch_kpointset(xmltree, schema_dict, name)

    return xmltree
Пример #13
0
def outxml_parser(outxmlfile,
                  version=None,
                  parser_info_out=None,
                  iteration_to_parse=None,
                  strict=False,
                  debug=False,
                  **kwargs):
    """
    Parses the out.xml file to a dictionary based on the version and the given tasks

    :param outxmlfile: either path to the out.xml file, opened file handle or a xml etree to be parsed
    :param version: version string to enforce that a given schema is used
    :param parser_info_out: dict, with warnings, info, errors, ...
    :param iteration_to_parse: either str or int, (optional, default 'last')
                               determines which iteration should be parsed.
                               Accepted are 'all', 'first', 'last' or an index for the iteration
    :param strict: bool if True  and no parser_info_out is provided any encountered error will immediately be raised
    :param debug: bool if True additional information is printed out in the logs

    Kwargs:
        :param ignore_validation: bool, if True schema validation errors are only logged
        :param minimal_mode: bool, if True only total Energy, iteration number and distances are parsed
        :param list_return: bool, if True one-item lists in the output dict are not converted to simple values
        :param additional_tasks: dict to define custom parsing tasks. For detailed explanation
                                 See :py:mod:`~masci_tools.io.parsers.fleur.default_parse_tasks`.
        :param overwrite: bool, if True and keys in additional_tasks collide with defaults
                          The defaults will be overwritten
        :param append: bool, if True and keys in additional_tasks collide with defaults
                       The inner tasks will be written into the dict. If inner keys collide
                       they are overwritten

    :return: python dictionary with the information parsed from the out.xml

    :raises ValueError: If the validation against the schema failed, or an irrecoverable error
                        occured during parsing
    :raises FileNotFoundError: If no Schema file for the given version was found
    :raises KeyError: If an unknown task is encountered

    """

    __parser_version__ = '0.5.0'

    logger = logging.getLogger(__name__)

    parser_log_handler = None
    if parser_info_out is not None or not strict:
        if parser_info_out is None:
            parser_info_out = {}

        logging_level = logging.INFO
        if debug:
            logging_level = logging.DEBUG
        logger.setLevel(logging_level)

        parser_log_handler = DictHandler(parser_info_out,
                                         WARNING='parser_warnings',
                                         ERROR='parser_errors',
                                         INFO='parser_info',
                                         DEBUG='parser_debug',
                                         CRITICAL='parser_critical',
                                         ignore_unknown_levels=True,
                                         level=logging_level)

        logger.addHandler(parser_log_handler)

    if strict:
        logger = None

    if logger is not None:
        logger.info('Masci-Tools Fleur out.xml Parser v%s', __parser_version__)

    outfile_broken = False

    if isinstance(outxmlfile, etree._ElementTree):
        xmltree = outxmlfile
    else:
        parser = etree.XMLParser(attribute_defaults=True,
                                 recover=False,
                                 encoding='utf-8')

        try:
            xmltree = etree.parse(outxmlfile, parser)
        except etree.XMLSyntaxError:
            outfile_broken = True
            if logger is None:
                warnings.warn('The out.xml file is broken I try to repair it.')
            else:
                logger.warning(
                    'The out.xml file is broken I try to repair it.')

        if outfile_broken:
            # repair xmlfile and try to parse what is possible.
            parser = etree.XMLParser(attribute_defaults=True,
                                     recover=True,
                                     encoding='utf-8')
            try:
                xmltree = etree.parse(outxmlfile, parser)
            except etree.XMLSyntaxError:
                if logger is None:
                    raise
                else:
                    logger.exception('Skipping the parsing of the xml file. '
                                     'Repairing was not possible.')
                    return {}

    if version is None:
        out_version = eval_xpath(xmltree,
                                 '//@fleurOutputVersion',
                                 logger=logger)
        out_version = str(out_version)
        if out_version is None:
            logger.error('Failed to extract outputVersion')
            raise ValueError('Failed to extract outputVersion')
    else:
        out_version = version

    if out_version == '0.27':
        program_version = eval_xpath(xmltree,
                                     '//programVersion/@version',
                                     logger=logger)
        if program_version == 'fleur 32':
            #Max5 release (before bugfix)
            out_version = '0.33'
            inp_version = '0.33'
            ignore_validation = True
            if logger is not None:
                logger.warning(
                    "Ignoring '0.27' outputVersion for MaX5.0 release")
            else:
                warnings.warn(
                    "Ignoring '0.27' outputVersion for MaX5.0 release")
        elif program_version == 'fleur 31':
            #Max4 release
            out_version = '0.31'
            inp_version = '0.31'
            ignore_validation = True
            if logger is not None:
                logger.warning(
                    "Ignoring '0.27' outputVersion for MaX4.0 release")
            else:
                warnings.warn(
                    "Ignoring '0.27' outputVersion for MaX4.0 release")
        elif program_version == 'fleur 30':
            #Max3.1 release
            out_version = '0.30'
            inp_version = '0.30'
            ignore_validation = True
            if logger is not None:
                logger.warning(
                    "Ignoring '0.27' outputVersion for MaX3.1 release")
            else:
                warnings.warn(
                    "Ignoring '0.27' outputVersion for MaX3.1 release")
        elif program_version == 'fleur 27':
            #Max3.1 release
            out_version = '0.29'
            inp_version = '0.29'
            ignore_validation = True
            if logger is not None:
                logger.warning(
                    "Found version before MaX3.1 release falling back to file version '0.29'"
                )
            warnings.warn(
                'out.xml files before the MaX3.1 release are not explicitely supported.'
                ' No guarantee is given that the parser will work without error',
                UserWarning)
        else:
            if logger is not None:
                logger.error(
                    "Unknown fleur version: File-version '%s' Program-version '%s'",
                    out_version, program_version)
            raise ValueError(
                f"Unknown fleur version: File-version '{out_version}' Program-version '{program_version}'"
            )
    else:
        ignore_validation = False
        inp_version = eval_xpath(xmltree,
                                 '//@fleurInputVersion',
                                 logger=logger)
        inp_version = str(inp_version)
        if inp_version is None:
            if logger is not None:
                logger.error('Failed to extract inputVersion')
            raise ValueError('Failed to extract inputVersion')

    ignore_validation = kwargs.get('ignore_validation', ignore_validation)

    #Load schema_dict (inp and out)
    outschema_dict = OutputSchemaDict.fromVersion(out_version,
                                                  inp_version=inp_version,
                                                  logger=logger)

    if outschema_dict['out_version'] != out_version or \
       outschema_dict['inp_version'] != inp_version:
        ignore_validation = True
        out_version = outschema_dict['out_version']
        inp_version = outschema_dict['inp_version']

    if logger is not None:
        logger.info('Found fleur out file with the versions out: %s; inp: %s',
                    out_version, inp_version)

    xmltree, _ = clear_xml(xmltree)
    root = xmltree.getroot()

    errmsg = ''
    try:
        validate_xml(
            xmltree,
            outschema_dict.xmlschema,
            error_header='Output file does not validate against the schema')
    except etree.DocumentInvalid as err:
        errmsg = str(err)
        if logger is not None:
            logger.warning(errmsg)
        if not ignore_validation:
            if logger is not None:
                logger.exception(errmsg)
            raise ValueError(errmsg) from err

    if not outschema_dict.xmlschema.validate(xmltree) and errmsg == '':
        msg = 'Output file does not validate against the schema: Reason is unknown'
        if logger is not None:
            logger.warning(msg)
        if not ignore_validation:
            if logger is not None:
                logger.exception(msg)
            raise ValueError(msg)

    parser = ParseTasks(out_version)
    additional_tasks = kwargs.pop('additional_tasks', {})
    for task_name, task_definition in additional_tasks.items():
        parser.add_task(task_name, task_definition, **kwargs)

    out_dict, constants = parse_general_information(
        root,
        parser,
        outschema_dict,
        logger=logger,
        iteration_to_parse=iteration_to_parse,
        **kwargs)

    out_dict['input_file_version'] = outschema_dict['inp_version']
    # get all iterations in out.xml file
    iteration_nodes = eval_simple_xpath(root,
                                        outschema_dict,
                                        'iteration',
                                        logger=logger,
                                        list_return=True)
    n_iters = len(iteration_nodes)

    # parse only last stable interation
    # (if modes (dos and co) maybe parse anyway if broken?)
    if outfile_broken and (n_iters >= 2):
        iteration_nodes = iteration_nodes[:-2]
        if logger is not None:
            logger.info('The last parsed iteration is %s', n_iters - 2)
    elif outfile_broken and (n_iters == 1):
        iteration_nodes = [iteration_nodes[0]]
        if logger is not None:
            logger.info('The last parsed iteration is %s', n_iters)
    elif not outfile_broken and (n_iters >= 1):
        pass
    else:  # there was no iteration found.
        # only the starting charge density could be generated
        msg = 'There was no iteration found in the outfile, either just a ' \
              'starting density was generated or something went wrong.'
        if logger is None:
            raise ValueError(msg)
        else:
            logger.error(msg)

    if iteration_to_parse is None:
        iteration_to_parse = 'last'  #This is the default from the aiida_fleur parser

    if iteration_to_parse == 'last':
        iteration_nodes = iteration_nodes[-1]
    elif iteration_to_parse == 'first':
        iteration_nodes = iteration_nodes[0]
    elif iteration_to_parse == 'all':
        pass
    elif isinstance(iteration_to_parse, int):
        try:
            iteration_nodes = iteration_nodes[iteration_to_parse]
        except IndexError as exc:
            if logger is not None:
                logger.exception(exc)
            raise ValueError(
                f"Invalid value for iteration_to_parse: Got '{iteration_to_parse}'"
                f"; but only '{len(iteration_nodes)}' iterations are available"
            ) from exc
    else:
        if logger is not None:
            logger.error(
                "Invalid value for iteration_to_parse: Got '%s' "
                "Valid values are: 'first', 'last', 'all', or int",
                iteration_to_parse)
        raise ValueError(
            f"Invalid value for iteration_to_parse: Got '{iteration_to_parse}' "
            "Valid values are: 'first', 'last', 'all', or int")

    if not isinstance(iteration_nodes, list):
        iteration_nodes = [iteration_nodes]

    logger_info = {'iteration': 'unknown'}
    iteration_logger = OutParserLogAdapter(logger, logger_info)

    for node in iteration_nodes:
        iteration_number = evaluate_attribute(node,
                                              outschema_dict,
                                              'numberForCurrentRun',
                                              optional=True)

        if iteration_number is not None:
            logger_info['iteration'] = iteration_number

        out_dict = parse_iteration(node,
                                   parser,
                                   outschema_dict,
                                   out_dict,
                                   constants,
                                   logger=iteration_logger,
                                   **kwargs)

        logger_info['iteration'] = 'unknown'

    if not kwargs.get('list_return', False):
        #Convert one item lists to simple values
        for key, value in out_dict.items():
            if isinstance(value, list):
                if len(value) == 1:
                    out_dict[key] = value[0]
            elif isinstance(value, dict):
                for subkey, subvalue in value.items():
                    if isinstance(subvalue, list):
                        if len(subvalue) == 1:
                            out_dict[key][subkey] = subvalue[0]

    if parser_log_handler is not None:
        if logger is not None:
            logger.removeHandler(parser_log_handler)

    return out_dict
def validate_nmmpmat(xmltree, nmmplines, schema_dict):
    """
    Checks that the given nmmp_lines is valid with the given xmltree

    Checks that the number of blocks is as expected from the inp.xml and each
    block does not contain non-zero elements outside their size given by the
    orbital quantum number in the inp.xml. Additionally the occupations, i.e.
    diagonal elements are checked that they are in between 0 and the maximum
    possible occupation

    :param xmltree: an xmltree that represents inp.xml
    :param nmmplines: list of lines in the n_mmp_mat file

    :raises ValueError: if any of the above checks are violated.
    """
    from masci_tools.util.xml.common_functions import get_xml_attribute
    from masci_tools.util.schema_dict_util import evaluate_attribute, eval_simple_xpath, attrib_exists

    nspins = evaluate_attribute(xmltree, schema_dict, 'jspins')
    if 'l_mtnocoPot' in schema_dict['attrib_types']:
        if attrib_exists(xmltree, schema_dict, 'l_mtnocoPot', contains='Setup'):
            if evaluate_attribute(xmltree, schema_dict, 'l_mtnocoPot', contains='Setup'):
                nspins = 3

    all_ldau = eval_simple_xpath(xmltree, schema_dict, 'ldaU', contains='species', list_return=True)
    numRows = nspins * 14 * len(all_ldau)

    tol = 0.01
    if nspins > 1:
        maxOcc = 1.0
    else:
        maxOcc = 2.0

    #Check that numRows matches the number of lines in nmmp_lines
    if nmmplines is not None:
        #Remove blank lines
        while '' in nmmplines:
            nmmplines.remove('')
        if numRows != len(nmmplines):
            raise ValueError('The number of lines in n_mmp_mat does not match the number expected from '+\
                             'the inp.xml file.')
    else:
        return

    #Now check for each block if the numbers make sense
    #(no numbers outside the valid area and no nonsensical occupations)
    for ldau_index, ldau in enumerate(all_ldau):

        orbital = evaluate_attribute(ldau, schema_dict, 'l', contains='species')
        species_name = get_xml_attribute(ldau.getparent(), 'name')

        for spin in range(nspins):
            startRow = (spin * len(all_ldau) + ldau_index) * 14

            for index in range(startRow, startRow + 14):
                currentLine = index - startRow
                currentRow = currentLine // 2

                line = nmmplines[index].split('    ')
                while '' in line:
                    line.remove('')
                nmmp = np.array([float(x) for x in line])

                outside_val = False
                if abs(currentRow - 3) > orbital:
                    if any(np.abs(nmmp) > 1e-12):
                        outside_val = True

                if currentLine % 2 == 0:
                    #m=-3 to m=0 real part
                    if any(np.abs(nmmp[:(3 - orbital) * 2]) > 1e-12):
                        outside_val = True

                else:
                    #m=0 imag part to m=3
                    if any(np.abs(nmmp[orbital * 2 + 1:]) > 1e-12):
                        outside_val = True

                if outside_val:
                    raise ValueError(f'Found value outside of valid range in for species {species_name}, spin {spin+1}'
                                     f' and l={orbital}')

                invalid_diag = False
                if spin < 2:
                    if currentRow - 3 <= 0 and currentLine % 2 == 0:
                        if nmmp[currentRow * 2] < -tol or nmmp[currentRow * 2] > maxOcc + tol:
                            invalid_diag = True
                    else:
                        if nmmp[(currentRow - 3) * 2 - 1] < -tol or nmmp[(currentRow - 3) * 2 - 1] > maxOcc + tol:
                            invalid_diag = True

                if invalid_diag:
                    raise ValueError(f'Found invalid diagonal element for species {species_name}, spin {spin+1}'
                                     f' and l={orbital}')
def set_nmmpmat(xmltree,
                nmmplines,
                schema_dict,
                species_name,
                orbital,
                spin,
                state_occupations=None,
                orbital_occupations=None,
                denmat=None,
                phi=None,
                theta=None):
    """Routine sets the block in the n_mmp_mat file specified by species_name, orbital and spin
    to the desired density matrix

    :param xmltree: an xmltree that represents inp.xml
    :param nmmplines: list of lines in the n_mmp_mat file
    :param schema_dict: InputSchemaDict containing all information about the structure of the input
    :param species_name: string, name of the species you want to change
    :param orbital: integer, orbital quantum number of the LDA+U procedure to be modified
    :param spin: integer, specifies which spin block should be modified
    :param state_occupations: list, sets the diagonal elements of the density matrix and everything
                              else to zero
    :param denmat: matrix, specify the density matrix explicitely
    :param phi: float, optional angle (radian), by which to rotate the density matrix before writing it
    :param theta: float, optional angle (radian), by which to rotate the density matrix before writing it

    :raises ValueError: If something in the input is wrong
    :raises KeyError: If no LDA+U procedure is found on a species

    :returns: list with modified nmmplines
    """
    from masci_tools.util.xml.common_functions import eval_xpath, get_xml_attribute
    from masci_tools.util.schema_dict_util import evaluate_attribute, eval_simple_xpath, attrib_exists
    from masci_tools.io.io_nmmpmat import write_nmmpmat, write_nmmpmat_from_states, write_nmmpmat_from_orbitals

    #All lda+U procedures have to be considered since we need to keep the order
    species_base_path = get_tag_xpath(schema_dict, 'species')

    if species_name == 'all':
        species_xpath = species_base_path
    elif species_name[:4] == 'all-':  #format all-<string>
        species_xpath = f'{species_base_path}[contains(@name,"{species_name[4:]}")]'
    else:
        species_xpath = f'{species_base_path}[@name = "{species_name}"]'

    all_species = eval_xpath(xmltree, species_xpath, list_return=True)

    nspins = evaluate_attribute(xmltree, schema_dict, 'jspins')
    if 'l_mtnocoPot' in schema_dict['attrib_types']:
        if attrib_exists(xmltree, schema_dict, 'l_mtnocoPot', contains='Setup'):
            if evaluate_attribute(xmltree, schema_dict, 'l_mtnocoPot', contains='Setup'):
                nspins = 3

    if spin > nspins:
        raise ValueError(f'Invalid input: spin {spin} requested, but input has only {nspins} spins')

    all_ldau = eval_simple_xpath(xmltree, schema_dict, 'ldaU', contains='species', list_return=True)
    numRows = nspins * 14 * len(all_ldau)

    if state_occupations is not None:
        new_nmmpmat_entry = write_nmmpmat_from_states(orbital, state_occupations, phi=phi, theta=theta)
    elif orbital_occupations is not None:
        new_nmmpmat_entry = write_nmmpmat_from_orbitals(orbital, orbital_occupations, phi=phi, theta=theta)
    elif denmat is not None:
        new_nmmpmat_entry = write_nmmpmat(orbital, denmat, phi=phi, theta=theta)
    else:
        raise ValueError('Invalid definition of density matrix. Provide either state_occupations, '
                         'orbital_occupations or denmat')

    #Check that numRows matches the number of lines in nmmp_lines_copy
    #If not either there was an n_mmp_mat file present in Fleurinp before and a lda+u calculation
    #was added or removed or the n_mmp_mat file was initialized and after the fact lda+u procedures were added
    #or removed. In both cases the resolution of this modification is very involved so we throw an error
    if nmmplines is not None:
        #Remove blank lines
        while '' in nmmplines:
            nmmplines.remove('')
        if numRows != len(nmmplines):
            raise ValueError('The number of lines in n_mmp_mat does not match the number expected from '+\
                             'the inp.xml file. Either remove the existing file before making modifications '+\
                             'and only use set_nmmpmat after all modifications to the inp.xml')

    for species in all_species:
        current_name = get_xml_attribute(species, 'name')

        #Determine the place at which the given U procedure occurs
        ldau_index = None
        for index, ldau in enumerate(all_ldau):
            ldau_species = get_xml_attribute(ldau.getparent(), 'name')
            ldau_orbital = evaluate_attribute(ldau, schema_dict, 'l', contains='species')
            if current_name == ldau_species and ldau_orbital == orbital:
                ldau_index = index

        if ldau_index is None:
            raise KeyError(f'No LDA+U procedure found on species {current_name} with l={orbital}')

        #check if fleurinp has a specified n_mmp_mat file if not initialize it with 0
        if nmmplines is None:
            nmmplines = []
            for index in range(numRows):
                nmmplines.append(''.join(map(str, [f'{0.0:20.13f}' for x in range(7)])))

        #Select the right block from n_mmp_mat and overwrite it with denmatpad
        startRow = ((spin - 1) * len(all_ldau) + ldau_index) * 14

        nmmplines[startRow:startRow + 14] = new_nmmpmat_entry

    return nmmplines
def rotate_nmmpmat(xmltree, nmmplines, schema_dict, species_name, orbital, phi, theta):
    """
    Rotate the density matrix with the given angles phi and theta

    :param xmltree: an xmltree that represents inp.xml
    :param nmmplines: list of lines in the n_mmp_mat file
    :param schema_dict: InputSchemaDict containing all information about the structure of the input
    :param species_name: string, name of the species you want to change
    :param orbital: integer, orbital quantum number of the LDA+U procedure to be modified
    :param phi: float, angle (radian), by which to rotate the density matrix
    :param theta: float, angle (radian), by which to rotate the density matrix

    :raises ValueError: If something in the input is wrong
    :raises KeyError: If no LDA+U procedure is found on a species

    :returns: list with modified nmmplines
    """
    from masci_tools.util.xml.common_functions import eval_xpath, get_xml_attribute
    from masci_tools.util.schema_dict_util import evaluate_attribute, eval_simple_xpath, attrib_exists
    from masci_tools.io.io_nmmpmat import read_nmmpmat_block, rotate_nmmpmat_block, format_nmmpmat

    species_base_path = get_tag_xpath(schema_dict, 'species')

    if species_name == 'all':
        species_xpath = species_base_path
    elif species_name[:4] == 'all-':  #format all-<string>
        species_xpath = f'{species_base_path}[contains(@name,"{species_name[4:]}")]'
    else:
        species_xpath = f'{species_base_path}[@name = "{species_name}"]'

    all_species = eval_xpath(xmltree, species_xpath, list_return=True)

    nspins = evaluate_attribute(xmltree, schema_dict, 'jspins')
    if 'l_mtnocoPot' in schema_dict['attrib_types']:
        if attrib_exists(xmltree, schema_dict, 'l_mtnocoPot', contains='Setup'):
            if evaluate_attribute(xmltree, schema_dict, 'l_mtnocoPot', contains='Setup'):
                nspins = 3

    all_ldau = eval_simple_xpath(xmltree, schema_dict, 'ldaU', contains='species', list_return=True)
    numRows = nspins * 14 * len(all_ldau)

    #Check that numRows matches the number of lines in nmmp_lines_copy
    #If not either there was an n_mmp_mat file present in Fleurinp before and a lda+u calculation
    #was added or removed or the n_mmp_mat file was initialized and after the fact lda+u procedures were added
    #or removed. In both cases the resolution of this modification is very involved so we throw an error
    if nmmplines is not None:
        #Remove blank lines
        while '' in nmmplines:
            nmmplines.remove('')
        if numRows != len(nmmplines):
            raise ValueError('The number of lines in n_mmp_mat does not match the number expected from '+\
                             'the inp.xml file. Either remove the existing file before making modifications '+\
                             'and only use set_nmmpmat after all modifications to the inp.xml')
    else:
        raise ValueError('rotate_nmmpmat has to be called with a initialized density matrix')

    for species in all_species:
        current_name = get_xml_attribute(species, 'name')

        #Determine the place at which the given U procedure occurs
        ldau_index = None
        for index, ldau in enumerate(all_ldau):
            ldau_species = get_xml_attribute(ldau.getparent(), 'name')
            ldau_orbital = evaluate_attribute(ldau, schema_dict, 'l', contains='species')
            if current_name == ldau_species and ldau_orbital == orbital:
                ldau_index = index

        if ldau_index is None:
            raise KeyError(f'No LDA+U procedure found on species {current_name} with l={orbital}')

        denmat = []

        for spin in range(nspins):

            startRow = (spin * len(all_ldau) + ldau_index) * 14
            denmat = read_nmmpmat_block(nmmplines, spin * len(all_ldau) + ldau_index)

            denmat = rotate_nmmpmat_block(denmat, orbital, phi=phi, theta=theta)

            nmmplines[startRow:startRow + 14] = format_nmmpmat(denmat)

    return nmmplines