예제 #1
0
파일: test_io.py 프로젝트: hensldm/gblearn
def test_scatter(store):
    """Tests the saving and restoration of random 2x2 Scatter matrices
    for each of the GBs.
    """
    #There shouldn't be anything before it is set.
    with store.Scatter[store.gbids[0]] as stored:
        assert stored is None

    #First, we generate the random matrices, then we set Scatter and get Scatter
    #and make sure they match.
    Scatters = {gbid: np.random.random((2, 2)) for gbid in store.gbids}
    store.Scatter = Scatters

    #Check that the directory has the relevant files and that they
    #don't have zero size.
    target = path.join(store.root, "scatter", "Scatter", "0.50_2_6_8_8_8")
    assert path.isdir(target)
    for gbid in store.gbids:
        gbpath = path.join(target, "{}.npy".format(gbid))
        assert path.isfile(gbpath)

    #Ask for a new store so that we can load the arrays from disk and
    #check their equality.
    nstore = ResultStore(range(1, 8), store.root)
    nstore.configure("scatter",
                     density=0.5,
                     Layers=2,
                     SPH_L=6,
                     n_trans=8,
                     n_angle1=8,
                     n_angle2=8)
    assert len(nstore.gbids) == len(nstore.Scatter)
    for gbid in nstore.gbids:
        with nstore.Scatter[gbid] as stored:
            assert np.allclose(stored, Scatters[gbid])
예제 #2
0
파일: test_io.py 프로젝트: hensldm/gblearn
def test_soap(store):
    """Tests the saving and restoration of random 2x2 SOAP matrices
    for each of the GBs.
    """
    #There shouldn't be anything before it is set.
    with store.P[store.gbids[0]] as stored:
        assert stored is None

    #First, we generate the random matrices, then we set P and get P
    #and make sure they match.
    Ps = {gbid: np.random.random((2, 2)) for gbid in store.gbids}
    store.P = Ps

    #Check that the directory has the relevant files and that they
    #don't have zero size.
    target = path.join(store.root, "soap", "P", "8_8_4.30")
    assert path.isdir(target)
    for gbid in store.gbids:
        gbpath = path.join(target, "{}.npy".format(gbid))
        assert path.isfile(gbpath)

    #Ask for a new store so that we can load the arrays from disk and
    #check their equality.
    nstore = ResultStore(range(1, 8), store.root)
    nstore.configure("soap", lmax=8, nmax=8, rcut=4.3)
    assert len(nstore.gbids) == len(nstore.P)
    for gbid in nstore.gbids:
        with nstore.P[gbid] as stored:
            assert np.allclose(stored, Ps[gbid])
예제 #3
0
파일: test_io.py 프로젝트: hensldm/gblearn
def store(tmpdir):
    root = tmpdir.join("rndstore")
    res = ResultStore(range(1, 8), str(root))
    res.configure("soap", lmax=8, nmax=8, rcut=4.3)
    res.configure("scatter",
                  density=0.5,
                  Layers=2,
                  SPH_L=6,
                  n_trans=8,
                  n_angle1=8,
                  n_angle2=8)
    return res
예제 #4
0
파일: test_io.py 프로젝트: hensldm/gblearn
def memstore():
    """Returns a memory-only result store.
    """
    res = ResultStore(range(1, 8))
    res.configure("soap", lmax=8, nmax=8, rcut=4.3)
    res.configure("scatter",
                  density=0.5,
                  Layers=2,
                  SPH_L=6,
                  n_trans=8,
                  n_angle1=8,
                  n_angle2=8)
    return res
예제 #5
0
파일: test_io.py 프로젝트: hensldm/gblearn
def test_ASR(store):
    """Tests the aggregated, single matrix storage.
    """
    ASR = np.random.random((7, 6))
    store.ASR = ASR

    target = path.join(store.root, "soap", "ASR", "8_8_4.30.npy")
    assert path.isfile(target)

    nstore = ResultStore(range(1, 8), store.root)
    nstore.configure("soap", lmax=8, nmax=8, rcut=4.3)
    assert np.allclose(nstore.ASR, ASR)
    assert np.allclose(store.ASR, ASR)
