def configs(self, kind, asatoms=True): """Loads a list of configurations of the specified kind. Args: kind (str): possible values are ['train', 'holdout', 'super']. asatoms (bool): when True, return a :class:`~matdb.atoms.AtomsList` object; otherwise just compile the file. Returns: matdb.atoms.AtomsList: Atoms list for the specified configuration class. """ fmap = { "train": lambda seq, splt: seq.train_file(splt), "holdout": lambda seq, splt: seq.holdout_file(splt), "super": lambda seq, splt: seq.super_file(splt) } smap = { t: getattr(self, "_{}file".format(t)) for t in ["train", "holdout", "super"] } cfile = smap[kind] if not path.isfile(cfile): cfiles = [] for seq in self.dbs: #We need to split to get training data. If the split has already #been done as part of a different training run, then it won't be #done a second time. msg.info("Compiling database {} for {}.".format( seq.name, self.fqn)) seq.split() if seq.name in self.cust_splits: splt = self.cust_splits[seq.name] else: splt = self.split #We grab a list of all the files that match the particular split #pattern. Then we apply any filters to individual atoms objects #within each of the databases. if splt == '*': nfiles = [] for dbsplit in seq.splits: nfiles.extend([f(seq, dbsplit) for f in fmap.values()]) else: nfiles = [fmap[kind](seq, splt)] filtered = self._filter_dbs(seq.name, nfiles) cfiles.extend(filtered) #If this is the training file, we need to append any extras; these #are files that have additional trainer-specific configs to include. if kind == "train": cfiles.extend(self.extras()) #First, save the configurations to a single file. dbcat(cfiles, cfile) if asatoms: return AtomsList(cfile)
def __init__(self, name=None, root=None, controller=None, splits=None, folder=None, pattern=None, config_type=None, energy="dft_energy", force="dft_force", virial="dft_virial", limit=None): self.name = name self.root = path.join(root, self.name) if not path.isdir(self.root): from os import mkdir mkdir(self.root) self.controller = controller self.splits = {} if splits is None else splits self.folder = folder if self.controller is None: self.ran_seed = 0 else: self.ran_seed = self.controller.ran_seed self._dbfile = path.join(self.root, "legacy-{}.h5".format(limit)) """str: path to the combined legacy database, with limits included. """ self._dbfull = path.join(self.root, "legacy.h5") """str: path to the combined legacy database, *without* limits. """ self.dbfiles = [] self.config_type = config_type from matdb.database.utility import dbconfig config = dbconfig(self._dbfull) if path.isfile(self._dbfile) and len(config) > 0: self.dbfiles = [db[0] for db in config["sources"]] self.config_type = config["config_type"] self.folder = folder else: from matdb.utility import dbcat if not path.isfile(self._dbfull): self._create_dbfull(folder, pattern, energy, force, virial, config_type) if limit is not None: msg.std("Slicing limit subset of full {} db.".format(self.name)) full = AtomsList(self._dbfull) N = np.arange(len(full)) np.random.shuffle(N) ids = N[0:limit] part = full[ids] part.write(self._dbfile) dbcat([self._dbfull], self._dbfile, docat=False, limit=limit, ids=ids) else: from matdb.utility import symlink symlink(self._dbfile, self._dbfull) #The rest of matdb expects each database to have an atoms object that is #representative. Just take the first config in the combined database. self.atoms = Atoms(self._dbfile)
def test_dbcate(): """Tests missing lines in dbcat. """ from matdb.utility import dbcat # Test to make just that 'temp2.txt.json' doesn't get written if # files can't be cated. dbcat(['temp1.txt'],'temp2.txt',sources=["temp3.txt"]) assert not path.isfile('temp2.txt.json') remove("temp2.txt") dbcat(['temp1.txt'],'temp2.txt') assert path.isfile('temp2.txt') remove("temp2.txt")
def _filter_dbs(self, seqname, dbfiles): """Filters each of the database files specified so that they conform to any specified filters. Args: seqname (str): name of the sequence that the database files are from. dbfiles (list): list of `str` paths to database files to filter. Returns: list: list of `str` paths to include in the database from this sequence. """ if len(self.dbfilter) > 0 and seqname in self._dbfilters: filtered = [] #The filters values have a function and a list of the actual values #used in the formula replacement. Extract the parameters; we can't #serialize the actual functions. filters = self._dbfilters[seqname].items() params = {k: v[1] for k, v in filters} for dbfile in dbfiles: dbname = path.basename(path.dirname(dbfile)) filtdb = path.join(self.root, "__{}.h5".format(dbname)) if path.isfile(filtdb): continue al = AtomsList(dbfile) nl = AtomsList() for a in al: #The 0 index here gets the function out; see comment above #about the filters dictionary. if not any(opf[0](getattr(a, attr)) for attr, opf in filters): nl.append(a) if len(nl) != len(al): nl.write(filtdb) dN, N = (len(al) - len(nl), len(nl)) dbcat([dbfile], filtdb, filters=params, dN=dN, N=N) filtered.append(filtdb) else: filtered.append(nfile) else: filtered = dbfiles return filtered
def split(atlist, splits, targets, dbdir, ran_seed, dbfile=None, recalc=0, nonsplit=None): """Splits the :class:`~matdb.atoms.AtomsList` multiple times, one for each `split` setting in the database specification. Args: atlsit (AtomsList or list): the list of :class:`matdb.atoms.Atoms` objects to be split or a list to the files containing the atoms objects. splits (dict): the splits to perform. targets (dict): the files to save the splits in, these should contain a {} in the name which will be replaced with the split name. The dictionary must have the format {"train": file_name, "holdout": file_name, "super": file_name}. dbdir (str): the root *splits* directory for the database. dbfile (str): the _dbfile for a legacy database. ran_seed (int or float): the random seed for the splits (i.e. the controllers random seed). recalc (int): when non-zero, re-split the data and overwrite any existing *.h5 files. This parameter decreases as rewrites proceed down the stack. To re-calculate lower-level h5 files, increase this value. nonsplit (AtomsList): a list of atoms to include in the training set "as-is" because they cannot be split (they only have meaning together). """ from matdb.utility import dbcat assert nonsplit is None or isinstance(nonsplit, AtomsList) for name, train_perc in splits.items(): train_file = targets["train"](name) holdout_file = targets["holdout"](name) super_file = targets["super"](name) idfile = path.join(dbdir, "{0}-ids.pkl".format(name)) if (path.isfile(train_file) and path.isfile(holdout_file) and path.isfile(super_file)): if recalc <= 0: return else: if path.isfile(idfile): with open(idfile, 'rb') as f: data = load(f) for fname in [train_file, holdout_file, super_file]: new_name = fname.replace( name, "{0}_{1}".format(name, data["uuid"])) rename(fname, new_name) remove(idfile) #Compile a list of all the sub-configurations we can include in the #training. if not isinstance(atlist, AtomsList): subconfs = AtomsList(atlist) else: subconfs = atlist if path.isfile(idfile): with open(idfile, 'rb') as f: data = load(f) subconfs = data["subconfs"] ids = data["ids"] Ntrain = data["Ntrain"] Nhold = data["Nhold"] Ntot = data["Ntot"] Nsuper = data["Nsuper"] else: Ntot = len(subconfs) Ntrain = int(np.ceil(Ntot * train_perc)) ids = np.arange(Ntot) Nhold = int(np.ceil((Ntot - Ntrain) * train_perc)) Nsuper = Ntot - Ntrain - Nhold np.random.shuffle(ids) #We need to save these ids so that we don't mess up the statistics on #the training and validation sets. data = { "uuid": str(uuid4()), "subconfs": subconfs, "ids": ids, "Ntrain": Ntrain, "Nhold": Nhold, "Ntot": Ntot, "Nsuper": Nsuper, "ran_seed": ran_seed } with open(idfile, 'wb') as f: dump(data, f) #Only write the minimum necessary files. Use dbcat to create the #database version and configuration information. There is duplication #here because we also store the ids again. We retain the pkl file above #so that we can recreate *exactly* the same split again later. if not path.isfile(train_file): tids = ids[0:Ntrain] #Make sure that we have some atoms to write in the first place! if len(tids) > 0: altrain = subconfs[tids] else: altrain = AtomsList() #Add the unsplittable configurations to the training set as-is. Nunsplit = 0 if nonsplit is not None: altrain.extend(nonsplit) Nunsplit = len(nonsplit) altrain.write(train_file) if dbfile is not None: dbcat([dbfile], train_file, docat=False, ids=tids, N=Ntrain + Nunsplit) else: dbcat([], train_file, docat=False, ids=tids, N=Ntrain + Nunsplit) if not path.isfile(holdout_file): hids = ids[Ntrain:-Nsuper] alhold = subconfs[hids] alhold.write(holdout_file) if dbfile is not None: dbcat([dbfile], holdout_file, docat=False, ids=hids, N=Nhold) else: dbcat([], holdout_file, docat=False, ids=hids, N=Nhold) if not path.isfile(super_file): sids = ids[-Nsuper:] alsuper = subconfs[sids] alsuper.write(super_file) if dbfile is not None: dbcat([dbfile], super_file, docat=False, ids=sids, N=Nsuper) else: dbcat([], super_file, docat=False, ids=sids, N=Nsuper)
def _create_dbfull(self, folder, pattern, energy, force, virial, config_type): """Creates the full combined database. """ from matdb.utility import chdir, dbcat from glob import glob from tqdm import tqdm from os import path #NB! There is a subtle bug here: if you try and open a matdb.atoms.Atoms #within the context manager of `chdir`, something messes up with the #memory sharing in fortran and it dies. This has to be separate. with chdir(folder): self.dbfiles = glob(pattern) rewrites = [] for dbfile in self.dbfiles: #Look at the first configuration in the atoms list to #determine if it matches the energy, force, virial and #config type parameter names. dbpath = path.join(folder, dbfile) params, doforce = _atoms_conform(dbpath, energy, force, virial) if len(params) > 0 or doforce: msg.std("Conforming database file {}.".format(dbpath)) al = AtomsList(dbpath) outpath = path.join(self.root, dbfile.replace(".xyz",".h5")) for ai in tqdm(al): for target, source in params.items(): if (target == "config_type" and config_type is not None): ai.params[target] = config_type else: ai.add_param(target,ai.params[source]) del ai.params[source] if source in ai.info: #pragma: no cover #(if things were #dane correctly by #the atoms object #this should never #be used. It exists #mainly as a #safegaurd. msg.warn("The atoms object didn't properly " "update the parameters of the legacy " "atoms object.") del ai.info[source] if doforce: ai.add_property("ref_force",ai.properties[force]) del ai.properties[force] al.write(outpath) #Mark this db as non-conforming so that we created a new #version of it. rewrites.append(dbfile) dbcat([dbpath], outpath, docat=False, renames=params, doforce=doforce) # We want a single file to hold all of the data for all the atoms in the database. all_atoms = AtomsList() for dbfile in self.dbfiles: if dbfile in rewrites: infile = dbfile.replace(".xyz",".h5") all_atoms.extend(AtomsList(path.join(self.root, infile))) else: dbpath = path.join(folder, dbfile) all_atoms.extend(AtomsList(dbpath)) all_atoms.write(self._dbfull) #Finally, create the config file. from matdb.utility import dbcat with chdir(folder): dbcat(self.dbfiles, self._dbfull, config_type=self.config_type, docat=False)