Ejemplo n.º 1
0
    def load(self, f):
        """
        Parameters
        ----------
        f : str or file-like
            The file to load the spectrum from, or a str that specifies the
            file name
        """
        if hasattr(f, "name"):
            fname = f.name
        else:
            fname = f

        with possibly_open_file(f, "r") as g:
            self.orso = load_orso(g)

        _data = self.orso[0].data[:, :4].T
        # ORSO files save resolution information as SD,
        # internally refnx uses FWHM
        if _data.shape[1] > 3:
            _data[3] *= 2.3548

        self.data = _data
        self.filename = fname
        self.name = os.path.splitext(os.path.basename(fname))[0]
Ejemplo n.º 2
0
    def load(self, f):
        """
        Loads a dataset from file. Must be 2 to 4 column ASCII.

        Parameters
        ----------
        f : file-handle or string
            File to load the dataset from.

        """
        # see if there are header rows
        with possibly_open_file(f, 'rb') as g:
            header_lines = 0
            for i, line in enumerate(g):
                try:
                    nums = [float(tok) for tok in
                            re.split(r"\s|,", line.decode('utf-8'))
                            if len(tok)]
                    if len(nums) >= 2:
                        header_lines = i
                        break
                except ValueError:
                    continue

        self.data = np.loadtxt(f, unpack=True, skiprows=header_lines)

        if hasattr(f, 'read'):
            fname = f.name
        else:
            fname = f

        self.filename = fname
        self.name = os.path.splitext(os.path.basename(fname))[0]
Ejemplo n.º 3
0
    def save_xml(self, f, start_time=0):
        """
        Saves the reflectivity data to an XML file.

        Parameters
        ----------
        f : str or file-like
            The file to write the spectrum to, or a str that specifies the file
            name
        start_time: int, optional
            Epoch time specifying when the sample started
        """
        s = string.Template(_template_ref_xml)
        self.time = time.strftime('%Y-%m-%dT%H:%M:%S',
                                  time.localtime(start_time))
        # self.time = time.strftime(
        # datetime.fromtimestamp(start_time).isoformat()
        # filename = 'c_PLP{:07d}_{:d}.xml'.format(self._rnumber[0], 0)

        self._ydata = repr(self.y.tolist()).strip(',[]')
        self._xdata = repr(self.x.tolist()).strip(',[]')
        self._ydataSD = repr(self.y_err.tolist()).strip(',[]')
        self._xdataSD = repr(self.x_err.tolist()).strip(',[]')

        thefile = s.safe_substitute(self.__dict__)

        with possibly_open_file(f, 'wb') as g:
            if 'b' in g.mode:
                thefile = thefile.encode('utf-8')

            g.write(thefile)
Ejemplo n.º 4
0
    def load_model(self, f):
        """
        Load a serialised model.

        Parameters
        ----------
        f: file like or str
            pickle file to load model from.
        """
        with possibly_open_file(f) as g:
            reflect_model = pickle.load(g)
            self.set_model(reflect_model)
        self._print(repr(self.objective))
Ejemplo n.º 5
0
def load_chain(f):
    """
    Loads a chain from disk. Does not change the state of a CurveFitter
    object.

    Parameters
    ----------
    f : str or file-like
        File containing the chain.

    Returns
    -------
    chain : array
        The loaded chain - `(nsteps, nwalkers, ndim)` or
        `(nsteps, ntemps, nwalkers, ndim)`
    """
    with possibly_open_file(f, 'r') as g:
        # read header
        header = g.readline()
        expr = re.compile('(\d+)')
        matches = expr.findall(header)
        if matches:
            if len(matches) == 3:
                ntemps, nwalkers, ndim = map(int, matches)
                chain_size = ntemps * nwalkers * ndim
            elif len(matches) == 2:
                ntemps = None
                nwalkers, ndim = map(int, matches)
                chain_size = nwalkers * ndim
        else:
            raise ValueError("Couldn't read header line of chain file")

        # make an array that's the appropriate size
        read_arr = array.array("d")

        for i, l in enumerate(g, 1):
            read_arr.extend(np.fromstring(l,
                                          dtype=float,
                                          count=chain_size,
                                          sep=' '))

        chain = np.frombuffer(read_arr, dtype=np.float, count=len(read_arr))

        if ntemps is not None:
            chain = np.reshape(chain, (i, ntemps, nwalkers, ndim))
        else:
            chain = np.reshape(chain, (i, nwalkers, ndim))

        return chain
