Exemple #1
0
def psml_family(generate_psml_data):
    """Create a pseudopotential family with PsmlData potentials from scratch."""
    from aiida import plugins

    PsmlData = plugins.DataFactory('pseudo.psml')  # pylint: disable=invalid-name
    PseudoPotentialFamily = plugins.GroupFactory('pseudo.family')  # pylint: disable=invalid-name
    label = 'nc-sr-04_pbe_standard_psml'

    try:
        family = PseudoPotentialFamily.objects.get(label=label)
    except exceptions.NotExistent:
        pass
    else:
        return family

    with tempfile.TemporaryDirectory() as dirpath:
        for values in elements.values():

            element = values['symbol']
            upf = generate_psml_data(element)
            filename = os.path.join(dirpath, f'{element}.psml')

            with open(filename, 'w+b') as handle:
                with upf.open(mode='rb') as source:
                    handle.write(source.read())
                    handle.flush()

        family = PseudoPotentialFamily.create_from_folder(dirpath, label, pseudo_type=PsmlData)

    return family
def test_spec(workchain):
    """Test that the input specification of all implementations respects the common interface."""
    generator = workchain.get_input_generator()
    generator_spec = generator.spec()

    required_ports = {
        'structure': {
            'valid_type': plugins.DataFactory('structure')
        },
        'protocol': {
            'valid_type': str
        },
        'spin_type': {
            'valid_type': SpinType
        },
        'relax_type': {
            'valid_type': RelaxType
        },
        'electronic_type': {
            'valid_type': ElectronicType
        },
        'magnetization_per_site': {
            'valid_type': list
        },
        'threshold_forces': {
            'valid_type': float
        },
        'threshold_stress': {
            'valid_type': float
        },
        'reference_workchain': {
            'valid_type': orm.WorkChainNode
        },
        'engines': {}
    }

    for port_name, values in required_ports.items():
        assert isinstance(generator_spec.inputs.get_port(port_name),
                          (InputGeneratorPort, engine.PortNamespace))

        if 'valid_type' in values:
            assert generator_spec.inputs.get_port(
                port_name).valid_type is values['valid_type']
Exemple #3
0
def pseudo_dojo(generate_jthxml_data):
    """Create a PseudoDojo pseudo potential family from scratch."""
    from aiida import plugins

    PseudoDojoFamily = plugins.GroupFactory('pseudo.family.pseudo_dojo')  # pylint: disable=invalid-name
    label = 'PseudoDojo/1.0/PBE/SR/standard/jthxml'

    try:
        family = PseudoDojoFamily.objects.get(label=label)
    except exceptions.NotExistent:
        pass
    else:
        return family

    cutoffs_dict = {'normal': {}}

    with tempfile.TemporaryDirectory() as dirpath:
        for values in elements.values():

            element = values['symbol']
            upf = generate_jthxml_data(element)
            filename = os.path.join(dirpath, f'{element}.jthxml')

            with open(filename, 'w+b') as handle:
                with upf.open(mode='rb') as source:
                    handle.write(source.read())
                    handle.flush()

            cutoffs_dict['normal'][element] = {'cutoff_wfc': 30., 'cutoff_rho': 240.}

        family = PseudoDojoFamily.create_from_folder(dirpath, label, pseudo_type=plugins.DataFactory('pseudo.jthxml'))

    for stringency, cutoffs in cutoffs_dict.items():
        family.set_cutoffs(cutoffs, stringency, unit='Eh')

    return family
from math import pi
import yaml

from aiida import engine
from aiida import orm
from aiida import plugins
from aiida.common import exceptions
from aiida_castep.data import get_pseudos_from_structure
from aiida_castep.data.otfg import OTFGGroup

from ..generator import RelaxInputsGenerator, RelaxType, SpinType, ElectronicType
# pylint: disable=import-outside-toplevel

__all__ = ('CastepRelaxInputGenerator', )

StructureData = plugins.DataFactory('structure')


