Exemplo n.º 1
0
def get_formation_energy(chempots, doc, energy_key='enthalpy_per_atom'):
    """ From given chemical potentials, calculate the simplest
    formation energy per atom of the desired document.

    Note:
        recursive_get(doc, energy_key) MUST return an energy per atom for
        the target doc and the chemical potentials.

    Parameters:
        chempots (list of dict): list of chempot structures, must be unique.
        doc (dict): structure to evaluate.

    Keyword arguments:
        energy_key (str or list): name of energy field to use to calculate
            formation energy. Can use a list of keys/subkeys/indices to
            query nested dicts with `matador.utils.cursor_utils.recursive_get`.

    Returns:
        float: formation energy per atom.

    """
    from matador.utils.cursor_utils import recursive_get

    # warn user if per_atom energy is not found
    if isinstance(energy_key, (list, tuple)) and not any(['per_atom' in str(key) for key in energy_key]) \
            or (isinstance(energy_key, str) and 'per_atom' not in energy_key):
        warnings.warn('Requested energy key {} in get_formation_energy may'
                      ' not be per atom, if so results will be incorrect.'.format(energy_key))

    try:
        formation = recursive_get(doc, energy_key)
    except KeyError as exc:
        print('Doc {} missing key {}'.format(doc['source'], energy_key))
        raise exc

    # see if num chempots has been set and try to reuse it
    if 'num_chempots' in doc:
        num_chempots = doc['num_chempots']
    else:
        num_chempots = get_number_of_chempots(doc, chempots)

    num_atoms_per_fu = get_atoms_per_fu(doc)
    for ind, mu in enumerate(chempots):
        num_atoms_per_mu = get_atoms_per_fu(mu)
        try:
            mu_energy = recursive_get(mu, energy_key)
        except KeyError as exc:
            raise exc('Chemical potential {} missing key {}'.format(mu['source'], energy_key))

        formation -= mu_energy * num_chempots[ind] * num_atoms_per_mu / num_atoms_per_fu

    return formation
Exemplo n.º 2
0
    def construct_phase_diagram(self):
        """ Create a phase diagram with arbitrary chemical potentials.

        Expects self.cursor to be populated with structures and chemical potential
        labels to be set under self.species.

        """
        self.set_chempots()
        self.cursor = filter_cursor_by_chempots(self.species, self.cursor)

        formation_key = 'formation_{}'.format(self.energy_key)
        extensive_formation_key = 'formation_{}'.format(
            self._extensive_energy_key)
        for ind, doc in enumerate(self.cursor):
            self.cursor[ind][formation_key] = get_formation_energy(
                self.chempot_cursor, doc, energy_key=self.energy_key)
            self.cursor[ind][extensive_formation_key] = doc[
                formation_key] * doc['num_atoms']

        if self._non_elemental and self.args.get('subcmd') in [
                'voltage', 'volume'
        ]:
            raise NotImplementedError(
                'Pseudo-binary/pseudo-ternary voltages not yet implemented.')

        self.phase_diagram = PhaseDiagram(self.cursor, formation_key,
                                          self._dimension)
        # aliases for data stored in phase diagram
        self.structures = self.phase_diagram.structures
        self.hull_dist = self.phase_diagram.hull_dist
        self.convex_hull = self.phase_diagram.convex_hull

        # ensure hull cursor is sorted by enthalpy_per_atom,
        # then by concentration, as it will be by default if from database
        hull_cursor = [
            self.cursor[idx]
            for idx in np.where(self.hull_dist <= self.hull_cutoff + EPS)[0]
        ]
        # TODO: check why this fails when the opposite way around
        hull_cursor = sorted(
            hull_cursor,
            key=lambda doc:
            (doc['concentration'], recursive_get(doc, self.energy_key)))

        # by default hull cursor includes all structures within hull_cutoff
        # if summary requested and we're in hulldiff mode, filter hull_cursor for lowest per stoich
        if self.args.get('summary') and self.args['subcmd'] == 'hulldiff':
            tmp_hull_cursor = []
            compositions = set()
            for ind, member in enumerate(hull_cursor):
                formula = get_formula_from_stoich(member['stoichiometry'],
                                                  sort=True)
                if formula not in compositions:
                    compositions.add(formula)
                    tmp_hull_cursor.append(member)

            hull_cursor = tmp_hull_cursor

        self.hull_cursor = hull_cursor