Ejemplo n.º 6
0
    def save_model(self, f=None):
        """
        Serialise a model to a pickle file.

        Parameters
        ----------
        f: file like or str
            File to save model to.
        """
        if f is None:
            f = 'model_' + datetime.datetime.now().isoformat() + '.pkl'
            if self.dataset is not None:
                f = 'model_' + self.dataset.name + '.pkl'

        with possibly_open_file(f) as g:
            pickle.dump(self.model, g)
Ejemplo n.º 7
0
    def save_model(self, *args, f=None):
        """
        Serialise a model to a pickle file.
        If `f` is not specified then the file name is constructed from the
        current dataset name; if there is no current dataset then the filename
        is constructed from the current time. These constructed filenames will
        be in the current working directory, for a specific save location `f`
        must be provided.

        Parameters
        ----------
        f: file like or str, optional
            File to save model to.
        """
        if f is None:
            f = 'model_' + datetime.datetime.now().isoformat() + '.pkl'
            if self.dataset is not None:
                f = 'model_' + self.dataset.name + '.pkl'

        with possibly_open_file(f) as g:
            pickle.dump(self.model, g)
Ejemplo n.º 8
0
    def load_model(self, *args, f=None):
        """
        Load a serialised model.
        If `f` is not specified then an attempt will be made to find a model
        corresponding to the current dataset name,
        `'model_' + self.dataset.name + '.pkl'`. If there is no current
        dataset then the most recent model will be loaded.
        This method is only intended to be used to deserialise models created
        by this interactive Jupyter widget modeller, and will not successfully
        load complicated ReflectModel created outside of the interactive
        modeller.

        Parameters
        ----------
        f: file like or str, optional
            pickle file to load model from.
        """
        if f is None and self.dataset is not None:
            # try and load the model corresponding to the current dataset
            f = "model_" + self.dataset.name + ".pkl"
        elif f is None:
            # load the most recent model file
            files = list(filter(os.path.isfile, glob.glob("model_*.pkl")))
            files.sort(key=lambda x: os.path.getmtime(x))
            files.reverse()
            if len(files):
                f = files[0]

        if f is None:
            self._print("No model file is specified/available.")
            return

        try:
            with possibly_open_file(f, "rb") as g:
                reflect_model = pickle.load(g)
            self.set_model(reflect_model)
        except (RuntimeError, FileNotFoundError) as exc:
            # RuntimeError if the file isn't a ReflectModel
            # FileNotFoundError if the specified file name wasn't found
            self._print(repr(exc), repr(f))
Ejemplo n.º 9
0
def load_chain(f):
    """
    Loads a chain from disk. Does not change the state of a CurveFitter
    object.

    Parameters
    ----------
    f : str or file-like
        File containing the chain.

    Returns
    -------
    chain : array
        The loaded chain - `(nsteps, nwalkers, ndim)` or
        `(nsteps, ntemps, nwalkers, ndim)`
    """
    with possibly_open_file(f, "r") as g:
        # read header
        header = g.readline()
        expr = re.compile(r"(\d+)")
        matches = expr.findall(header)
        if matches:
            if len(matches) == 3:
                ntemps, nwalkers, ndim = map(int, matches)
            elif len(matches) == 2:
                ntemps = None
                nwalkers, ndim = map(int, matches)
        else:
            raise ValueError("Couldn't read header line of chain file")

    chain = np.loadtxt(f)

    if ntemps is not None:
        chain = np.reshape(chain, (-1, ntemps, nwalkers, ndim))
    else:
        chain = np.reshape(chain, (-1, nwalkers, ndim))

    return chain
