Exemplo n.º 1
0
    def _mpirun_command_validator(self, mpirun_cmd):
        """
        Validates the mpirun_command variable. MUST be called after properly
        checking for a valid scheduler.
        """
        if not isinstance(mpirun_cmd, (tuple, list)) or not all(
                isinstance(i, str) for i in mpirun_cmd):
            raise exceptions.ValidationError(
                'the mpirun_command must be a list of strings')

        try:
            job_resource_keys = self.get_scheduler(
            ).job_resource_class.get_valid_keys()
        except exceptions.EntryPointError:
            raise exceptions.ValidationError(
                'Unable to load the scheduler for this computer')

        subst = {i: 'value' for i in job_resource_keys}
        subst['tot_num_mpiprocs'] = 'value'

        try:
            for arg in mpirun_cmd:
                arg.format(**subst)
        except KeyError as exc:
            raise exceptions.ValidationError(
                f'In workdir there is an unknown replacement field {exc.args[0]}'
            )
        except ValueError as exc:
            raise exceptions.ValidationError(f"Error in the string: '{exc}'")
Exemplo n.º 2
0
    def validate(self):
        """
        Check if the attributes and files retrieved from the DB are valid.
        Raise a ValidationError if something is wrong.

        Must be able to work even before storing: therefore, use the get_attr and similar methods
        that automatically read either from the DB or from the internal attribute cache.

        For the base class, this is always valid. Subclasses will reimplement this.
        In the subclass, always call the super().validate() method first!
        """
        if not self.label.strip():
            raise exceptions.ValidationError('No name specified')

        self._hostname_validator(self.hostname)
        self._description_validator(self.description)
        self._transport_type_validator(self.transport_type)
        self._scheduler_type_validator(self.scheduler_type)
        self._workdir_validator(self.get_workdir())

        try:
            mpirun_cmd = self.get_mpirun_command()
        except exceptions.DbContentError:
            raise exceptions.ValidationError(
                'Error in the DB content of the metadata')

        # To be called AFTER the validation of the scheduler
        self._mpirun_command_validator(mpirun_cmd)
Exemplo n.º 3
0
 def array_list_checker(array_list, array_name, orb_length):
     """
     Does basic checks over everything in the array_list. Makes sure that
     all the arrays are np.ndarray floats, that the length is same as
     required_length, raises exception using array_name if there is
     a failure
     """
     if not all([isinstance(_, np.ndarray) for _ in array_list]):
         raise exceptions.ValidationError(f'{array_name} was not composed entirely of ndarrays')
     if len(array_list) != orb_length:
         raise exceptions.ValidationError(f'{array_name} did not have the same length as the list of orbitals')
Exemplo n.º 4
0
def validate_attribute_extra_key(key):
    """Validate the key for a node attribute or extra.

    :raise aiida.common.ValidationError: if the key is not a string or contains reserved separator character
    """
    if not key or not isinstance(key, str):
        raise exceptions.ValidationError('key for attributes or extras should be a string')

    if FIELD_SEPARATOR in key:
        raise exceptions.ValidationError(
            'key for attributes or extras cannot contain the character `{}`'.format(FIELD_SEPARATOR)
        )
Exemplo n.º 5
0
    def _validate(self):
        """Ensure that there is one object stored in the repository, whose key matches value set for `filename` attr."""
        super(SinglefileData, self)._validate()

        try:
            filename = self.filename
        except AttributeError:
            raise exceptions.ValidationError('the `filename` attribute is not set.')

        objects = self.list_object_names()

        if [filename] != objects:
            raise exceptions.ValidationError(
                'respository files {} do not match the `filename` attribute {}.'.format(objects, filename)
            )
