示例#1
0
    def configs(self, kind, asatoms=True):
        """Loads a list of configurations of the specified kind.

        Args:
            kind (str): possible values are ['train', 'holdout', 'super'].
            asatoms (bool): when True, return a :class:`~matdb.atoms.AtomsList`
              object; otherwise just compile the file.

        Returns:
            matdb.atoms.AtomsList: Atoms list for the specified configuration class.
        """
        fmap = {
            "train": lambda seq, splt: seq.train_file(splt),
            "holdout": lambda seq, splt: seq.holdout_file(splt),
            "super": lambda seq, splt: seq.super_file(splt)
        }
        smap = {
            t: getattr(self, "_{}file".format(t))
            for t in ["train", "holdout", "super"]
        }
        cfile = smap[kind]

        if not path.isfile(cfile):
            cfiles = []
            for seq in self.dbs:
                #We need to split to get training data. If the split has already
                #been done as part of a different training run, then it won't be
                #done a second time.
                msg.info("Compiling database {} for {}.".format(
                    seq.name, self.fqn))
                seq.split()
                if seq.name in self.cust_splits:
                    splt = self.cust_splits[seq.name]
                else:
                    splt = self.split

                #We grab a list of all the files that match the particular split
                #pattern. Then we apply any filters to individual atoms objects
                #within each of the databases.
                if splt == '*':
                    nfiles = []
                    for dbsplit in seq.splits:
                        nfiles.extend([f(seq, dbsplit) for f in fmap.values()])
                else:
                    nfiles = [fmap[kind](seq, splt)]

                filtered = self._filter_dbs(seq.name, nfiles)
                cfiles.extend(filtered)

            #If this is the training file, we need to append any extras; these
            #are files that have additional trainer-specific configs to include.
            if kind == "train":
                cfiles.extend(self.extras())

            #First, save the configurations to a single file.
            dbcat(cfiles, cfile)

        if asatoms:
            return AtomsList(cfile)
示例#2
0
    def __init__(self, name=None, root=None, controller=None, splits=None,
                 folder=None, pattern=None, config_type=None, energy="dft_energy",
                 force="dft_force", virial="dft_virial", limit=None):
        self.name = name
        self.root = path.join(root, self.name)
        if not path.isdir(self.root):
            from os import mkdir
            mkdir(self.root)

        self.controller = controller
        self.splits = {} if splits is None else splits
        self.folder = folder

        if self.controller is None:
            self.ran_seed = 0
        else:
            self.ran_seed = self.controller.ran_seed

        self._dbfile = path.join(self.root, "legacy-{}.h5".format(limit))
        """str: path to the combined legacy database, with limits included.
        """
        self._dbfull = path.join(self.root, "legacy.h5")
        """str: path to the combined legacy database, *without* limits.
        """
        self.dbfiles = []
        self.config_type = config_type

        from matdb.database.utility import dbconfig
        config = dbconfig(self._dbfull)
        if path.isfile(self._dbfile) and len(config) > 0:
            self.dbfiles = [db[0] for db in config["sources"]]
            self.config_type = config["config_type"]
            self.folder = folder
        else:
            from matdb.utility import dbcat
            if not path.isfile(self._dbfull):
                self._create_dbfull(folder, pattern, energy, force, virial, config_type)

            if limit is not None:
                msg.std("Slicing limit subset of full {} db.".format(self.name))
                full = AtomsList(self._dbfull)
                N = np.arange(len(full))
                np.random.shuffle(N)
                ids = N[0:limit]
                part = full[ids]
                part.write(self._dbfile)
                dbcat([self._dbfull], self._dbfile, docat=False, limit=limit,
                      ids=ids)
            else:
                from matdb.utility import symlink
                symlink(self._dbfile, self._dbfull)

        #The rest of matdb expects each database to have an atoms object that is
        #representative. Just take the first config in the combined database.
        self.atoms = Atoms(self._dbfile)
示例#3
0
def test_dbcate():
    """Tests missing lines in dbcat.
    """
    from matdb.utility import dbcat
    # Test to make just that 'temp2.txt.json' doesn't get written if
    # files can't be cated.
    dbcat(['temp1.txt'],'temp2.txt',sources=["temp3.txt"])

    assert not path.isfile('temp2.txt.json')
    remove("temp2.txt")
    
    dbcat(['temp1.txt'],'temp2.txt')

    assert path.isfile('temp2.txt')
    remove("temp2.txt")