Ejemplo n.º 10
0
    def write_offspecular(self, f, scanpoint=0):
        d = dict()
        d['time'] = strftime("%a, %d %b %Y %H:%M:%S +0000", gmtime())
        d['_rnumber'] = self.reflected_beam.datafile_number
        d['_numpointsz'] = np.size(self.m_ref, 1)
        d['_numpointsy'] = np.size(self.m_ref, 2)

        s = string.Template(_template_ref_xml)

        # filename = 'off_PLP{:07d}_{:d}.xml'.format(self._rnumber, index)
        d['_r'] = repr(self.m_ref[scanpoint].tolist()).strip(',[]')
        d['_qz'] = repr(self.m_qz[scanpoint].tolist()).strip(',[]')
        d['_dr'] = repr(self.m_ref_err[scanpoint].tolist()).strip(',[]')
        d['_qx'] = repr(self.m_qx[scanpoint].tolist()).strip(',[]')

        thefile = s.safe_substitute(d)

        with possibly_open_file(f, 'wb') as g:
            if 'b' in g.mode:
                thefile = thefile.encode('utf-8')

            g.write(thefile)
            g.truncate()
Ejemplo n.º 11
0
    def save_model(self, *args, f=None):
        """
        Serialise a model to a pickle file.
        If `f` is not specified then the file name is constructed from the
        current dataset name; if there is no current dataset then the filename
        is constructed from the current time. These constructed filenames will
        be in the current working directory, for a specific save location `f`
        must be provided.
        This method is only intended to be used to serialise models created by
        this interactive Jupyter widget modeller.

        Parameters
        ----------
        f: file like or str, optional
            File to save model to.
        """
        if f is None:
            f = "model_" + datetime.datetime.now().isoformat() + ".pkl"
            if self.dataset is not None:
                f = "model_" + self.dataset.name + ".pkl"

        with possibly_open_file(f) as g:
            pickle.dump(self.model, g)
Ejemplo n.º 12
0
    def write_offspecular(self, f, scanpoint=0):
        d = dict()
        d["time"] = strftime("%a, %d %b %Y %H:%M:%S +0000", gmtime())
        d["_rnumber"] = self.reflected_beam.datafile_number
        d["_numpointsz"] = np.size(self.m_ref, 1)
        d["_numpointsy"] = np.size(self.m_ref, 2)

        s = string.Template(_template_ref_xml)

        # filename = 'off_PLP{:07d}_{:d}.xml'.format(self._rnumber, index)
        d["_r"] = repr(self.m_ref[scanpoint].tolist()).strip(",[]")
        d["_qz"] = repr(self.m_qz[scanpoint].tolist()).strip(",[]")
        d["_dr"] = repr(self.m_ref_err[scanpoint].tolist()).strip(",[]")
        d["_qx"] = repr(self.m_qx[scanpoint].tolist()).strip(",[]")

        thefile = s.safe_substitute(d)

        with possibly_open_file(f, "wb") as g:
            if "b" in g.mode:
                thefile = thefile.encode("utf-8")

            g.write(thefile)
            g.truncate()
Ejemplo n.º 13
0
    def load_model(self, *args, f=None):
        """
        Load a serialised model.
        If `f` is not specified then an attempt will be made to find a model
        corresponding to the current dataset name,
        `'model_' + self.dataset.name + '.pkl'`. If there is no current
        dataset then the most recent model will be loaded.

        Parameters
        ----------
        f: file like or str, optional
            pickle file to load model from.
        """
        if f is None and self.dataset is not None:
            # try and load the model corresponding to the current dataset
            f = 'model_' + self.dataset.name + '.pkl'
        elif f is None:
            # load the most recent model file
            files = list(filter(os.path.isfile, glob.glob("model_*.pkl")))
            files.sort(key=lambda x: os.path.getmtime(x))
            files.reverse()
            if len(files):
                f = files[0]

        if f is None:
            self._print("No model file is specified/available.")
            return

        try:
            with possibly_open_file(f, 'rb') as g:
                reflect_model = pickle.load(g)
            self.set_model(reflect_model)
        except (RuntimeError, FileNotFoundError) as exc:
            # RuntimeError if the file isn't a ReflectModel
            # FileNotFoundError if the specified file name wasn't found
            self._print(repr(exc), repr(f))