Exemplo n.º 6
0
    def delete_many(self, filters):
        """
        Delete Logs based on ``filters``

        :param filters: similar to QueryBuilder filter
        :type filters: dict

        :return: (former) ``PK`` s of deleted Logs
        :rtype: list

        :raises TypeError: if ``filters`` is not a `dict`
        :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty
        """
        from aiida.orm import Log, QueryBuilder

        # Checks
        if not isinstance(filters, dict):
            raise TypeError('filters must be a dictionary')
        if not filters:
            raise exceptions.ValidationError('filters must not be empty')

        # Apply filter and delete found entities
        builder = QueryBuilder().append(Log, filters=filters, project='id')
        entities_to_delete = builder.all(flat=True)
        for entity in entities_to_delete:
            self.delete(entity)

        # Return list of deleted entities' (former) PKs for checking
        return entities_to_delete
Exemplo n.º 7
0
    def _workdir_validator(cls, workdir):
        """
        Validates the transport string.
        """
        if not workdir.strip():
            raise exceptions.ValidationError('No workdir specified')

        try:
            convertedwd = workdir.format(username='******')
        except KeyError as exc:
            raise exceptions.ValidationError('In workdir there is an unknown replacement field {}'.format(exc.args[0]))
        except ValueError as exc:
            raise exceptions.ValidationError("Error in the string: '{}'".format(exc))

        if not os.path.isabs(convertedwd):
            raise exceptions.ValidationError('The workdir must be an absolute path')
Exemplo n.º 8
0
 def set_group_label_prefix(self, label_prefix):
     """
     Set the label of the group to be created
     """
     if not isinstance(label_prefix, str):
         raise exceptions.ValidationError('group label must be a string')
     self._group_label_prefix = label_prefix
Exemplo n.º 9
0
 def set_group_name(self, gname):
     """
     Set the name of the group to be created
     """
     if not isinstance(gname, str):
         raise exceptions.ValidationError('group name must be a string')
     self.group_name = gname
Exemplo n.º 10
0
    def get(cls, **kwargs):
        """
        Custom get for group which can be used to get a group with the given attributes

        :param kwargs: the attributes to match the group to

        :return: the group
        :type nodes: :class:`aiida.orm.Node` or list
        """
        from aiida.orm import QueryBuilder

        filters = {}
        if 'type_string' in kwargs:
            if not isinstance(kwargs['type_string'], six.string_types):
                raise exceptions.ValidationError(
                    'type_string must be {}, you provided an object of type '
                    '{}'.format(str, type(kwargs['type_string'])))

        query = QueryBuilder()
        for key, val in kwargs.items():
            filters[key] = val

        query.append(cls, filters=filters)
        results = query.all()
        if len(results) > 1:
            raise exceptions.MultipleObjectsError(
                "Found {} groups matching criteria '{}'".format(
                    len(results), kwargs))
        if not results:
            raise exceptions.NotExistent(
                "No group found matching criteria '{}'".format(kwargs))
        return results[0][0]
Exemplo n.º 11
0
 def validate(strings):
     """Validate the list of strings passed to set_include and set_exclude."""
     if strings is None:
         return
     valid_prefixes = set([
         'aiida.node', 'aiida.calculations', 'aiida.workflows', 'aiida.data'
     ])
     for string in strings:
         pieces = string.split(':')
         if len(pieces) != 2:
             raise exceptions.ValidationError(
                 "'{}' is not a valid include/exclude filter, must contain two parts split by a colon"
                 .format(string))
         if pieces[0] not in valid_prefixes:
             raise exceptions.ValidationError(
                 "'{}' has an invalid prefix, must be among: {}".format(
                     string, sorted(valid_prefixes)))
Exemplo n.º 12
0
 def _scheduler_type_validator(cls, scheduler_type):
     """
     Validates the transport string.
     """
     from aiida.plugins.entry_point import get_entry_point_names
     if scheduler_type not in get_entry_point_names('aiida.schedulers'):
         raise exceptions.ValidationError(
             'The specified scheduler is not a valid one')
Exemplo n.º 13
0
 def set_exclude(self, exclude):
     """Return the list of classes to exclude from autogrouping."""
     the_exclude_classes = self._validate(exclude)
     if self.get_include() is not None:
         if 'all.' in self.get_include():
             if 'all.' in the_exclude_classes:
                 raise exceptions.ValidationError(
                     'Cannot exclude and include all classes')
     self.exclude = the_exclude_classes