示例#4
0
    def _filter_dbs(self, seqname, dbfiles):
        """Filters each of the database files specified so that they conform to
        any specified filters.

        Args:
            seqname (str): name of the sequence that the database files are
              from.
            dbfiles (list): list of `str` paths to database files to filter.

        Returns:
            list: list of `str` paths to include in the database from this sequence.
        """
        if len(self.dbfilter) > 0 and seqname in self._dbfilters:
            filtered = []
            #The filters values have a function and a list of the actual values
            #used in the formula replacement. Extract the parameters; we can't
            #serialize the actual functions.
            filters = self._dbfilters[seqname].items()
            params = {k: v[1] for k, v in filters}

            for dbfile in dbfiles:
                dbname = path.basename(path.dirname(dbfile))
                filtdb = path.join(self.root, "__{}.h5".format(dbname))
                if path.isfile(filtdb):
                    continue

                al = AtomsList(dbfile)
                nl = AtomsList()
                for a in al:
                    #The 0 index here gets the function out; see comment above
                    #about the filters dictionary.
                    if not any(opf[0](getattr(a, attr))
                               for attr, opf in filters):
                        nl.append(a)

                if len(nl) != len(al):
                    nl.write(filtdb)
                    dN, N = (len(al) - len(nl), len(nl))
                    dbcat([dbfile], filtdb, filters=params, dN=dN, N=N)
                    filtered.append(filtdb)
                else:
                    filtered.append(nfile)
        else:
            filtered = dbfiles

        return filtered
示例#5
0
def split(atlist,
          splits,
          targets,
          dbdir,
          ran_seed,
          dbfile=None,
          recalc=0,
          nonsplit=None):
    """Splits the :class:`~matdb.atoms.AtomsList` multiple times, one for
    each `split` setting in the database specification.

    Args:
        atlsit (AtomsList or list): the list of :class:`matdb.atoms.Atoms` objects
          to be split or a list to the files containing the atoms objects.
        splits (dict): the splits to perform.
        targets (dict): the files to save the splits in, these should
          contain a {} in the name which will be replaced with the split
          name. The dictionary must have the format {"train": file_name,
          "holdout": file_name, "super": file_name}.
        dbdir (str): the root *splits* directory for the database.
        dbfile (str): the _dbfile for a legacy database.
        ran_seed (int or float): the random seed for the splits (i.e. the controllers
          random seed).
        recalc (int): when non-zero, re-split the data and overwrite any
          existing *.h5 files. This parameter decreases as
          rewrites proceed down the stack. To re-calculate
          lower-level h5 files, increase this value.
        nonsplit (AtomsList): a list of atoms to include in the training
          set "as-is" because they cannot be split (they only have meaning
          together).
    """
    from matdb.utility import dbcat

    assert nonsplit is None or isinstance(nonsplit, AtomsList)
    for name, train_perc in splits.items():
        train_file = targets["train"](name)
        holdout_file = targets["holdout"](name)
        super_file = targets["super"](name)
        idfile = path.join(dbdir, "{0}-ids.pkl".format(name))

        if (path.isfile(train_file) and path.isfile(holdout_file)
                and path.isfile(super_file)):
            if recalc <= 0:
                return
            else:
                if path.isfile(idfile):
                    with open(idfile, 'rb') as f:
                        data = load(f)
                for fname in [train_file, holdout_file, super_file]:
                    new_name = fname.replace(
                        name, "{0}_{1}".format(name, data["uuid"]))
                    rename(fname, new_name)
                remove(idfile)

        #Compile a list of all the sub-configurations we can include in the
        #training.
        if not isinstance(atlist, AtomsList):
            subconfs = AtomsList(atlist)
        else:
            subconfs = atlist

        if path.isfile(idfile):
            with open(idfile, 'rb') as f:
                data = load(f)
            subconfs = data["subconfs"]
            ids = data["ids"]
            Ntrain = data["Ntrain"]
            Nhold = data["Nhold"]
            Ntot = data["Ntot"]
            Nsuper = data["Nsuper"]
        else:
            Ntot = len(subconfs)
            Ntrain = int(np.ceil(Ntot * train_perc))
            ids = np.arange(Ntot)
            Nhold = int(np.ceil((Ntot - Ntrain) * train_perc))
            Nsuper = Ntot - Ntrain - Nhold
            np.random.shuffle(ids)

            #We need to save these ids so that we don't mess up the statistics on
            #the training and validation sets.
            data = {
                "uuid": str(uuid4()),
                "subconfs": subconfs,
                "ids": ids,
                "Ntrain": Ntrain,
                "Nhold": Nhold,
                "Ntot": Ntot,
                "Nsuper": Nsuper,
                "ran_seed": ran_seed
            }
            with open(idfile, 'wb') as f:
                dump(data, f)

        #Only write the minimum necessary files. Use dbcat to create the
        #database version and configuration information. There is duplication
        #here because we also store the ids again. We retain the pkl file above
        #so that we can recreate *exactly* the same split again later.
        if not path.isfile(train_file):
            tids = ids[0:Ntrain]
            #Make sure that we have some atoms to write in the first place!
            if len(tids) > 0:
                altrain = subconfs[tids]
            else:
                altrain = AtomsList()
            #Add the unsplittable configurations to the training set as-is.
            Nunsplit = 0
            if nonsplit is not None:
                altrain.extend(nonsplit)
                Nunsplit = len(nonsplit)
            altrain.write(train_file)

            if dbfile is not None:
                dbcat([dbfile],
                      train_file,
                      docat=False,
                      ids=tids,
                      N=Ntrain + Nunsplit)
            else:
                dbcat([],
                      train_file,
                      docat=False,
                      ids=tids,
                      N=Ntrain + Nunsplit)
        if not path.isfile(holdout_file):
            hids = ids[Ntrain:-Nsuper]
            alhold = subconfs[hids]
            alhold.write(holdout_file)
            if dbfile is not None:
                dbcat([dbfile], holdout_file, docat=False, ids=hids, N=Nhold)
            else:
                dbcat([], holdout_file, docat=False, ids=hids, N=Nhold)
        if not path.isfile(super_file):
            sids = ids[-Nsuper:]
            alsuper = subconfs[sids]
            alsuper.write(super_file)
            if dbfile is not None:
                dbcat([dbfile], super_file, docat=False, ids=sids, N=Nsuper)
            else:
                dbcat([], super_file, docat=False, ids=sids, N=Nsuper)