예제 #6
0
파일: test_io.py 프로젝트: hensldm/gblearn
def test_LER(store):
    """Tests the parameterized, aggregated storage for U and LER.
    """
    eps = [1.1, 2.2, 3.3]
    LER = {e: np.random.random((7, 5)) for e in eps}
    store.LER = LER

    #Ask for a new store so that we can load the arrays from disk and
    #check their equality.
    nstore = ResultStore(range(1, 8), store.root)
    nstore.configure("soap", lmax=8, nmax=8, rcut=4.3)
    for e in eps:
        assert np.allclose(nstore.LER[e], LER[e])
예제 #7
0
파일: gb.py 프로젝트: jayspendlove/gblearn
class GrainBoundaryCollection(OrderedDict):
    """Represents a collection of grain boundaries and the unique environments
    between them.

    .. warning:: If you don't specify a path for `store`, any results (such as
      SOAP matrices, ASR, LER, etc. will *not* be saved to disk. Also, they
      won't be loaded from disk if they already exist.

    Args:
        name (str): identifier for this collection.
        root (str): path to the directory where the raw GB atomic descriptions
          are located.
        store (str): path to the :class:`~gblearn.io.ResultStore` root
          directory that this collection's results are stored in. To use a
          memory-only store, leave this as `None`.
        rxgbid (str): regex pattern for extracting the `gbid` for each GB. Any
          files that don't match the regex are automatically excluded. The regex
          should include a named group `(?P<gbid>...)` so that the GB id can be
          extracted correctly. If not specified, the file name is used as the
          `gbid`.
        sortkey (function): when `root` is investigated to load GBs, the file
          names are first sorted; here you can specify a custom sorting
          function.
        reverse (bool): for GB file name sorting (see `sortkey`), whether to
          reverse the order.
        seed (numpy.ndarray): seed SOAP vector for calculating unique LAEs. This
          is usually the SOAP vector of the perfect bulk crystal in the GB.
        padding (float): amount of perfect bulk to include as padding around
          the grain boundary before the representation is made.

    Attributes:
        name (str): identifier for this collection.
        root (str): path to the directory where the raw GB atomic descriptions
          are located.
        store (gblearn.io.ResultStore): storage manager for all the
          intermediate results generated by this collection.
        unique (dict): keys are `float` values of `epsilon` for comparing
          environments; values are themselves dictionaries that have keys as
          `tuple` of `(gbid, aid)` with `aid` the id of the atom (row id) in the
          SOAP vector for the GB; value is the SOAP vector already found to be
          unique for that value of `epsilon`.
        equivalent (dict): keys are `float` values of `epsilon` for comparing
          environments; values are themselves dictionaries that have `gbid` keys
          values are a `dict` having linked keys with :attr:`unique` and values
          a list of `aid` in the GB whose LAEs that are equivalent to the unique
          LAE represented by the key.
        repargs (dict): keys are the representation name, while the values are
          the parameters for the
        properties (dict): keys are property names, values are `dict` keyed by
          `gbid` with values being the property value for each GB.
        LAE (dict): keys are int values corresponding LAE id, and the values are
          LAE objects corresponding to the id.
        others (dict): keys are ids while values are Grain Boundary objects not
           belonging to the original collection.
        axis (int): the dimension in which the Grain Boundary was parsed
    """
    def __init__(self,
                 name,
                 root,
                 store=None,
                 rxgbid=None,
                 sortkey=None,
                 reverse=False,
                 seed=None,
                 padding=10.0):
        super(GrainBoundaryCollection, self).__init__()
        self.name = name
        self.root = path.abspath(path.expanduser(root))
        self._sortkey = sortkey
        """function: when `root` is investigated to load GBs, the file names are first
          sorted; here you can specify a custom sorting function.
        """
        self._reverse = reverse
        """bool: for GB file name sorting (see `sortkey`), whether to reverse
        the order.
        """
        self._rxgbid = None
        """_sre.SRE_Pattern: compiled regex for the gbid pattern matching
        string.
        """
        self.unique = {}
        self.equivalent = {}
        self.properties = {}
        self.repargs = {}
        self.LAE = {}
        self.seed = seed
        self.others = {}
        self.axis = None

        if rxgbid is not None:
            import re
            self._rxgbid = re.compile(rxgbid)

        #Search for all the GBs in the specified root folder.
        self.gbfiles = OrderedDict()
        self._find_gbs()

        from gblearn.io import ResultStore
        self.store = ResultStore(self.gbfiles.keys(), store, padding=padding)
        #self.mutlistore = ResultStore(self.gbfiles.keys(), "multires_"+ store, padding=padding)
        self.padding = padding

    def get_property(self, name):
        """Builds a value vector for a property in this collection.

        Args:
            name (str): name of the property to build the vector for.
        """
        if name in self.properties:
            values = self.properties[name]
            return np.array([values[gbid] for gbid in self])
        else:
            values = []
            scalar = False
            for gb in self.values():
                if hasattr(gb, name):
                    values.append(getattr(gb, name))
                elif name in gb.params:
                    values.append(gb.params[name])
                    scalar = True
                else:  # pragma: no cover
                    break

            if scalar:
                return np.array(values)
            else:
                return values

    def add_property(self,
                     name,
                     filename=None,
                     values=None,
                     colindex=1,
                     delimiter=None,
                     cast=float,
                     skip=0):
        """Adds a property to each GB in the collection from file or an existing
        dictionary.

        .. note:: You must specify either `filename` or `values`.

        Args:
            name (str): name of the property to index under.
            filename (str): path to the file to import from. First column should
              be the `gbid`. Values are taken from `colindex` and `cast` to the
              specified data type.
            values (dict): keys are `gbid`, values are property values.
            colindex (int): index in the text file to extract values from.
            delimiter (str): delimiter to split on for each row in the file.
            cast: function to apply to the value for this property.
            skip (int): number of rows to skip before reading data.
        """
        if values is not None:
            self.properties[name] = values
            return

        #Extract the gbids directly from the first column before we do the array
        #loading.
        pdict = {}
        iskip = 0
        with open(filename) as f:
            for line in f:
                if iskip < skip:  # pragma: no cover
                    iskip += 1
                    continue

                if delimiter is None:
                    rvals = line.split()
                else:  # pragma: no cover
                    rvals = line.split(delimiter)

                gbid = rvals[0]
                pval = cast(rvals[colindex])
                pdict[gbid] = pval

        self.properties[name] = pdict

    def _find_gbs(self):
        """Finds all the GBs in the root directory using the regex.
        """
        from os import walk
        allfiles = []
        for (dirpath, dirnames, filenames) in walk(self.root):
            allfiles.extend(filenames)
            break

        for fname in sorted(allfiles, key=self._sortkey,
                            reverse=self._reverse):
            if self._rxgbid is not None:
                gbmatch = self._rxgbid.match(fname)
                if gbmatch:
                    try:
                        gbid = gbmatch.group("gbid")
                        self.gbfiles[gbid] = path.join(self.root, fname)
                    except IndexError:  # pragma: no cover
                        pass
            else:
                self.gbfiles[fname] = path.join(self.root, fname)

        msg.info("Found {} grain boundaries.".format(len(self.gbfiles)))

    def load(self,
             parser=None,
             custids=None,
             name=None,
             fname=None,
             **selectargs):
        """Loads the GBs from their files to create :class:`GrainBoundary`
        objects.

        .. note:: The :class:`GrainBoundary` objects are stored in this objects
          dictionary (it inherits from :class:`~collections.OrderedDict`). Thus
          :attr:`keys` are the `gbid` and :attr:`values` are the
          :class:`GrainBoundary` instances, in the sorted order that they were
          discovered.

        .. note:: If 'coord' is not given as a selectarg, load will determine
          and use the coordinate along the longest dimension.

        Args:
            parser: object used to parse the raw GB file. Defaults to
              :class:`gblearn.lammps.Timestep`. Class should have a method `gb`
              that constructs a :class:`GrainBoundary` instance.
            custids (dict or str): if `dict`, keys are `str` GB ids and values
              are the custom selection method to use. If `str`, then a TSV file
              where the first column is GB id and the second is the custom
              selection method to use.
            name (str): unique id of external grain boundary
            fname (str): filenme to the grain boundary file
                    ..warning:: the filenmae automattically adds the root path
            selectargs (dict): keyword arguments passed to `parser` when
              isolating grain boundary atoms.
        """
        if parser is None:
            from gblearn.lammps import Timestep
            parser = Timestep

        if custids is not None and isinstance(custids, six.string_types):
            rawids = np.loadtxt(custids, dtype=str).tolist()
            custids = {g: m for g, m in enumerate(rawids)}

        if fname is not None:
            if name is None:  # pragma: no cover
                msg.info(
                    "Name not specified, using {} as unique identifier".format(
                        fname))
                name = fname
            gbpath = path.join(self.root, fname)
            self.others[name] = self._parse_gb(gbpath, parser, **selectargs)
            return

        for gbid, gbpath in tqdm(self.gbfiles.items()):
            if custids is not None and gbid in custids:
                selectargs["method"] = custids[gbid]
            self[gbid] = self._parse_gb(gbpath, parser, **selectargs)

    def _parse_gb(self, gbpath, parser, **selectargs):
        """Parses a given file into a :class: `GrainBoundary` object
        """
        t = parser(gbpath)
        self.axis = selectargs.get('coord', None)
        if self.axis is None:
            dif = t.xyz.max(axis=0) - t.xyz.min(axis=0)
            self.axis = dif.argmax()
            selectargs['coord'] = self.axis
        return t.gb(padding=self.padding, **selectargs)

    def trim(self):
        """Removes the atoms from each grain boundary that were included as
        padding for the SOAP vectors.
        """
        for gbid, gb in self.items():
            gb.trim()

    def soap(self,
             lmax=10,
             nmax=10,
             rcut=5.,
             autotrim=True,
             multires=None,
             **soapargs):
        """Calculates the SOAP vector matrix for the atomic environments at
        each grain boundary.

        Args:
            autotrim (boolean): If true will automatically call :meth: `self.trim`
            soapargs (dict): key-value pairs of the SOAP parameters (see :class: `SOAPCalculator`)
        """
        defargs = {'lmax': lmax, 'nmax': nmax, 'rcut': rcut}
        soapargs.update(defargs)
        if multires is not None:
            self.repargs["soap"] = multires
            self.store.configure("soap", multires)
            for args in multires:
                assert abs(args["rcut"] - self.padding / 2.) < 1e-8
        else:
            self.repargs["soap"] = soapargs
            self.store.configure("soap", **soapargs)
            assert abs(soapargs["rcut"] - self.padding / 2.) < 1e-8
        P = self.store.P

        if len(P) == len(self):
            if autotrim:
                self.trim()
            for gbid, gb in self.items():
                self[gbid].rep_params["soap"] = soapargs
                #No need to recompute if the store has the result.
            return P

        if multires is not None:
            for gbid, gb in tqdm(self.items()):
                soap = []
                for args in multires:
                    isoap = gb.soap(cache=False, **args)
                    soap.append(isoap)
                res = np.hstack(soap)
                P[gbid] = res
        else:
            for gbid, gb in tqdm(self.items()):
                P[gbid] = gb.soap(cache=False, **soapargs)

        if autotrim:
            self.trim()

    def scatter(self,
                density=0.5,
                Layers=2,
                SPH_L=6,
                n_trans=8,
                n_angle1=8,
                n_angle2=8,
                threads=0,
                multires=None,
                **scatterargs):
        """Calculates the Scatter vectors for each grain boundary.

        Args:
            threads (int): the number of threads to use. If 0, the number of cores
              will try to be determined from multiprocessing.cpu_count(). If this fails
              1 thread will be used.
            scatterargs (dict): key-value pairs of the Scatter parameters (see :module: `SNET`)
        """
        defargs = {
            "density": density,
            "Layers": Layers,
            "SPH_L": SPH_L,
            "n_trans": n_trans,
            "n_angle1": n_angle1,
            "n_angle2": n_angle2
        }
        scatterargs.update(defargs)
        if multires is not None:
            self.repargs["scatter"] = multires
            self.store.configure("scatter", multires)
        else:
            self.repargs["scatter"] = scatterargs
            self.store.configure("scatter", **scatterargs)
        Scatter = self.store.Scatter

        if len(Scatter) == len(self):
            for gbid, gb in self.items():
                self[gbid].rep_params["scatter"] = scatterargs
            #No need to recompute if the store has the result.
            return Scatter

        if threads == 0:  # pragma: no cover
            try:
                threads = mp.cpu_count()
            except NotImplementedError:
                msg.warn("Unable able to determine number of available CPU's, "
                         "resorting to 1 thread")
                threads = 1

        pbar = tqdm(total=len(self))

        def _update(res):
            """Updates the tqdm bar and
                adds the completed scatter vector to the storce
            """
            pbar.update()
            Scatter[res[0]] = res[1]

        pool = mp.Pool(processes=threads)
        result = {}

        if multires is not None:
            for gbid, gb in self.items():
                pool.apply_async(_mutlires_scatter_mp,
                                 args=(gbid, gb, multires),
                                 callback=_update)
        else:
            for gbid, gb in self.items():
                pool.apply_async(_scatter_mp,
                                 args=(gbid, gb, scatterargs),
                                 callback=_update)
        pool.close()
        pool.join()

    @property
    def Scatter(self):
        """Returns the computed Scatter vectors for each GB in the collection
        """
        result = self.store.Scatter
        if len(result) == 0:
            msg.info("The Scatter vectors haven't been computed yet. Use "
                     ":meth:`scatter`.")

        return result

    @property
    def SM(self):
        """Returns the Scatter feature matrix based on the current Scatter parameters
        """
        Scatter = self.Scatter
        if len(Scatter) == 0:
            return
        size = 0
        with Scatter[self.keys()[0]] as scat:
            size = len(scat)
        matrix = np.zeros((len(self), size))
        for id, gbid in enumerate(self):
            with Scatter[gbid] as scat:
                matrix[id] = scat

        return matrix

    @property
    def P(self):
        """Returns the computed SOAP matrices for each GB in the collection.
        """
        result = self.store.P
        if len(result) == 0:
            msg.info("The SOAP matrices haven't been computed yet. Use "
                     ":meth:`soap`.")

        return result

    @property
    def ASR(self):
        """Returns the ASR for the GB collection.
        """
        result = self.store.ASR
        P = self.P

        if result is None and len(P) > 0:
            soaps = []
            for gbid in P.gbids:
                with P[gbid] as Pi:
                    soaps.append(np.sum(Pi, axis=0))

            result = np.vstack(soaps)
            self.store.ASR = result

        return result

    def U(self, eps, **kwargs):
        """Returns the uniquified set of environments for the GB collection and
        current `soapargs`.

        .. note:: This method also assigns and adds the LAE number to each atom in
        each grain boundary in the collection

        Args:
            eps (float): similarity threshlod parameter.
            kwargs (dict): used to pass the desired parameters to the
              Locality Sensitive Hashing algorithm used in :meth:`uniquify`.
        """
        result = None
        U = self.store.U
        if eps in U:
            result = U[eps]

        if result is None:
            result = self.uniquify(eps, **kwargs)
            U[eps] = result
            self.store.U = U

        for gbid in self.gbfiles:
            LAEs = result["GBs"][gbid]
            self._assign(self[gbid], LAEs, result["U"])

    #Fills the GrainBoundaryCollection LAE property with the LAEs
        for id, soap in enumerate(result["U"].values()):
            self.LAE[id] = lae.LAE(id, soap)

        return result

    def _assign(self, gb, LAEs, U):
        """Assigns and fills LAEs for the specified Grain Boundary
        """
        #Just grab the atom ids from each list and then assign that
        #particular atom the corresponding unique signature. Note that
        #each unique signature atom list has the unique signature as the
        #first element, which is why the range starts at 1.
        #This also populates the atoms objects with their corresponding LAE numbers

        for u, elist in LAEs.items():
            for PID, VID in elist[1:]:
                gb.LAEs[VID] = u
        #LAE = [U.keys().index(x) for x in gb.LAEs]
        #gb.atoms.add_property("LAE",LAE)

    def uniquify(self, eps, **kwargs):
        """Extracts all the unique LAEs in the entire GB system using the
        specified `epsilon` similarity value.

        .. warning:: This method does not verify the completion status of any
          previous :meth:`uniquify` attempts. It just re-runs everything and
          clobbers any existing results for the specified value of `epsilon`.

        .. note:: This method implements a Locality Sensitive Hashing algorithm
        to aproximate the nearest cluster for each SOAP vector. For their documentation
        refer to https://falconn-lib.org/.

        Args:
            eps (float): similarity scores below this value are considered
              identical. Two actually identical GBs will have a similarity score
              of `0` by this metric, so smaller is more similar.
            kwargs (dict): Hold the values that are passed falconn to setup the
             hash tables in :meth:`setup_hash_tables`.

        Returns:
            dict: with keys `U` and `GBs`. The `U` key has a dictionary of
            `(PID, VID)` identifiers for the unique LAEs in the GB collection. The
            values are the corresponding SOAP vectors. `GBs` is a dictionary with
            `gbid` keys and values being a `dict` keyed by unique LAEs with values a
            list of `(PID, VID)` identifiers from the global GB collection.
        """
        from tqdm import tqdm
        result = {"U": None, "GBs": {}}

        #We pre-seed the list of unique environments with perfect FCC.
        if self.seed is None:
            raise ValueError("Cannot uniquify LAEs without a seed LAE.")

        U = OrderedDict()
        U[('0', 0)] = self.seed

        for gbid in tqdm(self.gbfiles):
            with self.P[gbid] as NP:
                self._uniquify(NP, gbid, U, eps)

        #Create the hash tables and the query object needed for the LSH algorithm
        used = {k: False for k in U}
        query = self.setup_hash_tables(np.vstack(U.values()), **kwargs)

        #With the alogrithm setup loop through all the vectors to find its
        #approximate nearest unique neighbor
        for gbid in tqdm(self.gbfiles):
            with self.P[gbid] as NP:
                LAEs = self._classify(NP, gbid, U, query, used)
                result["GBs"][gbid] = LAEs

        #Now, remove any LAEs from U that didn't get used. We shouldn't really
        #have many of these.
        for k, v in used.items():
            if not v:  # pragma: no cover
                del U[k]

        #Populate the result dict with the final unique LAEs. We want to store
        #these ordered by similarity to the seed U.
        from operator import itemgetter
        K = {u: dissimilarity(v, self.seed) for u, v in U.items()}
        Us = OrderedDict(sorted(K.items(), key=itemgetter(1), reverse=True))
        result["U"] = OrderedDict([(u, U[u]) for u in Us])
        return result

    def setup_hash_tables(self, data, threads=0, probes=50):
        """Creates hash tables for an efficient approximate nearest neighbor
        search

        Args:
            data (numpy.ndarray): matrix where each row is a unique vector
            threads (int): the number of threads desired to setup the
                 Locality Sensitive Hash hash tables. If the number of threads is 0
                 the maximum number of available hardware threads found will be used
                 up to the number of hash tables 10. 0 is selected by default.
            probes (int): the number of probes each query will make over all the
                 hash tables. (The higher number of probes the more accurate the search,
                 but the longer it will take [Needs Verification]).

        Returns
            query object from falconn to search the created table.
        """
        import falconn
        params = falconn.get_default_parameters(data.shape[0], len(self.seed))
        params.num_setup_threads = threads
        table = falconn.LSHIndex(params)
        table.setup(data)
        query = table.construct_query_object()
        query.set_num_probes(probes)
        return query

    def _uniquify(self, NP, gbid, uni, eps):
        """Runs the first unique identification pass through the collection. Calculates
        the unique SOAP vectors in the given GB relative to the current set of
        unique ones.

        .. note:: This version includes refactoring by Jonathan Priedemann.

        Args:
            NP (numpy.ndarray): matrix of SOAP vectors for the grain boundary.
            gbid (str): id of the grain boundary in the publication set.
            uni (dict): keys are `tuple` of (PID, VID) with `PID` the
              publication Id of the grain boundary and `VID` the id of the SOAP
              vector in that GBs descriptor matrix. Value is the actual SOAP
              vector already found to be unique for some value of `eps`.
            eps (float): cutoff value for deciding whether two vectors are unique.

        Returns:
            dict: keys are `tuple` of (PID, VID) linked to `uni`; values are a
            list of `tuple` (PID, VID) of vectors similar to the key.
        """
        for i in range(len(NP)):
            Pv = NP[i, :]
            for u in list(uni.keys()):
                uP = uni[u]
                K = dissimilarity(Pv, uP)
                if K < eps:
                    #This vector already has at least one possible classification
                    break
            else:
                #Numpy slicing increases the ref count, so if the sliced array is not
                #copied than the array cannot be garbage collected when the context manager
                #deletes the original array
                uni[(gbid, i)] = np.copy(Pv)

    def _classify(self, NP, PID, uni, query, used=None):
        """Runs through the collection a second time to find the aproximate nearest
        unique LAE identified in :meth:`_uniquify`.
        """
        result = {}

        for u in uni:
            result[u] = [u]

        for i in range(len(NP)):
            Pv = NP[i, :]
            neighbor = uni.keys()[query.find_nearest_neighbor(Pv)]
            result[neighbor].append((PID, i))
            if used is not None:
                used[neighbor] = True

        return result

    def features(self, eps):
        """Calculates the feature descriptor for the given `eps` value and
        places it in the store.

        Args:
            eps (float): cutoff value for deciding whether two vectors are
              unique.
        """
        result = None
        features = self.store.features
        if eps in features:
            result = features[eps]

        if result is None:
            U = self.U(eps)
            result = list(U["U"].keys())
            features[eps] = result
            self.store.features = features
            self._create_feature_map(eps)

        return result

    def LER(self, eps, **kwargs):
        """Produces the LAE fingerprint for each GB in the system. The LAE
        figerprint is the percentage of the GBs local environments that belong to
        each unique LAE type.

        .. note:: the method :meth:`soap` must be called before, otherwise the
            store might not be configured properly.

        Args:
            eps (float): cutoff value for deciding whether two vectors are
              unique.
            kwargs (dict): used to pass the desired parameters to the
              Locality Sensitive Hashing algorithm used in :meth:`uniquify`.

        Returns:
            numpy.ndarray: rows represent GBs; columns are the percentage of unique
              local environments of each type in each GB.
        """
        result = None
        LER = self.store.LER
        if eps in LER:
            result = LER[eps]

        if result is None:
            U = self.U(eps, **kwargs)

            #Next, loop over each GB and count how many of each kind it has.
            result = np.zeros((len(self), len(U["U"])))
            for gbi, gbid in enumerate(self):
                result[gbi] = self._LER(self[gbid], U["U"], False)[:]

            LER[eps] = result
            self.store.LER = LER

        return result

    def _LER(self, gb, U, cache=True):
        """Calculates the LER for the specified Grain Boundary

            Args:
                gb (GrainBoundary): an instance of class:`GrainBoundary` on which to
                    calculate the LER
                U (dict): keys are the unique vector ids in the form (gbid, pid),
                    while the values are the actual unique vectors
                cache (boolean): set to false if the result should not be cached in memory
        """
        result = np.zeros(len(U))
        for ui, uid in enumerate(U):
            result[ui] = gb.LAEs.count(uid)
        #Normalize by the total number of atoms of each type
        N = np.sum(result[:])
        assert N == len(gb)
        result[:] /= N
        if cache:
            gb.LER = result
        return result

    def analyze_other(self, name, analysis, **kwargs):
        """Analyzies the given GB based on the given anlysis argument

        Args:
            name (string): the id of the grain boundary to analyze, which
                corresponds to the key of the GB in the others dictionary
            analysis (string): the analysis desired. This must be from the list
                (LER, ASR, Scatter), and the necessary analysis must also be
                done for the GrainBoundaryCollection.
            kwargs (dict): any additional arguments needed for the analysis
                requested

            Returns:
                The analysis given by the specified method.
        """
        analysismap = {
            "LER": self._other_LER,
            "ASR": self._other_ASR,
            "Scatter": self._other_Scatter
        }
        return analysismap[analysis](name, self.others[name], **kwargs)

    def _other_LER(self, name, gb, cache=True, **kwargs):
        """Analyzies the Given GB based on LER

        .. note: this will automatically use the soap args stored in self

        Args:
            name (string): the unique id of the GrainBoundary
            gb (GrainBoundary): an instance of class:`GrainBoundary` on which to
              perform the LER analysis.
            eps (float): `eps` value used in finding the set of unique LAEs in
              the GB system.
            cache (boolean): set to false if the LER should not be stored in the
               class:`GrainBoundary` object itself as gb.LER
             kwargs (dict): the arguments sent to meth:`setup_hash_tables` and the
               epsilon value

        Returns:
            The LER of the specified GrainBoundary

        """
        if 'eps' not in kwargs:  # pragma: no cover
            raise ValueError("Epsilon is required for LER analysis")
        eps = kwargs.pop('eps')
        gb.soap(**self.repargs["soap"])
        gb.trim()
        U = self.U(eps)['U']

        query = self.setup_hash_tables(np.vstack(U.values()), **kwargs)
        LAEs = self._classify(gb.P, name, U, query)
        self._assign(gb, LAEs, U)

        return self._LER(gb, U)

    def _other_ASR(self, name, gb, cache=True, **kwargs):
        """Analyzies the given GrainBoundary with the ASR representation

        """
        raise NotImplementedError()

    def _other_Scatter(self, name, gb, cache=True, **kwargs):
        """Analyzies the given GrainBoundary with the Scattering Transformation

        """
        raise NotImplementedError()

    def feature_map_file(self, eps):
        """Returns the full path to the feature map file.

        Args:
            eps (float): `eps` value used in finding the set of unique LAEs in
              the GB system.
        """
        filename = "{0:.5f}-features.dat".format(eps)
        return path.join(self.store.features_, filename)

    def _create_feature_map(self, eps):
        """Creates a feature map file that interoperates with the XGBoost boosters
        dump method.

        .. note:: It is important that the list of features has the *same order* as
          the features in the matrix that the model was trained on.

        Args:
            eps (float): `eps` value used in finding the set of unique LAEs in
              the GB system.
        """
        with open(self.feature_map_file(eps), 'w') as outfile:
            for i, feat in enumerate(self.store.features[eps]):
                outfile.write('{0}\t{1}-{2}\tq\n'.format(i, *feat))

    def importance(self, eps, model):
        """Calculates the feature importances based on the specified XGBoost
        model.

        .. note:: The model needs to have been fitted to the data before
          calling this method

        Args:
            eps (float): `eps` value used in finding the set of unique LAEs in
              the GB system.
            model: one of :class:`xgboost.XGBClassifier` or
              :class:`xgboost.XGBRegressor`.
        """
        from gblearn.analysis import order_features_by_gains
        mapfile = self.feature_map_file(eps)
        gains = order_features_by_gains(model.get_booster(), mapfile)
        result = {"cover": [], "gain": []}
        for key, gdict in gains:
            result["cover"].append((key, gdict["cover"]))
            result["gain"].append((key, gdict["gain"]))

        return result
예제 #8
0
파일: test_io.py 프로젝트: hensldm/gblearn
def test_errors():
    """Tests raising of exceptions for faulty values.
    """
    xstore = ResultStore(range(1, 8))
    with pytest.raises(KeyError):
        xstore.configure("soap", lmax=8, nmax=None, rcut=4.)