Exemplo n.º 14
0
    def clean_builtin(val):
        """
        A function to clean build-in python values (`BaseType`).

        It mainly checks that we don't store NaN or Inf.
        """
        # This is a whitelist of all the things we understand currently
        if val is None or isinstance(val, (bool, six.string_types)):
            return val

        # This fixes #2773 - in python3, ``numpy.int64(-1)`` cannot be json-serialized
        # Note that `numbers.Integral` also match booleans but they are already returned above
        if isinstance(val, numbers.Integral):
            return int(val)

        if isinstance(val, numbers.Real) and (math.isnan(val)
                                              or math.isinf(val)):
            # see https://www.postgresql.org/docs/current/static/datatype-json.html#JSON-TYPE-MAPPING-TABLE
            raise exceptions.ValidationError(
                'nan and inf/-inf can not be serialized to the database')

        # This is for float-like types, like ``numpy.float128`` that are not json-serializable
        # Note that `numbers.Real` also match booleans but they are already returned above
        if isinstance(val, numbers.Real):
            string_representation = '{{:.{}g}}'.format(
                AIIDA_FLOAT_PRECISION).format(val)
            new_val = float(string_representation)
            if 'e' in string_representation and new_val.is_integer():
                # This is indeed often quite unexpected, because it is going to change the type of the data
                # from float to int. But anyway clean_value is changing some types, and we are also bound to what
                # our current backends do.
                # Currently, in both Django and SQLA (with JSONB attributes), if we store 1.e1, ..., 1.e14, 1.e15,
                # they will be stored as floats; instead 1.e16, 1.e17, ... will all be stored as integer anyway,
                # even if we don't run this clean_value step.
                # So, for consistency, it's better if we do the conversion ourselves here, and we do it for a bit
                # smaller numbers than python+[SQL+JSONB] would do (the AiiDA float precision is here 14), so the
                # results are consistent, and the hashing will work also after a round trip as expected.
                return int(new_val)
            return new_val

        # Anything else we do not understand and we refuse
        raise exceptions.ValidationError(
            'type `{}` is not supported as it is not json-serializable'.format(
                type(val)))
Exemplo n.º 15
0
    def _default_mpiprocs_per_machine_validator(cls, def_cpus_per_machine):
        """
        Validates the default number of CPUs per machine (node)
        """
        if def_cpus_per_machine is None:
            return

        if not isinstance(def_cpus_per_machine,
                          int) or def_cpus_per_machine <= 0:
            raise exceptions.ValidationError(
                'Invalid value for default_mpiprocs_per_machine, must be a positive integer, or an empty string if you '
                'do not want to provide a default value.')
Exemplo n.º 16
0
    def _validate(self):
        super()._validate()

        if self.is_local() is None:
            raise exceptions.ValidationError(
                'You did not set whether the code is local or remote')

        if self.is_local():
            if not self.get_local_executable():
                raise exceptions.ValidationError(
                    'You have to set which file is the local executable '
                    'using the set_exec_filename() method')
            if self.get_local_executable() not in self.list_object_names():
                raise exceptions.ValidationError(
                    "The local executable '{}' is not in the list of "
                    'files of this code'.format(self.get_local_executable()))
        else:
            if self.list_object_names():
                raise exceptions.ValidationError(
                    'The code is remote but it has files inside')
            if not self.get_remote_computer():
                raise exceptions.ValidationError(
                    'You did not specify a remote computer')
            if not self.get_remote_exec_path():
                raise exceptions.ValidationError(
                    'You did not specify a remote executable')
Exemplo n.º 17
0
    def _validate(self, param, is_exact=True):
        """
        Used internally to verify the sanity of exclude, include lists
        """
        from aiida.plugins import CalculationFactory, DataFactory

        for i in param:
            if not any([
                    i.startswith('calculation'),
                    i.startswith('code'),
                    i.startswith('data'),
                    i == 'all',
            ]):
                raise exceptions.ValidationError(
                    'Module not recognized, allow prefixes '
                    ' are: calculation, code or data')
        the_param = [i + '.' for i in param]

        factorydict = {
            'calculation': locals()['CalculationFactory'],
            'data': locals()['DataFactory']
        }

        for i in the_param:
            base, module = i.split('.', 1)
            if base == 'code':
                if module:
                    raise exceptions.ValidationError(
                        'Cannot have subclasses for codes')
            elif base == 'all':
                continue
            else:
                if is_exact:
                    try:
                        factorydict[base](module.rstrip('.'))
                    except exceptions.EntryPointError:
                        raise exceptions.ValidationError(
                            'Cannot find the class to be excluded')
        return the_param