class CastepRelaxInputGenerator(RelaxInputsGenerator):
    """Input generator for the `CastepRelaxWorkChain`."""

    _default_protocol = 'moderate'
    _calc_types = {
        'relax': {
            'code_plugin': 'castep.castep',
            'description': 'The code to perform the relaxation.'
        }
    }
    _relax_types = {
        RelaxType.ATOMS:
        'Relax only the atomic positions while keeping the cell fixed.',
"""Implementation of `aiida_common_workflows.common.relax.generator.RelaxInputGenerator` for CP2K."""
import collections
import pathlib
from typing import Any, Dict, List
import yaml
import numpy as np

from aiida import engine
from aiida import orm
from aiida import plugins

from ..generator import RelaxInputsGenerator, RelaxType, SpinType, ElectronicType

__all__ = ('Cp2kRelaxInputsGenerator', )

StructureData = plugins.DataFactory('structure')  # pylint: disable=invalid-name
KpointsData = plugins.DataFactory('array.kpoints')  # pylint: disable=invalid-name

EV_A3_TO_GPA = 160.21766208


def dict_merge(dct, merge_dct):
    """ Taken from https://gist.github.com/angstwad/bf22d1822c38a92ec0a9
    Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
    updating only top-level keys, dict_merge recurses down into dicts nested
    to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
    ``dct``.
    :param dct: dict onto which the merge is executed
    :param merge_dct: dct merged into dct
    :return: None
    """
Exemple #6
0
from math import pi
import yaml

from aiida import engine
from aiida import orm
from aiida import plugins
from aiida.common import exceptions
from aiida_castep.data import get_pseudos_from_structure
from aiida_castep.data.otfg import OTFGGroup

from ..generator import RelaxInputsGenerator, RelaxType, SpinType, ElectronicType
# pylint: disable=import-outside-toplevel, too-many-branches, too-many-statements

__all__ = ('CastepRelaxInputGenerator',)

StructureData = plugins.DataFactory('structure')  # pylint: disable=invalid-name


class CastepRelaxInputGenerator(RelaxInputsGenerator):
    """Input generator for the `CastepRelaxWorkChain`."""

    _default_protocol = 'moderate'
    _calc_types = {'relax': {'code_plugin': 'castep.castep', 'description': 'The code to perform the relaxation.'}}
    _relax_types = {
        RelaxType.ATOMS: 'Relax only the atomic positions while keeping the cell fixed.',
        RelaxType.ATOMS_CELL: 'Relax both atomic positions and the cell.',
        RelaxType.ATOMS_SHAPE: 'Relax both atomic positions and the shape of the cell, keeping the volume fixed.',
        RelaxType.ATOMS_VOLUME: 'Relax both atomic positions and the volume of the cell, keeping the cell shape fixed.',
        RelaxType.NONE: 'Do not do any relaxation.',
        RelaxType.CELL: 'Only relax the cell with the scaled positions of atoms are kept fixed.',
        RelaxType.SHAPE: 'Only relax the shape of the cell.',
Exemple #7
0
    def get_builder(self,
                    structure: StructureData,
                    calc_engines: Dict[str, Any],
                    *,
                    protocol: str = None,
                    relax_type: RelaxType = RelaxType.ATOMS,
                    electronic_type: ElectronicType = ElectronicType.METAL,
                    spin_type: SpinType = SpinType.NONE,
                    magnetization_per_site: List[float] = None,
                    threshold_forces: float = None,
                    threshold_stress: float = None,
                    previous_workchain=None,
                    **kwargs) -> engine.ProcessBuilder:
        """Return a process builder for the corresponding workchain class with inputs set according to the protocol.

        :param structure: the structure to be relaxed.
        :param calc_engines: a dictionary containing the computational resources for the relaxation.
        :param protocol: the protocol to use when determining the workchain inputs.
        :param relax_type: the type of relaxation to perform.
        :param electronic_type: the electronic character that is to be used for the structure.
        :param spin_type: the spin polarization type to use for the calculation.
        :param magnetization_per_site: a list with the initial spin polarization for each site. Float or integer in
            units of electrons. If not defined, the builder will automatically define the initial magnetization if and
            only if `spin_type != SpinType.NONE`.
        :param threshold_forces: target threshold for the forces in eV/Å.
        :param threshold_stress: target threshold for the stress in eV/Å^3.
        :param previous_workchain: a <Code>RelaxWorkChain node.
        :param kwargs: any inputs that are specific to the plugin.
        :return: a `aiida.engine.processes.ProcessBuilder` instance ready to be submitted.
        """
        # pylint: disable=too-many-locals
        protocol = protocol or self.get_default_protocol_name()

        super().get_builder(structure,
                            calc_engines,
                            protocol=protocol,
                            relax_type=relax_type,
                            electronic_type=electronic_type,
                            spin_type=spin_type,
                            magnetization_per_site=magnetization_per_site,
                            threshold_forces=threshold_forces,
                            threshold_stress=threshold_stress,
                            previous_workchain=previous_workchain,
                            **kwargs)

        # Get the protocol that we want to use
        if protocol is None:
            protocol = self._default_protocol
        protocol = self.get_protocol(protocol)

        # Set the builder
        builder = self.process_class.get_builder()

        # Set code
        builder.code = orm.load_code(calc_engines['relax']['code'])

        # Set structure
        builder.structure = structure

        # Set options
        builder.options = plugins.DataFactory('dict')(
            dict=calc_engines['relax']['options'])

        # Set settings
        # Make sure we add forces and stress for the VASP parser
        settings = AttributeDict()
        settings.update(
            {'parser_settings': {
                'add_forces': True,
                'add_stress': True
            }})
        builder.settings = plugins.DataFactory('dict')(dict=settings)

        # Set workchain related inputs, in this case, give more explicit output to report
        builder.verbose = plugins.DataFactory('bool')(True)

        # Set parameters
        builder.parameters = plugins.DataFactory('dict')(
            dict=protocol['parameters'])

        # Set potentials and their mapping
        builder.potential_family = plugins.DataFactory('str')(
            protocol['potential_family'])
        builder.potential_mapping = plugins.DataFactory('dict')(
            dict=self._potential_mapping[protocol['potential_mapping']])

        # Set the kpoint grid from the density in the protocol
        kpoints = plugins.DataFactory('array.kpoints')()
        kpoints.set_kpoints_mesh([1, 1, 1])
        kpoints.set_cell_from_structure(structure)
        rec_cell = kpoints.reciprocal_cell
        kpoints.set_kpoints_mesh(
            fetch_k_grid(rec_cell, protocol['kpoint_distance']))
        builder.kpoints = kpoints

        # Here we set the protocols fast, moderate and precise. These currently have no formal meaning.
        # After a while these will be set in the VASP workchain entrypoints using the convergence workchain etc.
        # However, for now we rely on defaults plane wave cutoffs and a set k-point density for the chosen protocol.
        relax = AttributeDict()
        relax.perform = plugins.DataFactory('bool')(True)
        relax.algo = plugins.DataFactory('str')(protocol['relax']['algo'])

        if relax_type == RelaxType.ATOMS:
            relax.positions = plugins.DataFactory('bool')(True)
            relax.shape = plugins.DataFactory('bool')(False)
            relax.volume = plugins.DataFactory('bool')(False)
        elif relax_type == RelaxType.CELL:
            relax.positions = plugins.DataFactory('bool')(False)
            relax.shape = plugins.DataFactory('bool')(True)
            relax.volume = plugins.DataFactory('bool')(True)
        elif relax_type == RelaxType.ATOMS_CELL:
            relax.positions = plugins.DataFactory('bool')(True)
            relax.shape = plugins.DataFactory('bool')(True)
            relax.volume = plugins.DataFactory('bool')(True)
        else:
            raise ValueError('relaxation type `{}` is not supported'.format(
                relax_type.value))

        if threshold_forces is not None:
            threshold = threshold_forces
        else:
            threshold = protocol['relax']['threshold_forces']
        relax.force_cutoff = plugins.DataFactory('float')(threshold)

        if threshold_stress is not None:
            raise ValueError(
                'Using a stress threshold is not directly available in VASP during relaxation.'
            )

        builder.relax = relax

        return builder
Exemple #8
0
class BasisData(plugins.DataFactory('singlefile')):
    """Base class for data types representing bases."""

    _key_element = 'element'
    _key_md5 = 'md5'

    @classmethod
    def get_or_create(cls, stream: typing.BinaryIO, filename: str = None):
        """Get basis data node from database with matching md5 checksum or create a new one if not existent.

        :param stream: a filelike object with the binary content of the file.
        :param filename: optional explicit filename to give to the file stored in the repository.
        :return: instance of ``BasisData``, stored if taken from database, unstored otherwise.
        """
        query = orm.QueryBuilder()
        query.append(cls, subclassing=False, filters={f'attributes.{cls._key_md5}': md5_from_filelike(stream)})

        existing = query.first()

        if existing:
            basis = existing[0]
        else:
            stream.seek(0)
            basis = cls(stream, filename)

        return basis

    @classmethod
    def get_entry_point_name(cls):
        """Return the entry point name associated with this data class.

        :return: the entry point name.
        """
        from aiida.plugins.entry_point import get_entry_point_from_class
        _, entry_point = get_entry_point_from_class(cls.__module__, cls.__name__)
        return entry_point.name

    @staticmethod
    def is_readable_byte_stream(stream) -> bool:
        """Return whether an object appears to be a readable filelike object in binary mode or stream of bytes.

        :param stream: the object to analyse.
        :returns: True if ``stream`` appears to be a readable filelike object in binary mode, False otherwise.
        """
        return (
            isinstance(stream, io.BytesIO) or
            (hasattr(stream, 'read') and hasattr(stream, 'mode') and 'b' in stream.mode)
        )

    @classmethod
    def prepare_source(cls, source: typing.Union[str, pathlib.Path, typing.BinaryIO]) -> typing.BinaryIO:  # pylint: disable=unsubscriptable-object
        """Validate the ``source`` representing a file on disk or a byte stream.

        .. note:: if the ``source`` is a valid file on disk, its content is read and returned as a stream of bytes.

        :raises TypeError: if the source is not a ``str``, ``pathlib.Path`` instance or binary stream.
        :raises FileNotFoundError: if the source is a filepath but does not exist.
        """
        if not isinstance(source, (str, pathlib.Path)) and not cls.is_readable_byte_stream(source):
            raise TypeError(
                f'`source` should be a `str` or `pathlib.Path` filepath on disk or a stream of bytes, got: {source}'
            )

        if isinstance(source, (str, pathlib.Path)):
            filename = pathlib.Path(source).name
            with open(source, 'rb') as handle:
                source = io.BytesIO(handle.read())
                source.name = filename

        return source

    @classmethod
    def validate_element(cls, element: str):
        """Validate the given element symbol.

        :param element: the symbol of the element following the IUPAC naming standard.
        :raises ValueError: if the element symbol is invalid.
        """
        if element not in [values['symbol'] for values in elements.values()]:
            raise ValueError(f'`{element}` is not a valid element.')

    def validate_md5(self, md5: str):
        """Validate that the md5 checksum matches that of the currently stored file.

        :param value: the md5 checksum.
        :raises ValueError: if the md5 does not match that of the currently stored file.
        """
        with self.open(mode='rb') as handle:
            md5_file = md5_from_filelike(handle)
            if md5 != md5_file:
                raise ValueError(f'md5 does not match that of stored file: {md5} != {md5_file}')

    def set_file(self, stream: typing.BinaryIO, filename: str = None, **kwargs):
        """Set the file content.

        :param stream: a filelike object with the binary content of the file.
        :param filename: optional explicit filename to give to the file stored in the repository.
        """
        super().set_file(stream, filename, **kwargs)
        stream.seek(0)
        self.md5 = md5_from_filelike(stream)

    def store(self, **kwargs):
        """Store the node verifying first that all required attributes are set.

        :raises :py:exc:`~aiida.common.StoringNotAllowed`: if no valid element has been defined.
        """
        try:
            self.validate_element(self.element)
        except ValueError as exception:
            raise StoringNotAllowed('no valid element has been defined.') from exception

        try:
            self.validate_md5(self.md5)
        except ValueError as exception:
            raise StoringNotAllowed(exception) from exception

        return super().store(**kwargs)

    @property
    def element(self) -> typing.Union[str, None]:  # pylint: disable=unsubscriptable-object
        """Return the element symbol.

        :return: the symbol of the element following the IUPAC naming standard or None if not defined.
        """
        return self.get_attribute(self._key_element, None)

    @element.setter
    def element(self, value: str):
        """Set the element.

        :param value: the symbol of the element following the IUPAC naming standard.
        :raises ValueError: if the element symbol is invalid.
        """
        self.validate_element(value)
        self.set_attribute(self._key_element, value)

    @property
    def md5(self) -> typing.Union[str, None]:  # pylint: disable=unsubscriptable-object
        """Return the md5.

        :return: the md5 of the stored file.
        """
        return self.get_attribute(self._key_md5, None)

    @md5.setter
    def md5(self, value: str):
        """Set the md5.

        :param value: the md5 checksum.
        :raises ValueError: if the md5 does not match that of the currently stored file.
        """
        self.validate_md5(value)
        self.set_attribute(self._key_md5, value)
# -*- coding: utf-8 -*-
"""Implementation of `aiida_common_workflows.common.relax.generator.RelaxInputGenerator` for BigDFT."""
from typing import Any, Dict, List

from aiida import engine
from aiida import orm
from aiida import plugins
from aiida.engine import calcfunction

from ..generator import RelaxInputsGenerator, RelaxType, SpinType, ElectronicType

__all__ = ('BigDftRelaxInputsGenerator', )

BigDFTParameters = plugins.DataFactory('bigdft')
StructureData = plugins.DataFactory('structure')


@calcfunction
def ortho_struct(input_struct):
    """Create and update a dict to pass to transform_to_orthorombic,
      and then get back data to the input dict """
    dico = dict()
    dico['name'] = input_struct.sites[0].kind_name
    dico['a'] = round(input_struct.cell_lengths[0], 6)
    dico['alpha'] = round(input_struct.cell_angles[0], 6)
    dico['b'] = round(input_struct.cell_lengths[1], 6)
    dico['beta'] = round(input_struct.cell_angles[1], 6)
    dico['c'] = round(input_struct.cell_lengths[2], 6)
    dico['gamma'] = round(input_struct.cell_angles[2], 6)
    dico['nat'] = len(input_struct.sites)
    # use abc coordinates
Exemple #10
0
class PseudoPotentialData(plugins.DataFactory('singlefile')):
    """Base class for data types representing pseudo potentials."""

    _key_element = 'element'
    _key_md5 = 'md5'

    @classmethod
    def get_or_create(cls,
                      source: typing.Union[str, pathlib.Path, typing.BinaryIO],
                      filename: str = None):
        """Get pseudopotenial data node from database with matching md5 checksum or create a new one if not existent.

        :param source: the source pseudopotential content, either a binary stream, or a ``str`` or ``Path`` to the path
            of the file on disk, which can be relative or absolute.
        :param filename: optional explicit filename to give to the file stored in the repository.
        :return: instance of ``PseudoPotentialData``, stored if taken from database, unstored otherwise.
        :raises TypeError: if the source is not a ``str``, ``pathlib.Path`` instance or binary stream.
        :raises FileNotFoundError: if the source is a filepath but does not exist.
        """
        source = cls.prepare_source(source)

        query = orm.QueryBuilder()
        query.append(
            cls,
            subclassing=False,
            filters={f'attributes.{cls._key_md5}': md5_from_filelike(source)})

        existing = query.first()

        if existing:
            pseudo = existing[0]
        else:
            source.seek(0)
            pseudo = cls(source, filename)

        return pseudo

    @classmethod
    def get_entry_point_name(cls):
        """Return the entry point name associated with this data class.

        :return: the entry point name.
        """
        from aiida.plugins.entry_point import get_entry_point_from_class
        _, entry_point = get_entry_point_from_class(cls.__module__,
                                                    cls.__name__)
        return entry_point.name

    @staticmethod
    def is_readable_byte_stream(stream) -> bool:
        """Return whether an object appears to be a readable filelike object in binary mode or stream of bytes.

        :param stream: the object to analyse.
        :returns: True if ``stream`` appears to be a readable filelike object in binary mode, False otherwise.
        """
        return (isinstance(stream, io.BytesIO)
                or (hasattr(stream, 'read') and hasattr(stream, 'mode')
                    and 'b' in stream.mode))

    @classmethod
    def prepare_source(
        cls, source: typing.Union[str, pathlib.Path, typing.BinaryIO]
    ) -> typing.BinaryIO:
        """Validate the ``source`` representing a file on disk or a byte stream.

        .. note:: if the ``source`` is a valid file on disk, its content is read and returned as a stream of bytes.

        :raises TypeError: if the source is not a ``str``, ``pathlib.Path`` instance or binary stream.
        :raises FileNotFoundError: if the source is a filepath but does not exist.
        """
        if not isinstance(
                source,
            (str, pathlib.Path)) and not cls.is_readable_byte_stream(source):
            raise TypeError(
                f'`source` should be a `str` or `pathlib.Path` filepath on disk or a stream of bytes, got: {source}'
            )

        if isinstance(source, (str, pathlib.Path)):
            filename = pathlib.Path(source).name
            with open(source, 'rb') as handle:
                source = io.BytesIO(handle.read())
                source.name = filename

        return source

    @classmethod
    def validate_element(cls, element: str):
        """Validate the given element symbol.

        :param element: the symbol of the element following the IUPAC naming standard.
        :raises ValueError: if the element symbol is invalid.
        """
        if element not in [values['symbol'] for values in elements.values()]:
            raise ValueError(f'`{element}` is not a valid element.')

    def validate_md5(self, md5: str):
        """Validate that the md5 checksum matches that of the currently stored file.

        :param value: the md5 checksum.
        :raises ValueError: if the md5 does not match that of the currently stored file.
        """
        with self.open(mode='rb') as handle:
            md5_file = md5_from_filelike(handle)
            if md5 != md5_file:
                raise ValueError(
                    f'md5 does not match that of stored file: {md5} != {md5_file}'
                )

    def set_file(self,
                 source: typing.Union[str, pathlib.Path, typing.BinaryIO],
                 filename: str = None,
                 **kwargs):
        """Set the file content.

        .. note:: this method will first analyse the type of the ``source`` and if it is a filepath will convert it
            to a binary stream of the content located at that filepath, which is then passed on to the superclass. This
            needs to be done first, because it will properly set the file and filename attributes that are expected by
            other methods. Straight after the superclass call, the source seeker needs to be reset to zero if it needs
            to be read again, because the superclass most likely will have read the stream to the end. Finally it is
            important that the ``prepare_source`` is called here before the superclass invocation, because this way the
            conversion from filepath to byte stream will be performed only once. Otherwise, each subclass would perform
            the conversion over and over again.

        :param source: the source pseudopotential content, either a binary stream, or a ``str`` or ``Path`` to the path
            of the file on disk, which can be relative or absolute.
        :param filename: optional explicit filename to give to the file stored in the repository.
        :raises TypeError: if the source is not a ``str``, ``pathlib.Path`` instance or binary stream.
        :raises FileNotFoundError: if the source is a filepath but does not exist.
        """
        source = self.prepare_source(source)
        super().set_file(source, filename, **kwargs)
        source.seek(0)
        self.md5 = md5_from_filelike(source)

    def store(self, **kwargs):
        """Store the node verifying first that all required attributes are set.

        :raises :py:exc:`~aiida.common.StoringNotAllowed`: if no valid element has been defined.
        """
        try:
            self.validate_element(self.element)
        except ValueError as exception:
            raise StoringNotAllowed(
                'no valid element has been defined.') from exception

        try:
            self.validate_md5(self.md5)
        except ValueError as exception:
            raise StoringNotAllowed(exception) from exception

        return super().store(**kwargs)

    @property
    def element(self) -> typing.Optional[int]:
        """Return the element symbol.

        :return: the symbol of the element following the IUPAC naming standard or None if not defined.
        """
        return self.get_attribute(self._key_element, None)

    @element.setter
    def element(self, value: str):
        """Set the element.

        :param value: the symbol of the element following the IUPAC naming standard.
        :raises ValueError: if the element symbol is invalid.
        """
        self.validate_element(value)
        self.set_attribute(self._key_element, value)

    @property
    def md5(self) -> typing.Optional[int]:
        """Return the md5.

        :return: the md5 of the stored file.
        """
        return self.get_attribute(self._key_md5, None)

    @md5.setter
    def md5(self, value: str):
        """Set the md5.

        :param value: the md5 checksum.
        :raises ValueError: if the md5 does not match that of the currently stored file.
        """
        self.validate_md5(value)
        self.set_attribute(self._key_md5, value)
Exemple #11
0
def test_get_entry_point_name(entry_point_name):
    """Test the ``BasisData.get_entry_point_name`` method."""
    cls = plugins.DataFactory(entry_point_name)
    assert cls.get_entry_point_name() == entry_point_name
    def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
        """Construct a process builder based on the provided keyword arguments.

        The keyword arguments will have been validated against the input generator specification.
        """
        # pylint: disable=too-many-branches,too-many-statements,too-many-locals
        structure = kwargs['structure']
        engines = kwargs['engines']
        protocol = kwargs['protocol']
        spin_type = kwargs['spin_type']
        relax_type = kwargs['relax_type']
        magnetization_per_site = kwargs.get('magnetization_per_site', None)
        threshold_forces = kwargs.get('threshold_forces', None)
        threshold_stress = kwargs.get('threshold_stress', None)
        reference_workchain = kwargs.get('reference_workchain', None)

        # Get the protocol that we want to use
        if protocol is None:
            protocol = self._default_protocol
        protocol = self.get_protocol(protocol)

        # Set the builder
        builder = self.process_class.get_builder()

        # Set code
        builder.code = orm.load_code(engines['relax']['code'])

        # Set structure
        builder.structure = structure

        # Set options
        builder.options = plugins.DataFactory('dict')(
            dict=engines['relax']['options'])

        # Set settings
        # Make sure the VASP parser is configured for the problem
        settings = AttributeDict()
        settings.update({
            'parser_settings': {
                'add_energies': True,
                'add_forces': True,
                'add_stress': True,
                'add_misc': {
                    'type':
                    'dict',
                    'quantities': [
                        'total_energies', 'maximum_stress', 'maximum_force',
                        'magnetization', 'notifications', 'run_status',
                        'run_stats', 'version'
                    ],
                    'link_name':
                    'misc'
                }
            }
        })
        builder.settings = plugins.DataFactory('dict')(dict=settings)

        # Set workchain related inputs, in this case, give more explicit output to report
        builder.verbose = plugins.DataFactory('bool')(True)

        # Fetch initial parameters from the protocol file.
        # Here we set the protocols fast, moderate and precise. These currently have no formal meaning.
        # After a while these will be set in the VASP workchain entrypoints using the convergence workchain etc.
        # However, for now we rely on plane wave cutoffs and a set k-point density for the chosen protocol.
        # Please consult the protocols.yml file for details.
        parameters_dict = protocol['parameters']

        # Set spin related parameters
        if spin_type == SpinType.NONE:
            parameters_dict['ispin'] = 1
        elif spin_type == SpinType.COLLINEAR:
            parameters_dict['ispin'] = 2

        # Set the magnetization
        if magnetization_per_site is not None:
            parameters_dict['magmom'] = list(magnetization_per_site)

        # Set the parameters on the builder, put it in the code namespace to pass through
        # to the code inputs
        builder.parameters = plugins.DataFactory('dict')(
            dict={
                'incar': parameters_dict
            })

        # Set potentials and their mapping
        builder.potential_family = plugins.DataFactory('str')(
            protocol['potential_family'])
        builder.potential_mapping = plugins.DataFactory('dict')(
            dict=self._potential_mapping[protocol['potential_mapping']])

        # Set the kpoint grid from the density in the protocol
        kpoints = plugins.DataFactory('array.kpoints')()
        kpoints.set_cell_from_structure(structure)
        if reference_workchain:
            previous_kpoints = reference_workchain.inputs.kpoints
            kpoints.set_kpoints_mesh(previous_kpoints.get_attribute('mesh'),
                                     previous_kpoints.get_attribute('offset'))
        else:
            kpoints.set_kpoints_mesh_from_density(protocol['kpoint_distance'])
        builder.kpoints = kpoints

        # Set the relax parameters
        relax = AttributeDict()
        if relax_type != RelaxType.NONE:
            # Perform relaxation of cell or positions
            relax.perform = plugins.DataFactory('bool')(True)
            relax.algo = plugins.DataFactory('str')(protocol['relax']['algo'])
            relax.steps = plugins.DataFactory('int')(
                protocol['relax']['steps'])
            if relax_type == RelaxType.POSITIONS:
                relax.positions = plugins.DataFactory('bool')(True)
                relax.shape = plugins.DataFactory('bool')(False)
                relax.volume = plugins.DataFactory('bool')(False)
            elif relax_type == RelaxType.CELL:
                relax.positions = plugins.DataFactory('bool')(False)
                relax.shape = plugins.DataFactory('bool')(True)
                relax.volume = plugins.DataFactory('bool')(True)
            elif relax_type == RelaxType.VOLUME:
                relax.positions = plugins.DataFactory('bool')(False)
                relax.shape = plugins.DataFactory('bool')(False)
                relax.volume = plugins.DataFactory('bool')(True)
            elif relax_type == RelaxType.SHAPE:
                relax.positions = plugins.DataFactory('bool')(False)
                relax.shape = plugins.DataFactory('bool')(True)
                relax.volume = plugins.DataFactory('bool')(False)
            elif relax_type == RelaxType.POSITIONS_CELL:
                relax.positions = plugins.DataFactory('bool')(True)
                relax.shape = plugins.DataFactory('bool')(True)
                relax.volume = plugins.DataFactory('bool')(True)
            elif relax_type == RelaxType.POSITIONS_SHAPE:
                relax.positions = plugins.DataFactory('bool')(True)
                relax.shape = plugins.DataFactory('bool')(True)
                relax.volume = plugins.DataFactory('bool')(False)
        else:
            # Do not perform any relaxation
            relax.perform = plugins.DataFactory('bool')(False)

        if threshold_forces is not None:
            threshold = threshold_forces
        else:
            threshold = protocol['relax']['threshold_forces']
        relax.force_cutoff = plugins.DataFactory('float')(threshold)

        if threshold_stress is not None:
            raise ValueError(
                'Using a stress threshold is not directly available in VASP during relaxation.'
            )

        builder.relax = relax

        return builder
    def get_builder(self,
                    structure: StructureData,
                    engines: Dict[str, Any],
                    *,
                    protocol: str = None,
                    relax_type: RelaxType = RelaxType.POSITIONS,
                    electronic_type: ElectronicType = ElectronicType.METAL,
                    spin_type: SpinType = SpinType.NONE,
                    magnetization_per_site: List[float] = None,
                    threshold_forces: float = None,
                    threshold_stress: float = None,
                    reference_workchain=None,
                    **kwargs) -> engine.ProcessBuilder:
        """Return a process builder for the corresponding workchain class with inputs set according to the protocol.

        :param structure: the structure to be relaxed.
        :param engines: a dictionary containing the computational resources for the relaxation.
        :param protocol: the protocol to use when determining the workchain inputs.
        :param relax_type: the type of relaxation to perform.
        :param electronic_type: the electronic character that is to be used for the structure.
        :param spin_type: the spin polarization type to use for the calculation.
        :param magnetization_per_site: a list with the initial spin polarization for each site. Float or integer in
            units of electrons. If not defined, the builder will automatically define the initial magnetization if and
            only if `spin_type != SpinType.NONE`.
        :param threshold_forces: target threshold for the forces in eV/Å.
        :param threshold_stress: target threshold for the stress in eV/Å^3.
        :param reference_workchain: a <Code>RelaxWorkChain node.
        :param kwargs: any inputs that are specific to the plugin.
        :return: a `aiida.engine.processes.ProcessBuilder` instance ready to be submitted.
        """
        # pylint: disable=too-many-locals, too-many-branches, too-many-statements
        protocol = protocol or self.get_default_protocol_name()

        super().get_builder(structure,
                            engines,
                            protocol=protocol,
                            relax_type=relax_type,
                            electronic_type=electronic_type,
                            spin_type=spin_type,
                            magnetization_per_site=magnetization_per_site,
                            threshold_forces=threshold_forces,
                            threshold_stress=threshold_stress,
                            reference_workchain=reference_workchain,
                            **kwargs)

        # Get the protocol that we want to use
        if protocol is None:
            protocol = self._default_protocol
        protocol = self.get_protocol(protocol)

        # Set the builder
        builder = self.process_class.get_builder()

        # Set code
        builder.code = orm.load_code(engines['relax']['code'])

        # Set structure
        builder.structure = structure

        # Set options
        builder.options = plugins.DataFactory('dict')(
            dict=engines['relax']['options'])

        # Set settings
        # Make sure the VASP parser is configured for the problem
        settings = AttributeDict()
        settings.update({
            'parser_settings': {
                'add_energies': True,
                'add_forces': True,
                'add_stress': True,
                'add_misc': {
                    'type':
                    'dict',
                    'quantities': [
                        'total_energies', 'maximum_stress', 'maximum_force',
                        'magnetization', 'notifications', 'run_status',
                        'run_stats', 'version'
                    ],
                    'link_name':
                    'misc'
                }
            }
        })
        builder.settings = plugins.DataFactory('dict')(dict=settings)

        # Set workchain related inputs, in this case, give more explicit output to report
        builder.verbose = plugins.DataFactory('bool')(True)

        # Fetch initial parameters from the protocol file.
        # Here we set the protocols fast, moderate and precise. These currently have no formal meaning.
        # After a while these will be set in the VASP workchain entrypoints using the convergence workchain etc.
        # However, for now we rely on plane wave cutoffs and a set k-point density for the chosen protocol.
        # Please consult the protocols.yml file for details.
        parameters_dict = protocol['parameters']

        # Set spin related parameters
        if spin_type == SpinType.NONE:
            parameters_dict['ispin'] = 1
        elif spin_type == SpinType.COLLINEAR:
            parameters_dict['ispin'] = 2

        # Set the magnetization
        if magnetization_per_site is not None:
            parameters_dict['magmom'] = list(magnetization_per_site)

        # Set the parameters on the builder, put it in the code namespace to pass through
        # to the code inputs
        builder.parameters = plugins.DataFactory('dict')(
            dict={
                'incar': parameters_dict
            })

        # Set potentials and their mapping
        builder.potential_family = plugins.DataFactory('str')(
            protocol['potential_family'])
        builder.potential_mapping = plugins.DataFactory('dict')(
            dict=self._potential_mapping[protocol['potential_mapping']])

        # Set the kpoint grid from the density in the protocol
        kpoints = plugins.DataFactory('array.kpoints')()
        kpoints.set_cell_from_structure(structure)
        if reference_workchain:
            previous_kpoints = reference_workchain.inputs.kpoints
            kpoints.set_kpoints_mesh(previous_kpoints.get_attribute('mesh'),
                                     previous_kpoints.get_attribute('offset'))
        else:
            kpoints.set_kpoints_mesh_from_density(protocol['kpoint_distance'])
        builder.kpoints = kpoints

        # Set the relax parameters
        relax = AttributeDict()
        if relax_type != RelaxType.NONE:
            # Perform relaxation of cell or positions
            relax.perform = plugins.DataFactory('bool')(True)
            relax.algo = plugins.DataFactory('str')(protocol['relax']['algo'])
            relax.steps = plugins.DataFactory('int')(
                protocol['relax']['steps'])
            if relax_type == RelaxType.POSITIONS:
                relax.positions = plugins.DataFactory('bool')(True)
                relax.shape = plugins.DataFactory('bool')(False)
                relax.volume = plugins.DataFactory('bool')(False)
            elif relax_type == RelaxType.CELL:
                relax.positions = plugins.DataFactory('bool')(False)
                relax.shape = plugins.DataFactory('bool')(True)
                relax.volume = plugins.DataFactory('bool')(True)
            elif relax_type == RelaxType.VOLUME:
                relax.positions = plugins.DataFactory('bool')(False)
                relax.shape = plugins.DataFactory('bool')(False)
                relax.volume = plugins.DataFactory('bool')(True)
            elif relax_type == RelaxType.SHAPE:
                relax.positions = plugins.DataFactory('bool')(False)
                relax.shape = plugins.DataFactory('bool')(True)
                relax.volume = plugins.DataFactory('bool')(False)
            elif relax_type == RelaxType.POSITIONS_CELL:
                relax.positions = plugins.DataFactory('bool')(True)
                relax.shape = plugins.DataFactory('bool')(True)
                relax.volume = plugins.DataFactory('bool')(True)
            elif relax_type == RelaxType.POSITIONS_SHAPE:
                relax.positions = plugins.DataFactory('bool')(True)
                relax.shape = plugins.DataFactory('bool')(True)
                relax.volume = plugins.DataFactory('bool')(False)
        else:
            # Do not perform any relaxation
            relax.perform = plugins.DataFactory('bool')(False)

        if threshold_forces is not None:
            threshold = threshold_forces
        else:
            threshold = protocol['relax']['threshold_forces']
        relax.force_cutoff = plugins.DataFactory('float')(threshold)

        if threshold_stress is not None:
            raise ValueError(
                'Using a stress threshold is not directly available in VASP during relaxation.'
            )

        builder.relax = relax

        return builder
Exemple #14
0
class PseudoPotentialData(plugins.DataFactory('singlefile')):
    """Base class for data types representing pseudo potentials."""

    _key_element = 'element'
    _key_md5 = 'md5'

    @classmethod
    def get_or_create(cls, stream: typing.BinaryIO, filename: str = None):
        """Get pseudopotenial data node from database with matching md5 checksum or create a new one if not existent.

        :param stream: a filelike object with the binary content of the file.
        :param filename: optional explicit filename to give to the file stored in the repository.
        :return: instance of ``PseudoPotentialData``, stored if taken from database, unstored otherwise.
        """
        query = orm.QueryBuilder()
        query.append(
            cls,
            subclassing=False,
            filters={f'attributes.{cls._key_md5}': md5_from_filelike(stream)})

        existing = query.first()

        if existing:
            pseudo = existing[0]
        else:
            stream.seek(0)
            pseudo = cls(stream, filename)

        return pseudo

    @classmethod
    def get_entry_point_name(cls):
        """Return the entry point name associated with this data class.

        :return: the entry point name.
        """
        from aiida.plugins.entry_point import get_entry_point_from_class
        _, entry_point = get_entry_point_from_class(cls.__module__,
                                                    cls.__name__)
        return entry_point.name

    @classmethod
    def validate_element(cls, element: str):
        """Validate the given element symbol.

        :param element: the symbol of the element following the IUPAC naming standard.
        :raises ValueError: if the element symbol is invalid.
        """
        if element not in [values['symbol'] for values in elements.values()]:
            raise ValueError(f'`{element}` is not a valid element.')

    def validate_md5(self, md5: str):
        """Validate that the md5 checksum matches that of the currently stored file.

        :param value: the md5 checksum.
        :raises ValueError: if the md5 does not match that of the currently stored file.
        """
        with self.open(mode='rb') as handle:
            md5_file = md5_from_filelike(handle)
            if md5 != md5_file:
                raise ValueError(
                    f'md5 does not match that of stored file: {md5} != {md5_file}'
                )

    def set_file(self,
                 stream: typing.BinaryIO,
                 filename: str = None,
                 **kwargs):
        """Set the file content.

        :param stream: a filelike object with the binary content of the file.
        :param filename: optional explicit filename to give to the file stored in the repository.
        """
        super().set_file(stream, filename, **kwargs)
        stream.seek(0)
        self.md5 = md5_from_filelike(stream)

    def store(self, **kwargs):
        """Store the node verifying first that all required attributes are set.

        :raises :py:exc:`~aiida.common.StoringNotAllowed`: if no valid element has been defined.
        """
        try:
            self.validate_element(self.element)
        except ValueError as exception:
            raise StoringNotAllowed(
                'no valid element has been defined.') from exception

        try:
            self.validate_md5(self.md5)
        except ValueError as exception:
            raise StoringNotAllowed(exception) from exception

        return super().store(**kwargs)

    @property
    def element(self) -> typing.Union[str, None]:
        """Return the element symbol.

        :return: the symbol of the element following the IUPAC naming standard or None if not defined.
        """
        return self.get_attribute(self._key_element, None)

    @element.setter
    def element(self, value: str):
        """Set the element.

        :param value: the symbol of the element following the IUPAC naming standard.
        :raises ValueError: if the element symbol is invalid.
        """
        self.validate_element(value)
        self.set_attribute(self._key_element, value)

    @property
    def md5(self) -> typing.Union[str, None]:
        """Return the md5.

        :return: the md5 of the stored file.
        """
        return self.get_attribute(self._key_md5, None)

    @md5.setter
    def md5(self, value: str):
        """Set the md5.

        :param value: the md5 checksum.
        :raises ValueError: if the md5 does not match that of the currently stored file.
        """
        self.validate_md5(value)
        self.set_attribute(self._key_md5, value)
    def define(cls, spec):
        """Define the specification of the input generator.

        The ports defined on the specification are the inputs that will be accepted by the ``get_builder`` method.
        """
        super().define(spec)
        spec.input('structure',
                   valid_type=plugins.DataFactory('structure'),
                   help='The structure whose geometry should be optimized.')
        spec.input(
            'protocol',
            valid_type=ChoiceType(('fast', 'moderate', 'precise')),
            default='moderate',
            help=
            'The protocol to use for the automated input generation. This value indicates the level of precision '
            'of the results and computational cost that the input parameters will be selected for.',
        )
        spec.input(
            'spin_type',
            valid_type=SpinType,
            serializer=SpinType,
            default=SpinType.NONE,
            help='The type of spin polarization to be used.',
        )
        spec.input(
            'relax_type',
            valid_type=RelaxType,
            serializer=RelaxType,
            default=RelaxType.POSITIONS,
            help=
            'The degrees of freedom during the geometry optimization process.',
        )
        spec.input(
            'electronic_type',
            valid_type=ElectronicType,
            serializer=ElectronicType,
            default=ElectronicType.METAL,
            help='The electronic character of the system.',
        )
        spec.input(
            'magnetization_per_site',
            valid_type=list,
            required=False,
            help=
            'The initial magnetization of the system. Should be a list of floats, where each float represents the '
            'spin polarization in units of electrons, meaning the difference between spin up and spin down '
            'electrons, for the site. This also corresponds to the magnetization of the site in Bohr magnetons '
            '(μB).',
        )
        spec.input(
            'threshold_forces',
            valid_type=float,
            required=False,
            help=
            'A real positive number indicating the target threshold for the forces in eV/Å. If not specified, '
            'the protocol specification will select an appropriate value.',
        )
        spec.input(
            'threshold_stress',
            valid_type=float,
            required=False,
            help=
            'A real positive number indicating the target threshold for the stress in eV/Å^3. If not specified, '
            'the protocol specification will select an appropriate value.',
        )
        spec.input(
            'reference_workchain',
            valid_type=orm.WorkChainNode,
            required=False,
            help=
            'The node of a previously completed process of the same type whose inputs should be taken into '
            'account when generating inputs. This is important for particular workflows where certain inputs have '
            'to be kept constant between successive iterations.',
        )
        spec.input_namespace(
            'engines',
            help='Inputs for the quantum engines',
        )
        spec.input_namespace(
            'engines.relax',
            help=
            'Inputs for the quantum engine performing the geometry optimization.',
        )
        spec.input(
            'engines.relax.code',
            valid_type=orm.Code,
            serializer=orm.load_code,
            help='The code instance to use for the geometry optimization.',
        )
        spec.input(
            'engines.relax.options',
            valid_type=dict,
            required=False,
            help='Options for the geometry optimization calculation jobs.',
        )