コード例 #1
0
def test_AtomsList_io(tmpdir):
    """Tests the AtomsList writing and reading from file.
    """
    from matdb.atoms import Atoms, AtomsList

    at1 = Atoms("Si8",positions=[[0,0,0],[0.25,0.25,0.25],[0.5,0.5,0],[0.75,0.75,0.25],
                                  [0.5,0,0.5],[0.75,0.25,0.75],[0,0.5,0.5],[0.25,0.75,0.75]],
                 cell=[5.43,5.43,5.43],info={"rand":10})
    
    at2 = Atoms("S6",positions=[[0,0,0],[0.25,0.25,0.25],[0.5,0.5,0],[0.75,0.75,0.25],
                                  [0.5,0,0.5],[0.75,0.25,0.75]],
                 cell=[6.43,5.43,4.43],info={"rand":10})

    at3 = Atoms("CNi",positions=[[0,0,0],[0.5,0.5,0.5]])
    at4 = Atoms()
    at4.copy_from(at3)
    
    al1 = AtomsList([at1,at2,at3,at4])
    
    target = str(tmpdir.join("atomList_to_hdf5"))
    if not path.isdir(target):
        mkdir(target)

    al1.write(path.join(target,"temp.h5"))

    aR = AtomsList()
    aR.read(path.join(target,"temp.h5"))

    assert len(aR) == len(al1)

    alpos = aR.positions
    assert any([np.allclose(alpos[i],at1.positions) for i in range(4) if
                len(alpos[i])==len(at1.positions)])
    assert any([np.allclose(alpos[i],at2.positions) for i in range(4) if
                len(alpos[i])==len(at2.positions)])
    assert any([np.allclose(alpos[i],at3.positions) for i in range(4) if
                len(alpos[i])==len(at3.positions)])
    assert any([np.allclose(alpos[i],at4.positions) for i in range(4) if
                len(alpos[i])==len(at4.positions)])

    al1.write(path.join(target,"temp.xyz"))

    aR = AtomsList()
    aR.read(path.join(target,"temp.xyz"))

    assert len(aR) == len(al1)

    aR.read(path.join(target,"temp.xyz"))
    assert len(aR) == 2*len(al1)


    # Test reading in of a single atoms object.

    aR1 = Atoms(path.join(target,"temp.h5"))
    assert isinstance(aR1,Atoms)
    assert any([np.allclose(alpos[i],at1.positions) for i in range(4) if
                len(alpos[i])==len(at1.positions)])
コード例 #2
0
def h5cat(files, target):
    """Concatenates a list of h5 AtomsList files into a single AtomsList.

    Args:
        files (list): list of `string` file paths to combine.
        target (str): name/path of the output file that will include all of the
          combined files.
    """
    # Local import to prevent cyclic imports
    from matdb.atoms import AtomsList
    result = AtomsList()
    for fname in files:
        ilist = AtomsList(fname)
        result.extend(ilist)
    result.write(target)
コード例 #3
0
    def _filter_dbs(self, seqname, dbfiles):
        """Filters each of the database files specified so that they conform to
        any specified filters.

        Args:
            seqname (str): name of the sequence that the database files are
              from.
            dbfiles (list): list of `str` paths to database files to filter.

        Returns:
            list: list of `str` paths to include in the database from this sequence.
        """
        if len(self.dbfilter) > 0 and seqname in self._dbfilters:
            filtered = []
            #The filters values have a function and a list of the actual values
            #used in the formula replacement. Extract the parameters; we can't
            #serialize the actual functions.
            filters = self._dbfilters[seqname].items()
            params = {k: v[1] for k, v in filters}

            for dbfile in dbfiles:
                dbname = path.basename(path.dirname(dbfile))
                filtdb = path.join(self.root, "__{}.h5".format(dbname))
                if path.isfile(filtdb):
                    continue

                al = AtomsList(dbfile)
                nl = AtomsList()
                for a in al:
                    #The 0 index here gets the function out; see comment above
                    #about the filters dictionary.
                    if not any(opf[0](getattr(a, attr))
                               for attr, opf in filters):
                        nl.append(a)

                if len(nl) != len(al):
                    nl.write(filtdb)
                    dN, N = (len(al) - len(nl), len(nl))
                    dbcat([dbfile], filtdb, filters=params, dN=dN, N=N)
                    filtered.append(filtdb)
                else:
                    filtered.append(nfile)
        else:
            filtered = dbfiles

        return filtered