Exemplo n.º 18
0
    def __init__(self,
                 time,
                 loggername,
                 levelname,
                 dbnode_id,
                 message='',
                 metadata=None,
                 backend=None):  # pylint: disable=too-many-arguments
        """Construct a new log

        :param time: time
        :type time: :class:`!datetime.datetime`

        :param loggername: name of logger
        :type loggername: str

        :param levelname: name of log level
        :type levelname: str

        :param dbnode_id: id of database node
        :type dbnode_id: int

        :param message: log message
        :type message: str

        :param metadata: metadata
        :type metadata: dict

        :param backend: database backend
        :type backend: :class:`aiida.orm.implementation.Backend`


        """
        from aiida.common import exceptions

        if metadata is not None and not isinstance(metadata, dict):
            raise TypeError('metadata must be a dict')

        if not loggername or not levelname:
            raise exceptions.ValidationError(
                'The loggername and levelname cannot be empty')

        backend = backend or get_manager().get_backend()
        model = backend.logs.create(time=time,
                                    loggername=loggername,
                                    levelname=levelname,
                                    dbnode_id=dbnode_id,
                                    message=message,
                                    metadata=metadata)
        super().__init__(model)
        self.store()  # Logs are immutable and automatically stored
Exemplo n.º 19
0
    def set_include(self, include):
        """Set the list of classes to include in the autogrouping.

        :param include: a list of valid entry point strings (might contain '%' to be used as
          string to be matched using SQL's ``LIKE`` pattern-making logic), or ``None``
          to specify no include list.
        """
        if isinstance(include, str):
            include = [include]
        self.validate(include)
        if include is not None and self.get_exclude() is not None:
            # It's ok to set None, both as a default, or to 'undo' the include list
            raise exceptions.ValidationError(
                'Cannot both specify exclude and include')
        self._include = include
Exemplo n.º 20
0
    def __init__(self,
                 label=None,
                 user=None,
                 description='',
                 type_string=GroupTypeString.USER.value,
                 backend=None):
        """
        Create a new group. Either pass a dbgroup parameter, to reload
        a group from the DB (and then, no further parameters are allowed),
        or pass the parameters for the Group creation.

        :param label: The group label, required on creation
        :type label: str

        :param description: The group description (by default, an empty string)
        :type description: str

        :param user: The owner of the group (by default, the automatic user)
        :type user: :class:`aiida.orm.User`

        :param type_string: a string identifying the type of group (by default,
            an empty string, indicating an user-defined group.
        :type type_string: str
        """
        if not label:
            raise ValueError('Group label must be provided')

        # Check that chosen type_string is allowed
        if not isinstance(type_string, six.string_types):
            raise exceptions.ValidationError(
                'type_string must be {}, you provided an object of type '
                '{}'.format(str, type(type_string)))

        backend = backend or get_manager().get_backend()
        user = user or users.User.objects(backend).get_default()
        type_check(user, users.User)

        model = backend.groups.create(label=label,
                                      user=user.backend_entity,
                                      description=description,
                                      type_string=type_string)
        super(Group, self).__init__(model)
Exemplo n.º 21
0
    def _check_projections_bands(self, projection_array):
        """
        Checks to make sure that a reference bandsdata is already set, and that
        projection_array is of the same shape of the bands data

        :param projwfc_arrays: nk x nb x nwfc array, to be
                               checked against bands

        :raise: AttributeError if energy is not already set
        :raise: AttributeError if input_array is not of same shape as
                dos_energy
        """
        try:
            shape_bands = np.shape(self.get_reference_bandsdata())
        except AttributeError:
            raise exceptions.ValidationError('Bands must be set first, then projwfc')
        # The [0:2] is so that each array, and not collection of arrays
        # is used to make the comparison
        if np.shape(projection_array) != shape_bands:
            raise AttributeError('These arrays are not the same shape as' ' the bands')