Ejemplo n.º 14
0
    def load(self, f):
        """
        Loads a dataset from file. Must be 2 to 4 column ASCII.

        Parameters
        ----------
        f : file-handle or string
            File to load the dataset from.

        """
        # it would be nicer to simply use np.loadtxt, but this is an
        # attempt to auto ignore header lines.
        with possibly_open_file(f, "r") as g:
            lines = list(reversed(g.readlines()))
            x = list()
            y = list()
            y_err = list()
            x_err = list()

            # a marker for how many columns in the data there will be
            numcols = 0
            for i, line in enumerate(lines):
                try:
                    # parse a line for numerical tokens separated by whitespace
                    # or comma
                    nums = [
                        float(tok) for tok in re.split(r"\s|,", line)
                        if len(tok)
                    ]
                    if len(nums) in [0, 1]:
                        # might be trailing newlines at the end of the file,
                        # just ignore those
                        continue
                    if not numcols:
                        # figure out how many columns one has
                        numcols = len(nums)
                    elif len(nums) != numcols:
                        # if the number of columns changes there's an issue
                        break
                    x.append(nums[0])
                    y.append(nums[1])
                    if len(nums) > 2:
                        y_err.append(nums[2])
                    if len(nums) > 3:
                        x_err.append(nums[3])
                except ValueError:
                    # you should drop into this if you can't parse tokens into
                    # a series of floats. But the text may be meta-data, so
                    # try to carry on.
                    continue

        x.reverse()
        y.reverse()
        y_err.reverse()
        x_err.reverse()

        if len(x) == 0:
            raise RuntimeError("Datafile didn't appear to contain any data (or"
                               " was the wrong format)")

        if numcols < 3:
            y_err = None
        if numcols < 4:
            x_err = None

        self.data = (x, y, y_err, x_err)

        if hasattr(f, "read"):
            fname = f.name
        else:
            fname = f

        self.filename = fname
        self.name = os.path.splitext(os.path.basename(fname))[0]