Exemplo n.º 3
0
    def __init__(self,
                 cursor,
                 data_key,
                 energy_key='enthalpy_per_atom',
                 chempot_energy_key=None,
                 num_samples=None,
                 parameter_key=None,
                 species=None,
                 voltage=False,
                 verbosity=None,
                 **kwargs):
        """ Initialise EnsembleHull from a cursor, with other keywords
        following QueryConvexHull.

        Parameters:
            cursor (list[dict]): list of matador documents containing
                variable parameter data for energies.
            data_key (str): the key under which all parameter data is
                stored to the variable parameter, e.g. `_beef` or `_temperature`.

        Keyword arguments:
            energy_key (str): the key under `parameter_key` to use to create
                the hulls.
            chempot_energy_key (str): the key used to create the first convex hull.
            parameter_key (str): the key pertaining to the variable parameter
                itself, e.g. `temperature` or `thetas`.
            num_samples (int): use up to this many samples in creating the hull.
            species (list[str]): list of elements/chempots to use, in
                the desired order.
            voltage (bool): whether or not to compute voltage curves.
            plot_kwargs (dict): arguments to pass to plot_hull function.
            kwargs (dict): other arguments to pass to QueryConvexHull.

        """
        # sometimes the first hull needs to be made with a different key
        if chempot_energy_key is not None:
            self.chempot_energy_key = chempot_energy_key
        else:
            self.chempot_energy_key = energy_key

        super().__init__(cursor=cursor,
                         energy_key=self.chempot_energy_key,
                         species=species,
                         voltage=voltage,
                         no_plot=True,
                         lazy=False,
                         **kwargs)

        self.energy_key = energy_key

        if self.phase_diagram is None:
            del self.phase_diagram
        if self.hull_dist is None:
            del self.hull_dist

        self.from_cursor = True
        self.verbosity = verbosity
        # set up relative keys
        self.formation_key = 'formation_' + self.energy_key
        self.data_key = data_key
        self.parameter_key = parameter_key

        if self.parameter_key is None:
            self._parameter_keys = None
        else:
            self._parameter_keys = [self.data_key] + [parameter_key]
        self._formation_keys = [self.data_key] + [self.formation_key]
        self._hulldist_keys = [self.data_key] + ['hull_distance']
        self._energy_keys = [self.data_key] + [self.energy_key]

        self.phase_diagrams = []

        self.set_chempots(energy_key=self.chempot_energy_key)
        self.cursor = filter_cursor_by_chempots(self.species, self.cursor)
        self.cursor = sorted(self.cursor,
                             key=lambda doc:
                             (recursive_get(doc, self.chempot_energy_key), doc[
                                 'concentration']))

        if self.parameter_key is None:
            parameter_iterable = recursive_get(self.chempot_cursor[0],
                                               self._energy_keys)
            _keys = self._energy_keys
        else:
            parameter_iterable = recursive_get(self.chempot_cursor[0],
                                               self._parameter_keys)
            _keys = self.parameter_key

        if parameter_iterable is None:
            raise RuntimeError(
                f"Could not find any data for keys {_keys} in {self.chempot_cursor[0]}."
            )

        print(
            f"Found {len(parameter_iterable)} entries under data key: {self.data_key}."
        )

        # allocate formation energy and hull distance arrays
        for ind, doc in enumerate(self.cursor):
            recursive_set(doc, self._formation_keys,
                          [None] * len(recursive_get(doc, self._energy_keys)))
            recursive_set(doc, self._hulldist_keys,
                          [None] * len(recursive_get(doc, self._energy_keys)))

        n_hulls = len(parameter_iterable)
        if num_samples is not None:
            parameter_iterable = parameter_iterable[:num_samples]
            print(
                f"Using {num_samples} out of {n_hulls} possible phase diagrams."
            )
        else:
            num_samples = n_hulls

        for param_ind, parameter in enumerate(tqdm.tqdm(parameter_iterable)):
            for ind, doc in enumerate(self.cursor):
                if self.parameter_key is not None:
                    assert recursive_get(doc, self._parameter_keys +
                                         [param_ind]) == parameter

                formation_energy = get_formation_energy(
                    self.chempot_cursor,
                    doc,
                    energy_key=self._energy_keys + [param_ind])
                recursive_set(self.cursor[ind],
                              self._formation_keys + [param_ind],
                              formation_energy)
            self.phase_diagrams.append(
                PhaseDiagram(self.cursor, self._formation_keys + [param_ind],
                             self._dimension))
            set_cursor_from_array(self.cursor,
                                  self.phase_diagrams[-1].hull_dist,
                                  self._hulldist_keys + [param_ind])

        self.stability_histogram = self.generate_stability_statistics()