Exemplo n.º 22
0
        def get_or_create(self, label=None, **kwargs):
            """
            Try to retrieve a group from the DB with the given arguments;
            create (and store) a new group if such a group was not present yet.

            :param label: group label
            :type label: str

            :return: (group, created) where group is the group (new or existing,
              in any case already stored) and created is a boolean saying
            :rtype: (:class:`aiida.orm.Group`, bool)
            """
            if not label:
                raise ValueError('Group label must be provided')

            filters = {'label': label}

            if 'type_string' in kwargs:
                if not isinstance(kwargs['type_string'], six.string_types):
                    raise exceptions.ValidationError(
                        'type_string must be {}, you provided an object of type '
                        '{}'.format(str, type(kwargs['type_string'])))

                filters['type_string'] = kwargs['type_string']

            res = self.find(filters=filters)

            if not res:
                return Group(label, backend=self.backend,
                             **kwargs).store(), True

            if len(res) > 1:
                raise exceptions.MultipleObjectsError(
                    'More than one groups found in the database')

            return res[0], False
Exemplo n.º 23
0
    def prepare_for_submission(self, folder):
        self.inputs.metadata.options.parser_name = 'z2pack.z2pack'
        self.inputs.metadata.options.output_filename = self._OUTPUT_Z2PACK_FILE
        self.inputs.metadata.options.input_filename = self._INPUT_Z2PACK_FILE

        calcinfo = datastructures.CalcInfo()

        codeinfo = datastructures.CodeInfo()
        codeinfo.stdout_name = self._OUTPUT_Z2PACK_FILE
        codeinfo.stdin_name = self._INPUT_Z2PACK_FILE
        codeinfo.code_uuid = self.inputs.code.uuid
        calcinfo.codes_info = [codeinfo]

        calcinfo.codes_run_mode = datastructures.CodeRunMode.SERIAL
        calcinfo.cmdline_params = []

        calcinfo.retrieve_list = []
        calcinfo.retrieve_temporary_list = []
        calcinfo.local_copy_list = []
        calcinfo.remote_copy_list = []
        calcinfo.remote_symlink_list = []

        inputs = [
            self._INPUT_PW_NSCF_FILE,
            self._INPUT_OVERLAP_FILE,
            self._INPUT_W90_FILE,
        ]
        outputs = [
            self._OUTPUT_Z2PACK_FILE,
            self._OUTPUT_SAVE_FILE,
            self._OUTPUT_RESULT_FILE,
        ]
        errors = [
            os.path.join('build', a)
            for a in [self._ERROR_W90_FILE, self._ERROR_PW_FILE]
        ]

        calcinfo.retrieve_list.extend(outputs)
        calcinfo.retrieve_list.extend(errors)

        parent = self.inputs.parent_folder
        rpath = parent.get_remote_path()
        uuid = parent.computer.uuid
        parent_type = parent.creator.process_class

        if parent_type == Z2packCalculation:
            self._set_inputs_from_parent_z2pack()
        elif parent_type == PwCalculation:
            self._set_inputs_from_parent_scf()

        pw_dct = _lowercase_dict(self.inputs.pw_parameters.get_dict(),
                                 'pw_dct')
        sys = pw_dct['system']
        if sys.get('noncolin', False) and sys.get('lspinorb', False):
            self._blocked_keywords_wannier90.append(('spinors', True))

        try:
            settings = _lowercase_dict(self.inputs.z2pack_settings.get_dict(),
                                       'z2pack_settings')
        except AttributeError:
            raise exceptions.InputValidationError(
                'Must provide `z2pack_settings` input for `scf` calculation.')
        symlink = settings.get('parent_folder_symlink', False)
        self.restart_mode = settings.get('restart_mode', True)
        ptr = calcinfo.remote_symlink_list if symlink else calcinfo.remote_copy_list

        if parent_type == PwCalculation:
            prepare_nscf(self, folder)
            prepare_overlap(self, folder)
            prepare_wannier90(self, folder)
        elif parent_type == Z2packCalculation:
            if self.restart_mode:
                calcinfo.remote_copy_list.append((
                    uuid,
                    os.path.join(rpath, self._OUTPUT_SAVE_FILE),
                    self._OUTPUT_SAVE_FILE,
                ))

            calcinfo.remote_copy_list.extend([(uuid, os.path.join(rpath,
                                                                  inp), inp)
                                              for inp in inputs])
        else:
            raise exceptions.ValidationError(
                'parent node must be either from a PWscf or a Z2pack calculation.'
            )

        parent_files = [self._PSEUDO_SUBFOLDER, self._OUTPUT_SUBFOLDER]
        ptr.extend([(uuid, os.path.join(rpath, fname), fname)
                    for fname in parent_files])

        prepare_z2pack(self, folder)

        return calcinfo