Ejemplo n.º 15
0
    def sample(
        self,
        steps,
        nthin=1,
        random_state=None,
        f=None,
        callback=None,
        verbose=True,
        pool=-1,
    ):
        """
        Performs sampling from the objective.

        Parameters
        ----------
        steps : int
            Collect `steps` samples into the chain. The sampler will run a
            total of `steps * nthin` moves.
        nthin : int, optional
            Each chain sample is separated by `nthin` iterations.
        random_state : {int, `np.random.RandomState`, `np.random.Generator`}
            If `random_state` is not specified the `~np.random.RandomState`
            singleton is used.
            If `random_state` is an int, a new ``RandomState`` instance is
            used, seeded with random_state.
            If `random_state` is already a ``RandomState`` or a ``Generator``
            instance, then that object is used.
            Specify `random_state` for repeatable minimizations.
        f : file-like or str
            File to incrementally save chain progress to. Each row in the file
            is a flattened array of size `(nwalkers, ndim)` or
            `(ntemps, nwalkers, ndim)`. There are `steps` rows in the
            file.
        callback : callable
            callback function to be called at each iteration step. Has the
            signature `callback(coords, logprob)`.
        verbose : bool, optional
            Gives updates on the sampling progress
        pool : int or map-like object, optional
            If `pool` is an `int` then it specifies the number of threads to
            use for parallelization. If `pool == -1`, then all CPU's are used.
            If pool is a map-like callable that follows the same calling
            sequence as the built-in map function, then this pool is used for
            parallelisation.

        Notes
        -----
        Please see :class:`emcee.EnsembleSampler` for its detailed behaviour.

        >>> # we'll burn the first 500 steps
        >>> fitter.sample(500)
        >>> # after you've run those, then discard them by resetting the
        >>> # sampler.
        >>> fitter.sampler.reset()
        >>> # Now collect 40 steps, each step separated by 50 sampler
        >>> # generations.
        >>> fitter.sample(40, nthin=50)

        One can also burn and thin in `Curvefitter.process_chain`.
        """
        self._check_vars_unchanged()

        # setup a random number generator
        rng = check_random_state(random_state)

        if self._state is None:
            self.initialise(random_state=rng)

        # for saving progress to file
        def _callback_wrapper(state, h=None):
            if callback is not None:
                callback(state.coords, state.log_prob)

            if h is not None:
                h.write(" ".join(map(str, state.coords.ravel())))
                h.write("\n")

        # remove chains from each of the parameters because they slow down
        # pickling but only if they are parameter objects.
        flat_params = f_unique(flatten(self.objective.parameters))
        flat_params = [param for param in flat_params if is_parameter(param)]
        # zero out all the old parameter stderrs
        for param in flat_params:
            param.stderr = None
            param.chain = None

        # make sure the checkpoint file exists
        if f is not None:
            with possibly_open_file(f, "w") as h:
                # write the shape of each step of the chain
                h.write("# ")
                shape = self._state.coords.shape
                h.write(", ".join(map(str, shape)))
                h.write("\n")

        # set the random state of the sampler
        # normally one could give this as an argument to the sample method
        # but PTSampler didn't historically accept that...
        if isinstance(rng, np.random.RandomState):
            rstate0 = rng.get_state()
            self._state.random_state = rstate0
            self.sampler.random_state = rstate0

        # using context manager means we kill off zombie pool objects
        # but does mean that the pool has to be specified each time.
        with MapWrapper(pool) as g, possibly_open_file(f, "a") as h:
            # these kwargs are provided to the sampler.sample method
            kwargs = {"iterations": steps, "thin": nthin}

            # if you're not creating more than 1 thread, then don't bother with
            # a pool.
            if isinstance(self.sampler, emcee.EnsembleSampler):
                if pool == 1:
                    self.sampler.pool = None
                else:
                    self.sampler.pool = g
            else:
                kwargs["mapper"] = g

            # new emcee arguments
            sampler_args = getargspec(self.sampler.sample).args
            if "progress" in sampler_args and verbose:
                kwargs["progress"] = True
                verbose = False

            if "thin_by" in sampler_args:
                kwargs["thin_by"] = nthin
                kwargs.pop("thin", 0)

            # perform the sampling
            for state in self.sampler.sample(self._state, **kwargs):
                self._state = state
                _callback_wrapper(state, h=h)

        if isinstance(self.sampler, emcee.EnsembleSampler):
            self.sampler.pool = None

        # sets parameter value and stderr
        return process_chain(self.objective, self.chain)
