Beispiel #1
0
    def refine_frequency(self, time, amplitude, guess, verbose=False):
        # set up to do do a mininzer fit to best freq
        p = ParamState(
            't',
            'y_true',
            a=1,
            b=1,
            f=guess,
        )
        p.given(t=time, y_true=amplitude)

        def model(p):
            return (p.a * np.sin(2 * np.pi * p.f * p.t) +
                    p.b * np.cos(2 * np.pi * p.f * p.t))

        def cost(args, p):
            p.ingest(args)
            err = model(p) - p.y_true
            energy = np.sum(err**2)
            return energy

        x0 = p.array
        xf = fmin_powell(cost, x0, args=(p, ), disp=verbose)
        p.ingest(xf)
        if (p.f - guess) > 3**2:
            raise ValueError(
                f'Guess freq: {self.f0}, Fit Freq: {p.f}  too far apart')
        return p.f
Beispiel #2
0
class Fitter:
    """
    This class is a convenience wrapper for scipy optimization.
    Look at Fitter.examples attribute to see usage.

    The constructor takes exactly the same arguments as ParamState.
    """
    OPTIMIZER_NAMES = {'fmin', 'fmin_powell', 'fmin_cg', 'fmin_bfgs'}

    DEFAULT_VERBOSE = True

    examples = examples()

    def __init__(self, *args, **kwargs):
        """
        *args and **kwargs define/initialize the model variables.
        They are passed directly to a ParamState constructor
        """
        from easier import ParamState
        self._params = ParamState(*args, **kwargs)
        self._algorithm = 'fmin'
        self._optimizer_kwargs = {}
        self._model = None
        self._givens = {}
        self._cost = self._default_cost
        self._verbose = self.DEFAULT_VERBOSE

    @property
    def all_params(self):
        """
        Show all parameters (even data variables)
        """
        p = self._params.clone()
        return p

    @property
    def params(self):
        """
        Show all params except data variables
        """
        p = self.all_params
        p.drop('x')
        p.drop('y')
        return p

    def extra(self, **kwargs):
        """
        Put extra attributes on the params object.  These params will
        be completely ignored by the optimizer.  This is good for passing
        utility functions to your model or cost fuctions.
        """
        for key, val in kwargs.items():
            setattr(self._params, key, val)

    def given(self, **kwargs):
        """
        Supply kwargs that will set constant variables.  These can either
        be new variables or already defined variables.  This allows you
        to easily turn on and off variables to optimize.
        """
        self._givens = kwargs
        return self

    def optimizer_kwargs(self, **kwargs):
        """
        Additional kwargs to pass to optimizer.
        """
        self._optimizer_kwargs = kwargs
        return self

    def _cost_wrapper(self, args, p):
        """
        Allows users to define their cost functions as only functions
        of p, the ParamState object.
        """
        p.ingest(args)
        return self._cost(p)

    def _default_cost(self, p):
        """
        Allows users to define their cost functions as only functions
        of p, the ParamState object.
        """
        import numpy as np
        if hasattr(p, 'weights'):
            w = p.weights
        else:
            w = np.ones_like(p.y)

        z = (self._model(p) - p.y) * w / np.sum(w)
        return np.sum(z**2)

    def _plotter(self, *args, **kwargs):
        """
        A function to plot fits in real time
        """
        import holoviews as hv
        import easier as ezr
        if kwargs['data']:
            xd, yd = kwargs['data'][-2:]

        else:
            xd, yd = [0], [0]

        fit_line = hv.Curve(*args, **kwargs)
        data = hv.Scatter((xd, yd))
        overlay = hv.Overlay([fit_line, data])

        overlay = overlay.opts(
            hv.opts.Curve(color=ezr.cc.b),
            hv.opts.Scatter(color=ezr.cc.a, size=5, alpha=.5))
        return overlay

    def _model_wrapper(self, model):
        """
        The model wrapper will wrap the model with a function
        that updates real time plots of fit if needed.
        """
        if model is None:
            return model

        if self.plot_every is None:
            return model
        else:
            self.plot_counter = 0

            def wrapped(*args, **kwargs):
                yfit = model(*args, **kwargs)
                if self.plot_counter % self.plot_every == 0:
                    self.pipe.send(
                        (self._params.x, yfit, self._params.x, self._params.y))
                self.plot_counter += 1
                return yfit

            return wrapped

    def fit(self,
            *,
            x=None,
            y=None,
            weights=None,
            model=None,
            cost=None,
            plot_every=None,
            algorithm='fmin',
            verbose=True):
        """
        The method to use for training the fitter.
        Args:
                     x: the indepenant data
                     y: the depenant data
                 model: a model function to fit the data to
                  cost: a custom cost function (defaults to least squares)
            plot_every: Plot solution in real time every this number of iterations
             algorithm: Scipy optimization routine to use.  Enter nonsense to see list of valid.
               verbose: Print convergence information
        """
        import numpy as np
        from scipy import optimize
        import holoviews as hv
        from holoviews.streams import Pipe
        from IPython.display import display
        self.plot_every = plot_every

        if algorithm not in self.OPTIMIZER_NAMES:
            raise ValueError(
                f'Invalid optimizer {algorithm}.  Choose one of {self.OPTIMIZER_NAMES}'
            )

        givens = dict(x=x, y=y)
        if weights is not None:
            givens.update({'weights': weights})
        givens.update(self._givens)

        self._params.given(**givens)

        # This stuff only needs to happen if we are iteratively plotting fits
        if plot_every is not None:
            x, y = self._params.x, self._params.y
            xmin, xmax = np.min(x), np.max(x)
            ymin, ymax = np.min(y), np.max(y)

            scale = .1
            delta_x = scale * (xmax - xmin)
            delta_y = scale * (ymax - ymin)

            xmin, xmax = xmin - delta_x, xmax + delta_x
            ymin, ymax = ymin - delta_y, ymax + delta_y

            xlim = (xmin, xmax)
            ylim = (ymin, ymax)
            self.pipe = Pipe(data=[])
            try:
                dmap = hv.DynamicMap(self._plotter, streams=[self.pipe])
                dmap.opts(hv.opts.Overlay(xlim=xlim, ylim=ylim))
                display(dmap)
            except AttributeError:
                raise RuntimeError(
                    'You must import holoviews and set bokeh backround for plotting to work'
                )

        if model is None and cost is None:
            raise ValueError(
                'You must supply either a model function or a cost function')

        self._raw_model = model
        self._model = self._model_wrapper(self._raw_model)

        a0 = self._params.array

        if cost is not None:
            self._cost = cost

        optimizer = getattr(optimize, algorithm)
        self._optimizer_kwargs.update(disp=verbose)
        a_fit = optimizer(self._cost_wrapper,
                          a0,
                          args=(self._params, ),
                          **self._optimizer_kwargs)
        a_fit = np.array(a_fit, ndmin=1)
        self._params.ingest(a_fit)
        return self

    def predict(self, x=None):
        """
        Returns an array of predictions based on trained models.  If
        x is not supplied, than the fit over x values in training set is used
        """
        if x is not None:
            p = self.all_params
            p.x = x
        else:
            p = self._params
        return self._raw_model(p)

    def df(self, x=None):
        import pandas as pd
        p = self._params

        if x is None:
            x = p.x

        y = self.predict(x)
        return pd.DataFrame({'x': x, 'y': y})

    def plot(self,
             *,
             x=None,
             scale_factor=1,
             label=None,
             line_color=None,
             scatter_color=None,
             size=10,
             xlabel='x',
             ylabel='y',
             as_components=False):
        """
        Draw plots for model fit results.
        Params:
            x: A custom x over which to draw fits
            scale_factor: Scale all y values by this factor
            label: A string with which to label the fit
            line_color: Color for fit line
            scatter_color: Color for data points
            size: size of scatter points
            xlabel: x axis label
            ylabel: y axis label
            as_components: if True, return chart components rather than overlay
        """
        import easier as ezr
        p = self._params

        if x is None:
            x = p.x

        line_color = line_color if line_color else ezr.cc.b
        scatter_color = scatter_color if scatter_color else ezr.cc.a

        import holoviews as hv

        label_val = label if label else 'Fit'

        try:
            scatter = hv.Scatter((p.x, scale_factor * p.y),
                                 xlabel,
                                 ylabel,
                                 label=label_val).options(color=scatter_color,
                                                          size=size,
                                                          alpha=.5)
        except Exception:
            raise RuntimeError(
                'You must import holoviews and set bokeh backround for plotting to work'
            )

        line = hv.Curve((x, scale_factor * self.predict(x)),
                        label=label_val).options(color=line_color)
        traces = [scatter, line]
        if as_components:
            return traces
        else:
            return hv.Overlay(traces)