Exemplo n.º 24
0
def validate_traversal_rules(ruleset=GraphTraversalRules.DEFAULT, **kwargs):
    """
    Validates the keywords with a ruleset template and returns a parsed dictionary
    ready to be used.

    :type ruleset: :py:class:`aiida.common.links.GraphTraversalRules`
    :param ruleset: Ruleset template used to validate the set of rules.
    :param bool input_calc_forward: will traverse INPUT_CALC links in the forward direction.
    :param bool input_calc_backward: will traverse INPUT_CALC links in the backward direction.
    :param bool create_forward: will traverse CREATE links in the forward direction.
    :param bool create_backward: will traverse CREATE links in the backward direction.
    :param bool return_forward: will traverse RETURN links in the forward direction.
    :param bool return_backward: will traverse RETURN links in the backward direction.
    :param bool input_work_forward: will traverse INPUT_WORK links in the forward direction.
    :param bool input_work_backward: will traverse INPUT_WORK links in the backward direction.
    :param bool call_calc_forward: will traverse CALL_CALC links in the forward direction.
    :param bool call_calc_backward: will traverse CALL_CALC links in the backward direction.
    :param bool call_work_forward: will traverse CALL_WORK links in the forward direction.
    :param bool call_work_backward: will traverse CALL_WORK links in the backward direction.
    """
    from aiida.common import exceptions

    if not isinstance(ruleset, GraphTraversalRules):
        raise TypeError(
            'ruleset input must be of type aiida.common.links.GraphTraversalRules\ninstead, it is: {}'.format(
                type(ruleset)
            )
        )

    rules_applied = {}
    links_forward = []
    links_backward = []

    for name, rule in ruleset.value.items():

        follow = rule.default

        if name in kwargs:

            if not rule.toggleable:
                raise ValueError('input rule {} is not toggleable for ruleset {}'.format(name, ruleset))

            follow = kwargs.pop(name)

            if not isinstance(follow, bool):
                raise ValueError('the value of rule {} must be boolean, but it is: {}'.format(name, follow))

        if follow:

            if rule.direction == 'forward':
                links_forward.append(rule.link_type)
            elif rule.direction == 'backward':
                links_backward.append(rule.link_type)
            else:
                raise exceptions.InternalError(
                    'unrecognized direction `{}` for graph traversal rule'.format(rule.direction)
                )

        rules_applied[name] = follow

    if kwargs:
        error_message = 'unrecognized keywords: {}'.format(', '.join(kwargs.keys()))
        raise exceptions.ValidationError(error_message)

    valid_output = {
        'rules_applied': rules_applied,
        'forward': links_forward,
        'backward': links_backward,
    }

    return valid_output
Exemplo n.º 25
0
 def _hostname_validator(cls, hostname):
     """
     Validates the hostname.
     """
     if not hostname.strip():
         raise exceptions.ValidationError('No hostname specified')
Exemplo n.º 26
0
 def _name_validator(cls, name):
     """
     Validates the name.
     """
     if not name.strip():
         raise exceptions.ValidationError('No name specified')