Ejemplo n.º 16
0
    def sample(self, steps, nthin=1, random_state=None, f=None, callback=None,
               verbose=True, pool=0):
        """
        Performs sampling from the objective.

        Parameters
        ----------
        steps : int
            Collect `steps` samples into the chain. The sampler will run a
            total of `steps * nthin` moves.
        nthin : int, optional
            Each chain sample is separated by `nthin` iterations.
        random_state : int or `np.random.RandomState`, optional
            If `random_state` is an int, a new `np.random.RandomState` instance
            is used, seeded with `random_state`.
            If `random_state` is already a `np.random.RandomState` instance,
            then that `np.random.RandomState` instance is used. Specify
            `random_state` for repeatable sampling
        f : file-like or str
            File to incrementally save chain progress to. Each row in the file
            is a flattened array of size `(nwalkers, ndim)` or
            `(ntemps, nwalkers, ndim)`. There are `steps` rows in the
            file.
        callback : callable
            callback function to be called at each iteration step
        verbose : bool, optional
            Gives updates on the sampling progress
        pool : int or map-like object, optional
            If `pool` is an `int` then it specifies the number of threads to
            use for parallelization. If `pool == 0`, then all CPU's are used.
            If pool is an object with a map method that follows the same
            calling sequence as the built-in map function, then this pool is
            used for parallelisation.

        Notes
        -----
        Please see :class:`emcee.EnsembleSampler` for its detailed behaviour.

        >>> # we'll burn the first 500 steps
        >>> fitter.sample(500)
        >>> # after you've run those, then discard them by resetting the
        >>> # sampler.
        >>> fitter.sampler.reset()
        >>> # Now collect 40 steps, each step separated by 50 sampler
        >>> # generations.
        >>> fitter.sample(40, nthin=50)

        One can also burn and thin in `Curvefitter.process_chain`.
        """
        self._check_vars_unchanged()

        if self._state is None:
            self.initialise()

        self.__pt_iterations = 0
        if isinstance(self.sampler, PTSampler):
            steps *= nthin

        # for saving progress to file
        def _callback_wrapper(state, h=None):
            if callback is not None:
                callback(state.coords, state.log_prob)

            if h is not None:
                # if you're parallel tempering, then you only
                # want to save every nthin
                if isinstance(self.sampler, PTSampler):
                    self.__pt_iterations += 1
                    if self.__pt_iterations % nthin:
                        return None

                h.write(' '.join(map(str, state.coords.ravel())))
                h.write('\n')

        # set the random state of the sampler
        # normally one could give this as an argument to the sample method
        # but PTSampler didn't historically accept that...
        if random_state is not None:
            rstate0 = check_random_state(random_state).get_state()
            self._state.random_state = rstate0
            if isinstance(self.sampler, PTSampler):
                self.sampler._random = rstate0

        # remove chains from each of the parameters because they slow down
        # pickling but only if they are parameter objects.
        flat_params = f_unique(flatten(self.objective.parameters))
        flat_params = [param for param in flat_params if is_parameter(param)]
        # zero out all the old parameter stderrs
        for param in flat_params:
            param.stderr = None
            param.chain = None

        # make sure the checkpoint file exists
        if f is not None:
            with possibly_open_file(f, 'w') as h:
                # write the shape of each step of the chain
                h.write('# ')
                shape = self._state.coords.shape
                h.write(', '.join(map(str, shape)))
                h.write('\n')

        # using context manager means we kill off zombie pool objects
        # but does mean that the pool has to be specified each time.
        with possibly_create_pool(pool) as g, possibly_open_file(f, 'a') as h:
            # if you're not creating more than 1 thread, then don't bother with
            # a pool.
            if pool == 1:
                self.sampler.pool = None
            else:
                self.sampler.pool = g

            # these kwargs are provided to the sampler.sample method
            kwargs = {'iterations': steps,
                      'thin': nthin}

            # new emcee arguments
            sampler_args = getargspec(self.sampler.sample).args
            if 'progress' in sampler_args and verbose:
                kwargs['progress'] = True
                verbose = False

            if 'thin_by' in sampler_args:
                kwargs['thin_by'] = nthin
                kwargs.pop('thin', 0)

            # ptemcee returns coords, lnprob
            # emcee returns a State object
            if isinstance(self.sampler, PTSampler):
                for result in self.sampler.sample(self._state.coords,
                                                  **kwargs):
                    self._state = State(result[0],
                                        log_prob=result[1] + result[2],
                                        random_state=self.sampler._random)
                    _callback_wrapper(self._state, h=h)
            else:
                for state in self.sampler.sample(self._state,
                                                 **kwargs):
                    self._state = state
                    _callback_wrapper(state, h=h)

        self.sampler.pool = None

        # finish off the progress bar
        if verbose:
            sys.stdout.write("\n")

        # sets parameter value and stderr
        return process_chain(self.objective, self.chain)