Exemplo n.º 4
0
    def set_chempots(self, energy_key=None):
        """ Search for chemical potentials that match the structures in
        the query cursor and add them to the cursor. Also set the concentration
        of chemical potentials in :attr:`cursor`, if not already set.

        """
        if energy_key is None:
            energy_key = self.energy_key
        query = self._query
        query_dict = dict()
        species_stoich = [sorted(get_stoich_from_formula(spec, sort=False)) for spec in self.species]
        self.chempot_cursor = []

        if self.args.get('chempots') is not None:
            self.fake_chempots(custom_elem=self.species)

        elif self.from_cursor:
            chempot_cursor = sorted([doc for doc in self.cursor if doc['stoichiometry'] in species_stoich],
                                    key=lambda doc: recursive_get(doc, energy_key))

            for species in species_stoich:
                for doc in chempot_cursor:
                    if doc['stoichiometry'] == species:
                        self.chempot_cursor.append(doc)
                        break

            if len(self.chempot_cursor) != len(self.species):
                raise RuntimeError('Found {} of {} required chemical potentials'
                                   .format(len(self.chempot_cursor), len(self.species)))

            if self.args.get('debug'):
                print([mu['stoichiometry'] for mu in self.chempot_cursor])

        else:
            print(60 * '─')
            self.chempot_cursor = len(self.species) * [None]
            # scan for suitable chem pots in database
            for ind, elem in enumerate(self.species):

                print('Scanning for suitable', elem, 'chemical potential...')
                query_dict['$and'] = deepcopy(list(query.calc_dict['$and']))

                if not self.args.get('ignore_warnings'):
                    query_dict['$and'].append(query.query_quality())

                if len(species_stoich[ind]) == 1:
                    query_dict['$and'].append(query.query_composition(custom_elem=[elem]))
                else:
                    query_dict['$and'].append(query.query_stoichiometry(custom_stoich=[elem]))

                # if oqmd, only query composition, not parameters
                if query.args.get('tags') is not None:
                    query_dict['$and'].append(query.query_tags())

                mu_cursor = query.repo.find(SON(query_dict)).sort(energy_key, pm.ASCENDING)
                if mu_cursor.count() == 0:
                    print_notify('Failed... searching without spin polarization field...')
                    scanned = False
                    while not scanned:
                        for idx, dicts in enumerate(query_dict['$and']):
                            for key in dicts:
                                if key == 'spin_polarized':
                                    del query_dict['$and'][idx][key]
                                    break
                            if idx == len(query_dict['$and']) - 1:
                                scanned = True
                    mu_cursor = query.repo.find(SON(query_dict)).sort(energy_key, pm.ASCENDING)

                if mu_cursor.count() == 0:
                    raise RuntimeError('No chemical potentials found for {}...'.format(elem))

                self.chempot_cursor[ind] = mu_cursor[0]
                if self.chempot_cursor[ind] is not None:
                    print('Using', ''.join([self.chempot_cursor[ind]['text_id'][0], ' ',
                                            self.chempot_cursor[ind]['text_id'][1]]), 'as chem pot for', elem)
                    print(60 * '─')
                else:
                    raise RuntimeError('No possible chem pots available for {}.'.format(elem))

            for i, mu in enumerate(self.chempot_cursor):
                self.chempot_cursor[i][self._extensive_energy_key + '_per_b'] = mu[energy_key]
                self.chempot_cursor[i]['num_a'] = 0

            self.chempot_cursor[0]['num_a'] = float('inf')

        # don't check for IDs if we're loading from cursor
        if not self.from_cursor:
            ids = [doc['_id'] for doc in self.cursor]
            if self.chempot_cursor[0]['_id'] is None or self.chempot_cursor[0]['_id'] not in ids:
                self.cursor.insert(0, self.chempot_cursor[0])
            for match in self.chempot_cursor[1:]:
                if match['_id'] is None or match['_id'] not in ids:
                    self.cursor.append(match)

        # add faked chempots to overall cursor
        elif self.args.get('chempots') is not None:
            self.cursor.insert(0, self.chempot_cursor[0])
            self.cursor.extend(self.chempot_cursor[1:])

        # find all elements present in the chemical potentials
        elements = []
        for mu in self.chempot_cursor:
            for elem, _ in mu['stoichiometry']:
                if elem not in elements:
                    elements.append(elem)
        self.elements = elements
        self.num_elements = len(elements)