Exemplo n.º 27
0
 def _transport_type_validator(cls, transport_type):
     """
     Validates the transport string.
     """
     if transport_type not in transports.Transport.get_valid_transports():
         raise exceptions.ValidationError('The specified transport is not a valid one')
Exemplo n.º 28
0
 def _scheduler_type_validator(cls, scheduler_type):
     """
     Validates the transport string.
     """
     if scheduler_type not in schedulers.Scheduler.get_valid_schedulers():
         raise exceptions.ValidationError('The specified scheduler is not a valid one')
Exemplo n.º 29
0
    def set_projectiondata(self,
                           list_of_orbitals,
                           list_of_projections=None,
                           list_of_energy=None,
                           list_of_pdos=None,
                           tags=None,
                           bands_check=True):
        """
        Stores the projwfc_array using the projwfc_label, after validating both.

        :param list_of_orbitals: list of orbitals, of class orbital data.
                                 They should be the ones up on which the
                                 projection array corresponds with.

        :param list_of_projections: list of arrays of projections of a atomic
                              wavefunctions onto bloch wavefunctions. Since the
                              projection is for every bloch wavefunction which
                              can be specified by its spin (if used), band, and
                              kpoint the dimensions must be
                              nspin x nbands x nkpoints for the projwfc array.
                              Or nbands x nkpoints if spin is not used.

        :param energy_axis: list of energy axis for the list_of_pdos

        :param list_of_pdos: a list of projected density of states for the
                             atomic wavefunctions, units in states/eV

        :param tags: A list of tags, not supported currently.

        :param bands_check: if false, skips checks of whether the bands has
                            been already set, and whether the sizes match. For
                            use in parsers, where the BandsData has not yet
                            been stored and therefore get_reference_bandsdata
                            cannot be called
        """

        # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements

        def single_to_list(item):
            """
            Checks if the item is a list or tuple, and converts it to a list
            if it is not already a list or tuple

            :param item: an object which may or may not be a list or tuple
            :return: item_list: the input item unchanged if list or tuple and
                                [item] otherwise
            """
            if isinstance(item, (list, tuple)):
                return item

            return [item]

        def array_list_checker(array_list, array_name, orb_length):
            """
            Does basic checks over everything in the array_list. Makes sure that
            all the arrays are np.ndarray floats, that the length is same as
            required_length, raises exception using array_name if there is
            a failure
            """
            if not all([isinstance(_, np.ndarray) for _ in array_list]):
                raise exceptions.ValidationError(
                    '{} was not composed entirely of ndarrays'.format(
                        array_name))
            if len(array_list) != orb_length:
                raise exceptions.ValidationError(
                    '{} did not have the same length as the '
                    'list of orbitals'.format(array_name))

        ##############
        list_of_orbitals = single_to_list(list_of_orbitals)
        list_of_orbitals = copy.deepcopy(list_of_orbitals)

        # validates the input data
        if not list_of_pdos and not list_of_projections:
            raise exceptions.ValidationError(
                'Must set either pdos or projections')
        if bool(list_of_energy) != bool(list_of_pdos):
            raise exceptions.ValidationError(
                'list_of_pdos and list_of_energy must always be set together')

        orb_length = len(list_of_orbitals)

        # verifies and sets the orbital dicts
        list_of_orbital_dicts = []
        for i, _ in enumerate(list_of_orbitals):
            this_orbital = list_of_orbitals[i]
            orbital_dict = this_orbital.get_orbital_dict()
            try:
                orbital_type = orbital_dict.pop('_orbital_type')
            except KeyError:
                raise exceptions.ValidationError(
                    'No _orbital_type key found in dictionary: {}'.format(
                        orbital_dict))
            cls = OrbitalFactory(orbital_type)
            test_orbital = cls(**orbital_dict)
            list_of_orbital_dicts.append(test_orbital.get_orbital_dict())
        self.set_attribute('orbital_dicts', list_of_orbital_dicts)

        # verifies and sets the projections
        if list_of_projections:
            list_of_projections = single_to_list(list_of_projections)
            array_list_checker(list_of_projections, 'projections', orb_length)
            for i, _ in enumerate(list_of_projections):
                this_projection = list_of_projections[i]
                array_name = self._from_index_to_arrayname(i)
                if bands_check:
                    self._check_projections_bands(this_projection)
                self.set_array('proj_{}'.format(array_name), this_projection)

        # verifies and sets both pdos and energy
        if list_of_pdos:
            list_of_pdos = single_to_list(list_of_pdos)
            list_of_energy = single_to_list(list_of_energy)
            array_list_checker(list_of_pdos, 'pdos', orb_length)
            array_list_checker(list_of_energy, 'energy', orb_length)
            for i, _ in enumerate(list_of_pdos):
                this_pdos = list_of_pdos[i]
                this_energy = list_of_energy[i]
                array_name = self._from_index_to_arrayname(i)
                if bands_check:
                    self._check_projections_bands(this_projection)
                self.set_array('pdos_{}'.format(array_name), this_pdos)
                self.set_array('energy_{}'.format(array_name), this_energy)

        # verifies and sets the tags
        if tags is not None:
            try:
                if len(tags) != len(list_of_orbitals):
                    raise exceptions.ValidationError(
                        'must set as many tags as projections')
            except IndexError:
                return exceptions.ValidationError('tags must be a list')

            if not all([isinstance(_, str) for _ in tags]):
                raise exceptions.ValidationError(
                    'Tags must set a list of strings')
            self.set_attribute('tags', tags)