コード例 #4
0
def test_AtomsList_empty_io(tmpdir):
    from matdb.atoms import Atoms, AtomsList
    from os import path

    target = str(tmpdir.join("empty_AtomsList"))
    globals_setup(target)

    if not path.isdir(target):
        mkdir(target)

    empty_list = AtomsList([])
    empty_list.write(path.join(target,"temp.h5"))
    assert len(empty_list) == 0
    assert path.isfile(path.join(target,"temp.h5"))

    aR = AtomsList()
    aR.read(path.join(target,"temp.h5"))
    assert len(aR) == 0
コード例 #5
0
ファイル: utility.py プロジェクト: HallLabs/tracy_matdb
def split(atlist,
          splits,
          targets,
          dbdir,
          ran_seed,
          dbfile=None,
          recalc=0,
          nonsplit=None):
    """Splits the :class:`~matdb.atoms.AtomsList` multiple times, one for
    each `split` setting in the database specification.

    Args:
        atlsit (AtomsList or list): the list of :class:`matdb.atoms.Atoms` objects
          to be split or a list to the files containing the atoms objects.
        splits (dict): the splits to perform.
        targets (dict): the files to save the splits in, these should
          contain a {} in the name which will be replaced with the split
          name. The dictionary must have the format {"train": file_name,
          "holdout": file_name, "super": file_name}.
        dbdir (str): the root *splits* directory for the database.
        dbfile (str): the _dbfile for a legacy database.
        ran_seed (int or float): the random seed for the splits (i.e. the controllers
          random seed).
        recalc (int): when non-zero, re-split the data and overwrite any
          existing *.h5 files. This parameter decreases as
          rewrites proceed down the stack. To re-calculate
          lower-level h5 files, increase this value.
        nonsplit (AtomsList): a list of atoms to include in the training
          set "as-is" because they cannot be split (they only have meaning
          together).
    """
    from matdb.utility import dbcat

    assert nonsplit is None or isinstance(nonsplit, AtomsList)
    for name, train_perc in splits.items():
        train_file = targets["train"](name)
        holdout_file = targets["holdout"](name)
        super_file = targets["super"](name)
        idfile = path.join(dbdir, "{0}-ids.pkl".format(name))

        if (path.isfile(train_file) and path.isfile(holdout_file)
                and path.isfile(super_file)):
            if recalc <= 0:
                return
            else:
                if path.isfile(idfile):
                    with open(idfile, 'rb') as f:
                        data = load(f)
                for fname in [train_file, holdout_file, super_file]:
                    new_name = fname.replace(
                        name, "{0}_{1}".format(name, data["uuid"]))
                    rename(fname, new_name)
                remove(idfile)

        #Compile a list of all the sub-configurations we can include in the
        #training.
        if not isinstance(atlist, AtomsList):
            subconfs = AtomsList(atlist)
        else:
            subconfs = atlist

        if path.isfile(idfile):
            with open(idfile, 'rb') as f:
                data = load(f)
            subconfs = data["subconfs"]
            ids = data["ids"]
            Ntrain = data["Ntrain"]
            Nhold = data["Nhold"]
            Ntot = data["Ntot"]
            Nsuper = data["Nsuper"]
        else:
            Ntot = len(subconfs)
            Ntrain = int(np.ceil(Ntot * train_perc))
            ids = np.arange(Ntot)
            Nhold = int(np.ceil((Ntot - Ntrain) * train_perc))
            Nsuper = Ntot - Ntrain - Nhold
            np.random.shuffle(ids)

            #We need to save these ids so that we don't mess up the statistics on
            #the training and validation sets.
            data = {
                "uuid": str(uuid4()),
                "subconfs": subconfs,
                "ids": ids,
                "Ntrain": Ntrain,
                "Nhold": Nhold,
                "Ntot": Ntot,
                "Nsuper": Nsuper,
                "ran_seed": ran_seed
            }
            with open(idfile, 'wb') as f:
                dump(data, f)

        #Only write the minimum necessary files. Use dbcat to create the
        #database version and configuration information. There is duplication
        #here because we also store the ids again. We retain the pkl file above
        #so that we can recreate *exactly* the same split again later.
        if not path.isfile(train_file):
            tids = ids[0:Ntrain]
            #Make sure that we have some atoms to write in the first place!
            if len(tids) > 0:
                altrain = subconfs[tids]
            else:
                altrain = AtomsList()
            #Add the unsplittable configurations to the training set as-is.
            Nunsplit = 0
            if nonsplit is not None:
                altrain.extend(nonsplit)
                Nunsplit = len(nonsplit)
            altrain.write(train_file)

            if dbfile is not None:
                dbcat([dbfile],
                      train_file,
                      docat=False,
                      ids=tids,
                      N=Ntrain + Nunsplit)
            else:
                dbcat([],
                      train_file,
                      docat=False,
                      ids=tids,
                      N=Ntrain + Nunsplit)
        if not path.isfile(holdout_file):
            hids = ids[Ntrain:-Nsuper]
            alhold = subconfs[hids]
            alhold.write(holdout_file)
            if dbfile is not None:
                dbcat([dbfile], holdout_file, docat=False, ids=hids, N=Nhold)
            else:
                dbcat([], holdout_file, docat=False, ids=hids, N=Nhold)
        if not path.isfile(super_file):
            sids = ids[-Nsuper:]
            alsuper = subconfs[sids]
            alsuper.write(super_file)
            if dbfile is not None:
                dbcat([dbfile], super_file, docat=False, ids=sids, N=Nsuper)
            else:
                dbcat([], super_file, docat=False, ids=sids, N=Nsuper)
