def get_formation_energy(chempots, doc, energy_key='enthalpy_per_atom'): """ From given chemical potentials, calculate the simplest formation energy per atom of the desired document. Note: recursive_get(doc, energy_key) MUST return an energy per atom for the target doc and the chemical potentials. Parameters: chempots (list of dict): list of chempot structures, must be unique. doc (dict): structure to evaluate. Keyword arguments: energy_key (str or list): name of energy field to use to calculate formation energy. Can use a list of keys/subkeys/indices to query nested dicts with `matador.utils.cursor_utils.recursive_get`. Returns: float: formation energy per atom. """ from matador.utils.cursor_utils import recursive_get # warn user if per_atom energy is not found if isinstance(energy_key, (list, tuple)) and not any(['per_atom' in str(key) for key in energy_key]) \ or (isinstance(energy_key, str) and 'per_atom' not in energy_key): warnings.warn('Requested energy key {} in get_formation_energy may' ' not be per atom, if so results will be incorrect.'.format(energy_key)) try: formation = recursive_get(doc, energy_key) except KeyError as exc: print('Doc {} missing key {}'.format(doc['source'], energy_key)) raise exc # see if num chempots has been set and try to reuse it if 'num_chempots' in doc: num_chempots = doc['num_chempots'] else: num_chempots = get_number_of_chempots(doc, chempots) num_atoms_per_fu = get_atoms_per_fu(doc) for ind, mu in enumerate(chempots): num_atoms_per_mu = get_atoms_per_fu(mu) try: mu_energy = recursive_get(mu, energy_key) except KeyError as exc: raise exc('Chemical potential {} missing key {}'.format(mu['source'], energy_key)) formation -= mu_energy * num_chempots[ind] * num_atoms_per_mu / num_atoms_per_fu return formation
def construct_phase_diagram(self): """ Create a phase diagram with arbitrary chemical potentials. Expects self.cursor to be populated with structures and chemical potential labels to be set under self.species. """ self.set_chempots() self.cursor = filter_cursor_by_chempots(self.species, self.cursor) formation_key = 'formation_{}'.format(self.energy_key) extensive_formation_key = 'formation_{}'.format( self._extensive_energy_key) for ind, doc in enumerate(self.cursor): self.cursor[ind][formation_key] = get_formation_energy( self.chempot_cursor, doc, energy_key=self.energy_key) self.cursor[ind][extensive_formation_key] = doc[ formation_key] * doc['num_atoms'] if self._non_elemental and self.args.get('subcmd') in [ 'voltage', 'volume' ]: raise NotImplementedError( 'Pseudo-binary/pseudo-ternary voltages not yet implemented.') self.phase_diagram = PhaseDiagram(self.cursor, formation_key, self._dimension) # aliases for data stored in phase diagram self.structures = self.phase_diagram.structures self.hull_dist = self.phase_diagram.hull_dist self.convex_hull = self.phase_diagram.convex_hull # ensure hull cursor is sorted by enthalpy_per_atom, # then by concentration, as it will be by default if from database hull_cursor = [ self.cursor[idx] for idx in np.where(self.hull_dist <= self.hull_cutoff + EPS)[0] ] # TODO: check why this fails when the opposite way around hull_cursor = sorted( hull_cursor, key=lambda doc: (doc['concentration'], recursive_get(doc, self.energy_key))) # by default hull cursor includes all structures within hull_cutoff # if summary requested and we're in hulldiff mode, filter hull_cursor for lowest per stoich if self.args.get('summary') and self.args['subcmd'] == 'hulldiff': tmp_hull_cursor = [] compositions = set() for ind, member in enumerate(hull_cursor): formula = get_formula_from_stoich(member['stoichiometry'], sort=True) if formula not in compositions: compositions.add(formula) tmp_hull_cursor.append(member) hull_cursor = tmp_hull_cursor self.hull_cursor = hull_cursor
def __init__(self, cursor, data_key, energy_key='enthalpy_per_atom', chempot_energy_key=None, num_samples=None, parameter_key=None, species=None, voltage=False, verbosity=None, **kwargs): """ Initialise EnsembleHull from a cursor, with other keywords following QueryConvexHull. Parameters: cursor (list[dict]): list of matador documents containing variable parameter data for energies. data_key (str): the key under which all parameter data is stored to the variable parameter, e.g. `_beef` or `_temperature`. Keyword arguments: energy_key (str): the key under `parameter_key` to use to create the hulls. chempot_energy_key (str): the key used to create the first convex hull. parameter_key (str): the key pertaining to the variable parameter itself, e.g. `temperature` or `thetas`. num_samples (int): use up to this many samples in creating the hull. species (list[str]): list of elements/chempots to use, in the desired order. voltage (bool): whether or not to compute voltage curves. plot_kwargs (dict): arguments to pass to plot_hull function. kwargs (dict): other arguments to pass to QueryConvexHull. """ # sometimes the first hull needs to be made with a different key if chempot_energy_key is not None: self.chempot_energy_key = chempot_energy_key else: self.chempot_energy_key = energy_key super().__init__(cursor=cursor, energy_key=self.chempot_energy_key, species=species, voltage=voltage, no_plot=True, lazy=False, **kwargs) self.energy_key = energy_key if self.phase_diagram is None: del self.phase_diagram if self.hull_dist is None: del self.hull_dist self.from_cursor = True self.verbosity = verbosity # set up relative keys self.formation_key = 'formation_' + self.energy_key self.data_key = data_key self.parameter_key = parameter_key if self.parameter_key is None: self._parameter_keys = None else: self._parameter_keys = [self.data_key] + [parameter_key] self._formation_keys = [self.data_key] + [self.formation_key] self._hulldist_keys = [self.data_key] + ['hull_distance'] self._energy_keys = [self.data_key] + [self.energy_key] self.phase_diagrams = [] self.set_chempots(energy_key=self.chempot_energy_key) self.cursor = filter_cursor_by_chempots(self.species, self.cursor) self.cursor = sorted(self.cursor, key=lambda doc: (recursive_get(doc, self.chempot_energy_key), doc[ 'concentration'])) if self.parameter_key is None: parameter_iterable = recursive_get(self.chempot_cursor[0], self._energy_keys) _keys = self._energy_keys else: parameter_iterable = recursive_get(self.chempot_cursor[0], self._parameter_keys) _keys = self.parameter_key if parameter_iterable is None: raise RuntimeError( f"Could not find any data for keys {_keys} in {self.chempot_cursor[0]}." ) print( f"Found {len(parameter_iterable)} entries under data key: {self.data_key}." ) # allocate formation energy and hull distance arrays for ind, doc in enumerate(self.cursor): recursive_set(doc, self._formation_keys, [None] * len(recursive_get(doc, self._energy_keys))) recursive_set(doc, self._hulldist_keys, [None] * len(recursive_get(doc, self._energy_keys))) n_hulls = len(parameter_iterable) if num_samples is not None: parameter_iterable = parameter_iterable[:num_samples] print( f"Using {num_samples} out of {n_hulls} possible phase diagrams." ) else: num_samples = n_hulls for param_ind, parameter in enumerate(tqdm.tqdm(parameter_iterable)): for ind, doc in enumerate(self.cursor): if self.parameter_key is not None: assert recursive_get(doc, self._parameter_keys + [param_ind]) == parameter formation_energy = get_formation_energy( self.chempot_cursor, doc, energy_key=self._energy_keys + [param_ind]) recursive_set(self.cursor[ind], self._formation_keys + [param_ind], formation_energy) self.phase_diagrams.append( PhaseDiagram(self.cursor, self._formation_keys + [param_ind], self._dimension)) set_cursor_from_array(self.cursor, self.phase_diagrams[-1].hull_dist, self._hulldist_keys + [param_ind]) self.stability_histogram = self.generate_stability_statistics()
def set_chempots(self, energy_key=None): """ Search for chemical potentials that match the structures in the query cursor and add them to the cursor. Also set the concentration of chemical potentials in :attr:`cursor`, if not already set. """ if energy_key is None: energy_key = self.energy_key query = self._query query_dict = dict() species_stoich = [sorted(get_stoich_from_formula(spec, sort=False)) for spec in self.species] self.chempot_cursor = [] if self.args.get('chempots') is not None: self.fake_chempots(custom_elem=self.species) elif self.from_cursor: chempot_cursor = sorted([doc for doc in self.cursor if doc['stoichiometry'] in species_stoich], key=lambda doc: recursive_get(doc, energy_key)) for species in species_stoich: for doc in chempot_cursor: if doc['stoichiometry'] == species: self.chempot_cursor.append(doc) break if len(self.chempot_cursor) != len(self.species): raise RuntimeError('Found {} of {} required chemical potentials' .format(len(self.chempot_cursor), len(self.species))) if self.args.get('debug'): print([mu['stoichiometry'] for mu in self.chempot_cursor]) else: print(60 * '─') self.chempot_cursor = len(self.species) * [None] # scan for suitable chem pots in database for ind, elem in enumerate(self.species): print('Scanning for suitable', elem, 'chemical potential...') query_dict['$and'] = deepcopy(list(query.calc_dict['$and'])) if not self.args.get('ignore_warnings'): query_dict['$and'].append(query.query_quality()) if len(species_stoich[ind]) == 1: query_dict['$and'].append(query.query_composition(custom_elem=[elem])) else: query_dict['$and'].append(query.query_stoichiometry(custom_stoich=[elem])) # if oqmd, only query composition, not parameters if query.args.get('tags') is not None: query_dict['$and'].append(query.query_tags()) mu_cursor = query.repo.find(SON(query_dict)).sort(energy_key, pm.ASCENDING) if mu_cursor.count() == 0: print_notify('Failed... searching without spin polarization field...') scanned = False while not scanned: for idx, dicts in enumerate(query_dict['$and']): for key in dicts: if key == 'spin_polarized': del query_dict['$and'][idx][key] break if idx == len(query_dict['$and']) - 1: scanned = True mu_cursor = query.repo.find(SON(query_dict)).sort(energy_key, pm.ASCENDING) if mu_cursor.count() == 0: raise RuntimeError('No chemical potentials found for {}...'.format(elem)) self.chempot_cursor[ind] = mu_cursor[0] if self.chempot_cursor[ind] is not None: print('Using', ''.join([self.chempot_cursor[ind]['text_id'][0], ' ', self.chempot_cursor[ind]['text_id'][1]]), 'as chem pot for', elem) print(60 * '─') else: raise RuntimeError('No possible chem pots available for {}.'.format(elem)) for i, mu in enumerate(self.chempot_cursor): self.chempot_cursor[i][self._extensive_energy_key + '_per_b'] = mu[energy_key] self.chempot_cursor[i]['num_a'] = 0 self.chempot_cursor[0]['num_a'] = float('inf') # don't check for IDs if we're loading from cursor if not self.from_cursor: ids = [doc['_id'] for doc in self.cursor] if self.chempot_cursor[0]['_id'] is None or self.chempot_cursor[0]['_id'] not in ids: self.cursor.insert(0, self.chempot_cursor[0]) for match in self.chempot_cursor[1:]: if match['_id'] is None or match['_id'] not in ids: self.cursor.append(match) # add faked chempots to overall cursor elif self.args.get('chempots') is not None: self.cursor.insert(0, self.chempot_cursor[0]) self.cursor.extend(self.chempot_cursor[1:]) # find all elements present in the chemical potentials elements = [] for mu in self.chempot_cursor: for elem, _ in mu['stoichiometry']: if elem not in elements: elements.append(elem) self.elements = elements self.num_elements = len(elements)
def test_recursive_get_set(self): from matador.utils.cursor_utils import recursive_get, recursive_set nested_dict = { "source": ["blah", "foo"], "lattice_cart": [[1, 0, 0], [0, 1, 0], [0, 0, 1]], "_beef": { "total_energy": [1, 2, 3, 4], "thetas": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], "foo": { "blah": "bloop" }, }, } self.assertEqual(recursive_get(nested_dict, ["_beef", "thetas", -1]), [7, 8, 9]) self.assertEqual( recursive_get(nested_dict, ["_beef", "total_energy", 2]), 3) self.assertEqual(recursive_get(nested_dict, ["_beef", "foo", "blah"]), "bloop") recursive_set(nested_dict, ["_beef", "thetas", -1], [1, 2, 3]) self.assertEqual(recursive_get(nested_dict, ["_beef", "thetas", -1]), [1, 2, 3]) recursive_set(nested_dict, ["_beef", "total_energy", 2], 3.5) self.assertEqual( recursive_get(nested_dict, ["_beef", "total_energy", 2]), 3.5) nested_dict["_beef"]["total_energy"][2] = 3 self.assertEqual( recursive_get(nested_dict, ["_beef", "total_energy", 2]), 3) nested_dict["_beef"]["foo"]["blah"] = (1, 2) self.assertEqual(recursive_get(nested_dict, ["_beef", "foo", "blah"]), (1, 2)) with self.assertRaises(IndexError): recursive_get(nested_dict, ["_beef", "thetas", 4]) with self.assertRaises(KeyError): recursive_get(nested_dict, ["_beef", "thetaz", 4]) with self.assertRaises(KeyError): recursive_get(nested_dict, ["_beef", "foo", "blahp"])
def __init__(self, *args, settings=None): """ Set up arguments and initialise DB client. Notes: Several arguments can be passed to this class from the command-line, and here are interpreted through *args: Parameters: db (str): the name of the collection to import to. scan (bool): whether or not to just scan the directory, rather than importing (automatically sets dryrun to true). dryrun (bool): perform whole process, up to actually importing to the database. tags (str): apply this tag to each structure added to database. force (bool): override rules about which folders can be imported into main database. recent_only (bool): if true, sort file lists by modification date and stop scanning when a file that already exists in database is found. """ self.args = args[0] self.dryrun = self.args.get('dryrun') self.scan = self.args.get('scan') self.recent_only = self.args.get('recent_only') if self.scan: self.dryrun = True self.debug = self.args.get('debug') self.verbosity = self.args.get('verbosity') or 0 self.config_fname = self.args.get('config') self.tags = self.args.get('tags') self.prototype = self.args.get('prototype') self.tag_dict = dict() self.tag_dict['tags'] = self.tags self.import_count = 0 self.skipped = 0 self.exclude_patterns = ['bad_castep', 'input'] self.errors = 0 self.struct_list = [] self.path_list = [] self.log = logging.getLogger('spatula') loglevel = {3: "DEBUG", 2: "INFO", 1: "WARN", 0: "ERROR"} self.log.setLevel(loglevel.get(self.verbosity, "WARN")) handler = logging.StreamHandler(sys.stdout) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) self.log.addHandler(handler) # I/O files if not self.dryrun: self.num_words = len(WORDS) self.num_nouns = len(NOUNS) if not self.scan: self.logfile = tempfile.NamedTemporaryFile(mode='w+t', delete=False) self.manifest = tempfile.NamedTemporaryFile(mode='w+t', delete=False) if settings is None: self.settings = load_custom_settings( config_fname=self.config_fname, debug=self.debug, no_quickstart=self.args.get('no_quickstart')) else: self.settings = settings result = make_connection_to_collection( self.args.get('db'), check_collection=False, import_mode=True, override=self.args.get('no_quickstart'), mongo_settings=self.settings) self.client, self.db, self.collections = result # perform some relevant collection-dependent checks assert len(self.collections) == 1, 'Can only import to one collection.' self.repo = list(self.collections.values())[0] # if trying to import to the default repo, without doing a dryrun or forcing it, then # check if we're in the protected directory, i.e. the only one that is allowed to import # to the default collection default_collection = recursive_get(self.settings, ['mongo', 'default_collection']) try: default_collection_file_path = recursive_get( self.settings, ['mongo', 'default_collection_file_path']) except KeyError: default_collection_file_path = None if self.args.get('db') is None or self.args.get( 'db') == default_collection: if not self.dryrun and not self.args.get('force'): # if using default collection, check we are in the correct path if default_collection_file_path is not None: if not os.getcwd().startswith( os.path.expanduser(default_collection_file_path)): print(80 * '!') print( 'You shouldn\'t be importing to the default database from this folder! ' 'Please use --db <YourDBName> to create a new collection, ' 'or copy these files to the correct place!') print(80 * '!') raise RuntimeError( 'Failed to import to default collection from ' 'current directory, import must be called from {}'. format(default_collection_file_path)) else: if self.args.get('db') is not None: if any(['oqmd' in db for db in self.args.get('db')]): exit('Cannot import directly to oqmd repo') elif len(self.args.get('db')) > 1: exit('Can only import to one collection.') num_prototypes_in_db = self.repo.count_documents({'prototype': True}) num_objects_in_db = self.repo.count_documents({}) self.log.info( "Found {} existing objects in database.".format(num_objects_in_db)) if self.args.get('prototype'): if num_prototypes_in_db != num_objects_in_db: raise SystemExit( 'I will not import prototypes to a non-prototype database!' ) else: if num_prototypes_in_db != 0: raise SystemExit( 'I will not import DFT calculations into a prototype database!' ) if not self.dryrun: # either drop and recreate or create spatula report collection self.db.spatula.drop() self.report = self.db.spatula # scan directory on init self.file_lists = self._scan_dir() # if import, as opposed to rebuild, scan for duplicates and remove from list if not self.args.get('subcmd') == 'rebuild': self.file_lists, skipped = self._scan_dupes(self.file_lists) self.skipped += skipped # print number of files found self._display_import() # only create dicts if not just scanning if not self.scan: # convert to dict and db if required self._files2db(self.file_lists) if self.import_count == 0: print('No new structures imported!') if not self.dryrun and self.import_count > 0: print('Successfully imported', self.import_count, 'structures!') # index by enthalpy for faster/larger queries count = 0 for _ in self.repo.list_indexes(): count += 1 # ignore default id index if count > 1: self.log.info('Index found, rebuilding...') self.repo.reindex() else: self.log.info('Building index...') self.repo.create_index([('enthalpy_per_atom', pm.ASCENDING)]) self.repo.create_index([('stoichiometry', pm.ASCENDING)]) self.repo.create_index([('cut_off_energy', pm.ASCENDING)]) self.repo.create_index([('species_pot', pm.ASCENDING)]) self.repo.create_index([('kpoints_mp_spacing', pm.ASCENDING)]) self.repo.create_index([('xc_functional', pm.ASCENDING)]) self.repo.create_index([('elems', pm.ASCENDING)]) # index by source for rebuilds self.repo.create_index([('source', pm.ASCENDING)]) self.log.info('Done!') elif self.dryrun: self.log.info('Dryrun complete!') if not self.scan: self.logfile.seek(0) errors = sum(1 for line in self.logfile) self.errors += errors if errors == 1: self.log.warning('There is 1 error to view in {}'.format( self.logfile.name)) elif errors == 0: self.log.warning('There are no errors to view in {}'.format( self.logfile.name)) elif errors > 1: self.log.warning('There are {} errors to view in {}'.format( errors, self.logfile.name)) try: self.logfile.close() except Exception: pass try: self.manifest.close() except Exception: pass if not self.dryrun: # construct dictionary in spatula_report collection to hold info report_dict = dict() report_dict['last_modified'] = datetime.datetime.utcnow().replace( microsecond=0) report_dict['num_success'] = self.import_count report_dict['num_errors'] = errors report_dict['version'] = __version__ self.report.insert_one(report_dict)