Exemplo n.º 30
0
def validate_traversal_rules(
        ruleset: GraphTraversalRules = GraphTraversalRules.DEFAULT,
        **traversal_rules: bool) -> dict:
    """
    Validates the keywords with a ruleset template and returns a parsed dictionary
    ready to be used.

    :param ruleset: Ruleset template used to validate the set of rules.
    :param input_calc_forward: will traverse INPUT_CALC links in the forward direction.
    :param input_calc_backward: will traverse INPUT_CALC links in the backward direction.
    :param create_forward: will traverse CREATE links in the forward direction.
    :param create_backward: will traverse CREATE links in the backward direction.
    :param return_forward: will traverse RETURN links in the forward direction.
    :param return_backward: will traverse RETURN links in the backward direction.
    :param input_work_forward: will traverse INPUT_WORK links in the forward direction.
    :param input_work_backward: will traverse INPUT_WORK links in the backward direction.
    :param call_calc_forward: will traverse CALL_CALC links in the forward direction.
    :param call_calc_backward: will traverse CALL_CALC links in the backward direction.
    :param call_work_forward: will traverse CALL_WORK links in the forward direction.
    :param call_work_backward: will traverse CALL_WORK links in the backward direction.
    """
    if not isinstance(ruleset, GraphTraversalRules):
        raise TypeError(
            f'ruleset input must be of type aiida.common.links.GraphTraversalRules\ninstead, it is: {type(ruleset)}'
        )

    rules_applied: Dict[str, bool] = {}
    links_forward: List[LinkType] = []
    links_backward: List[LinkType] = []

    for name, rule in ruleset.value.items():

        follow = rule.default

        if name in traversal_rules:

            if not rule.toggleable:
                raise ValueError(
                    f'input rule {name} is not toggleable for ruleset {ruleset}'
                )

            follow = traversal_rules.pop(name)

            if not isinstance(follow, bool):
                raise ValueError(
                    f'the value of rule {name} must be boolean, but it is: {follow}'
                )

        if follow:

            if rule.direction == 'forward':
                links_forward.append(rule.link_type)
            elif rule.direction == 'backward':
                links_backward.append(rule.link_type)
            else:
                raise exceptions.InternalError(
                    f'unrecognized direction `{rule.direction}` for graph traversal rule'
                )

        rules_applied[name] = follow

    if traversal_rules:
        error_message = f"unrecognized keywords: {', '.join(traversal_rules.keys())}"
        raise exceptions.ValidationError(error_message)

    valid_output = {
        'rules_applied': rules_applied,
        'forward': links_forward,
        'backward': links_backward,
    }

    return valid_output