示例#6
0
    def _create_dbfull(self, folder, pattern, energy, force, virial, config_type):
        """Creates the full combined database.
        """
        from matdb.utility import chdir, dbcat
        from glob import glob
        from tqdm import tqdm
        from os import path

        #NB! There is a subtle bug here: if you try and open a matdb.atoms.Atoms
        #within the context manager of `chdir`, something messes up with the
        #memory sharing in fortran and it dies. This has to be separate.
        with chdir(folder):
            self.dbfiles = glob(pattern)
        rewrites = []

        for dbfile in self.dbfiles:
            #Look at the first configuration in the atoms list to
            #determine if it matches the energy, force, virial and
            #config type parameter names.
            dbpath = path.join(folder, dbfile)
            params, doforce = _atoms_conform(dbpath, energy, force, virial)
            if len(params) > 0 or doforce:
                msg.std("Conforming database file {}.".format(dbpath))
                al = AtomsList(dbpath)
                outpath = path.join(self.root, dbfile.replace(".xyz",".h5"))
                for ai in tqdm(al):
                    for target, source in params.items():
                        if (target == "config_type" and
                            config_type is not None):
                            ai.params[target] = config_type
                        else:
                            ai.add_param(target,ai.params[source])
                            del ai.params[source]
                            if source in ai.info: #pragma: no cover
                                                  #(if things were
                                                  #dane correctly by
                                                  #the atoms object
                                                  #this should never
                                                  #be used. It exists
                                                  #mainly as a
                                                  #safegaurd.
                                msg.warn("The atoms object didn't properly "
                                         "update the parameters of the legacy "
                                         "atoms object.")
                                del ai.info[source]

                    if doforce:
                        ai.add_property("ref_force",ai.properties[force])
                        del ai.properties[force]

                al.write(outpath)

                #Mark this db as non-conforming so that we created a new
                #version of it.
                rewrites.append(dbfile)

                dbcat([dbpath], outpath, docat=False, renames=params,
                      doforce=doforce)

        # We want a single file to hold all of the data for all the atoms in the database.
        all_atoms = AtomsList()
        for dbfile in self.dbfiles:
            if dbfile in rewrites:
                infile = dbfile.replace(".xyz",".h5")
                all_atoms.extend(AtomsList(path.join(self.root, infile)))
            else:
                dbpath = path.join(folder, dbfile)
                all_atoms.extend(AtomsList(dbpath))

        all_atoms.write(self._dbfull)

        #Finally, create the config file.
        from matdb.utility import dbcat
        with chdir(folder):
            dbcat(self.dbfiles, self._dbfull, config_type=self.config_type, docat=False)