Exemplo n.º 5
0
    def test_recursive_get_set(self):
        from matador.utils.cursor_utils import recursive_get, recursive_set

        nested_dict = {
            "source": ["blah", "foo"],
            "lattice_cart": [[1, 0, 0], [0, 1, 0], [0, 0, 1]],
            "_beef": {
                "total_energy": [1, 2, 3, 4],
                "thetas": [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
                "foo": {
                    "blah": "bloop"
                },
            },
        }
        self.assertEqual(recursive_get(nested_dict, ["_beef", "thetas", -1]),
                         [7, 8, 9])
        self.assertEqual(
            recursive_get(nested_dict, ["_beef", "total_energy", 2]), 3)
        self.assertEqual(recursive_get(nested_dict, ["_beef", "foo", "blah"]),
                         "bloop")
        recursive_set(nested_dict, ["_beef", "thetas", -1], [1, 2, 3])
        self.assertEqual(recursive_get(nested_dict, ["_beef", "thetas", -1]),
                         [1, 2, 3])
        recursive_set(nested_dict, ["_beef", "total_energy", 2], 3.5)
        self.assertEqual(
            recursive_get(nested_dict, ["_beef", "total_energy", 2]), 3.5)
        nested_dict["_beef"]["total_energy"][2] = 3
        self.assertEqual(
            recursive_get(nested_dict, ["_beef", "total_energy", 2]), 3)
        nested_dict["_beef"]["foo"]["blah"] = (1, 2)
        self.assertEqual(recursive_get(nested_dict, ["_beef", "foo", "blah"]),
                         (1, 2))

        with self.assertRaises(IndexError):
            recursive_get(nested_dict, ["_beef", "thetas", 4])
        with self.assertRaises(KeyError):
            recursive_get(nested_dict, ["_beef", "thetaz", 4])
        with self.assertRaises(KeyError):
            recursive_get(nested_dict, ["_beef", "foo", "blahp"])
Exemplo n.º 6
0
    def __init__(self, *args, settings=None):
        """ Set up arguments and initialise DB client.

        Notes:
            Several arguments can be passed to this class from the command-line,
            and here are interpreted through *args:

        Parameters:
            db (str): the name of the collection to import to.
            scan (bool): whether or not to just scan the directory, rather
                than importing (automatically sets dryrun to true).
            dryrun (bool): perform whole process, up to actually importing to
                the database.
            tags (str): apply this tag to each structure added to database.
            force (bool): override rules about which folders can be imported
                into main database.
            recent_only (bool): if true, sort file lists by modification
                date and stop scanning when a file that already exists in
                database is found.

        """
        self.args = args[0]
        self.dryrun = self.args.get('dryrun')
        self.scan = self.args.get('scan')
        self.recent_only = self.args.get('recent_only')
        if self.scan:
            self.dryrun = True
        self.debug = self.args.get('debug')
        self.verbosity = self.args.get('verbosity') or 0
        self.config_fname = self.args.get('config')
        self.tags = self.args.get('tags')
        self.prototype = self.args.get('prototype')
        self.tag_dict = dict()
        self.tag_dict['tags'] = self.tags
        self.import_count = 0
        self.skipped = 0
        self.exclude_patterns = ['bad_castep', 'input']
        self.errors = 0
        self.struct_list = []
        self.path_list = []

        self.log = logging.getLogger('spatula')
        loglevel = {3: "DEBUG", 2: "INFO", 1: "WARN", 0: "ERROR"}
        self.log.setLevel(loglevel.get(self.verbosity, "WARN"))

        handler = logging.StreamHandler(sys.stdout)
        formatter = logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        handler.setFormatter(formatter)
        self.log.addHandler(handler)

        # I/O files
        if not self.dryrun:
            self.num_words = len(WORDS)
            self.num_nouns = len(NOUNS)

        if not self.scan:
            self.logfile = tempfile.NamedTemporaryFile(mode='w+t',
                                                       delete=False)
            self.manifest = tempfile.NamedTemporaryFile(mode='w+t',
                                                        delete=False)

        if settings is None:
            self.settings = load_custom_settings(
                config_fname=self.config_fname,
                debug=self.debug,
                no_quickstart=self.args.get('no_quickstart'))
        else:
            self.settings = settings

        result = make_connection_to_collection(
            self.args.get('db'),
            check_collection=False,
            import_mode=True,
            override=self.args.get('no_quickstart'),
            mongo_settings=self.settings)

        self.client, self.db, self.collections = result
        # perform some relevant collection-dependent checks
        assert len(self.collections) == 1, 'Can only import to one collection.'
        self.repo = list(self.collections.values())[0]

        # if trying to import to the default repo, without doing a dryrun or forcing it, then
        # check if we're in the protected directory, i.e. the only one that is allowed to import
        # to the default collection
        default_collection = recursive_get(self.settings,
                                           ['mongo', 'default_collection'])
        try:
            default_collection_file_path = recursive_get(
                self.settings, ['mongo', 'default_collection_file_path'])
        except KeyError:
            default_collection_file_path = None

        if self.args.get('db') is None or self.args.get(
                'db') == default_collection:
            if not self.dryrun and not self.args.get('force'):
                # if using default collection, check we are in the correct path
                if default_collection_file_path is not None:
                    if not os.getcwd().startswith(
                            os.path.expanduser(default_collection_file_path)):
                        print(80 * '!')
                        print(
                            'You shouldn\'t be importing to the default database from this folder! '
                            'Please use --db <YourDBName> to create a new collection, '
                            'or copy these files to the correct place!')
                        print(80 * '!')
                        raise RuntimeError(
                            'Failed to import to default collection from '
                            'current directory, import must be called from {}'.
                            format(default_collection_file_path))
        else:
            if self.args.get('db') is not None:
                if any(['oqmd' in db for db in self.args.get('db')]):
                    exit('Cannot import directly to oqmd repo')
                elif len(self.args.get('db')) > 1:
                    exit('Can only import to one collection.')

        num_prototypes_in_db = self.repo.count_documents({'prototype': True})
        num_objects_in_db = self.repo.count_documents({})
        self.log.info(
            "Found {} existing objects in database.".format(num_objects_in_db))
        if self.args.get('prototype'):
            if num_prototypes_in_db != num_objects_in_db:
                raise SystemExit(
                    'I will not import prototypes to a non-prototype database!'
                )
        else:
            if num_prototypes_in_db != 0:
                raise SystemExit(
                    'I will not import DFT calculations into a prototype database!'
                )

        if not self.dryrun:
            # either drop and recreate or create spatula report collection
            self.db.spatula.drop()
            self.report = self.db.spatula

        # scan directory on init
        self.file_lists = self._scan_dir()
        # if import, as opposed to rebuild, scan for duplicates and remove from list
        if not self.args.get('subcmd') == 'rebuild':
            self.file_lists, skipped = self._scan_dupes(self.file_lists)
            self.skipped += skipped

        # print number of files found
        self._display_import()

        # only create dicts if not just scanning
        if not self.scan:
            # convert to dict and db if required
            self._files2db(self.file_lists)
        if self.import_count == 0:
            print('No new structures imported!')
        if not self.dryrun and self.import_count > 0:
            print('Successfully imported', self.import_count, 'structures!')
            # index by enthalpy for faster/larger queries
            count = 0
            for _ in self.repo.list_indexes():
                count += 1
            # ignore default id index
            if count > 1:
                self.log.info('Index found, rebuilding...')
                self.repo.reindex()
            else:
                self.log.info('Building index...')
                self.repo.create_index([('enthalpy_per_atom', pm.ASCENDING)])
                self.repo.create_index([('stoichiometry', pm.ASCENDING)])
                self.repo.create_index([('cut_off_energy', pm.ASCENDING)])
                self.repo.create_index([('species_pot', pm.ASCENDING)])
                self.repo.create_index([('kpoints_mp_spacing', pm.ASCENDING)])
                self.repo.create_index([('xc_functional', pm.ASCENDING)])
                self.repo.create_index([('elems', pm.ASCENDING)])
                # index by source for rebuilds
                self.repo.create_index([('source', pm.ASCENDING)])
                self.log.info('Done!')
        elif self.dryrun:
            self.log.info('Dryrun complete!')

        if not self.scan:
            self.logfile.seek(0)
            errors = sum(1 for line in self.logfile)
            self.errors += errors
            if errors == 1:
                self.log.warning('There is 1 error to view in {}'.format(
                    self.logfile.name))
            elif errors == 0:
                self.log.warning('There are no errors to view in {}'.format(
                    self.logfile.name))
            elif errors > 1:
                self.log.warning('There are {} errors to view in {}'.format(
                    errors, self.logfile.name))

        try:
            self.logfile.close()
        except Exception:
            pass
        try:
            self.manifest.close()
        except Exception:
            pass

        if not self.dryrun:
            # construct dictionary in spatula_report collection to hold info
            report_dict = dict()
            report_dict['last_modified'] = datetime.datetime.utcnow().replace(
                microsecond=0)
            report_dict['num_success'] = self.import_count
            report_dict['num_errors'] = errors
            report_dict['version'] = __version__
            self.report.insert_one(report_dict)