コード例 #6
0
    def _create_dbfull(self, folder, pattern, energy, force, virial, config_type):
        """Creates the full combined database.
        """
        from matdb.utility import chdir, dbcat
        from glob import glob
        from tqdm import tqdm
        from os import path

        #NB! There is a subtle bug here: if you try and open a matdb.atoms.Atoms
        #within the context manager of `chdir`, something messes up with the
        #memory sharing in fortran and it dies. This has to be separate.
        with chdir(folder):
            self.dbfiles = glob(pattern)
        rewrites = []

        for dbfile in self.dbfiles:
            #Look at the first configuration in the atoms list to
            #determine if it matches the energy, force, virial and
            #config type parameter names.
            dbpath = path.join(folder, dbfile)
            params, doforce = _atoms_conform(dbpath, energy, force, virial)
            if len(params) > 0 or doforce:
                msg.std("Conforming database file {}.".format(dbpath))
                al = AtomsList(dbpath)
                outpath = path.join(self.root, dbfile.replace(".xyz",".h5"))
                for ai in tqdm(al):
                    for target, source in params.items():
                        if (target == "config_type" and
                            config_type is not None):
                            ai.params[target] = config_type
                        else:
                            ai.add_param(target,ai.params[source])
                            del ai.params[source]
                            if source in ai.info: #pragma: no cover
                                                  #(if things were
                                                  #dane correctly by
                                                  #the atoms object
                                                  #this should never
                                                  #be used. It exists
                                                  #mainly as a
                                                  #safegaurd.
                                msg.warn("The atoms object didn't properly "
                                         "update the parameters of the legacy "
                                         "atoms object.")
                                del ai.info[source]

                    if doforce:
                        ai.add_property("ref_force",ai.properties[force])
                        del ai.properties[force]

                al.write(outpath)

                #Mark this db as non-conforming so that we created a new
                #version of it.
                rewrites.append(dbfile)

                dbcat([dbpath], outpath, docat=False, renames=params,
                      doforce=doforce)

        # We want a single file to hold all of the data for all the atoms in the database.
        all_atoms = AtomsList()
        for dbfile in self.dbfiles:
            if dbfile in rewrites:
                infile = dbfile.replace(".xyz",".h5")
                all_atoms.extend(AtomsList(path.join(self.root, infile)))
            else:
                dbpath = path.join(folder, dbfile)
                all_atoms.extend(AtomsList(dbpath))

        all_atoms.write(self._dbfull)

        #Finally, create the config file.
        from matdb.utility import dbcat
        with chdir(folder):
            dbcat(self.dbfiles, self._dbfull, config_type=self.config_type, docat=False)