Exemple #1
0
    def test_constant_smearing(self):
        # check that constant dq/q smearing is the same as point by point
        dqvals = 0.05 * self.qvals
        rff = ReflectModel(self.structure, quad_order="ultimate")
        calc = rff.model(self.qvals, x_err=dqvals)

        rff.dq = 5.0
        calc2 = rff.model(self.qvals)

        assert_allclose(calc, calc2, rtol=0.011)
Exemple #2
0
    def test_reflectivity_model(self):
        # test reflectivity calculation with values generated from Motofit
        rff = ReflectModel(self.structure, dq=0)

        # the default for number of threads should be -1
        assert rff.threads == -1

        model = rff.model(self.qvals)
        assert_almost_equal(model, self.rvals)
Exemple #3
0
def resolution_test(slabs, data, backend):
    structure = Structure()
    for i, slab in enumerate(slabs):
        m = SLD(complex(slab[1], slab[2]))
        structure |= m(slab[0], slab[-1])

    with use_reflect_backend(backend):
        model = ReflectModel(structure, bkg=0.0)
        model.quad_order = 17
        R = model.model(data[:, 0],
                        x_err=data[:, -1] * 2 * np.sqrt(2 * np.log(2)))
        np.testing.assert_allclose(R, data[:, 1], rtol=0.03)
Exemple #4
0
    def test_smearedabeles(self):
        # test smeared reflectivity calculation with values generated from
        # Motofit (quadrature precsion order = 13)
        theoretical = np.loadtxt(
            os.path.join(self.pth, "smeared_theoretical.txt"))
        qvals, rvals, dqvals = np.hsplit(theoretical, 3)
        """
        the order of the quadrature precision used to create these smeared
        values in Motofit was 13.
        Do the same here
        """
        rff = ReflectModel(self.structure, quad_order=13)
        calc = rff.model(qvals.flatten(), x_err=dqvals.flatten())

        assert_almost_equal(rvals.flatten(), calc)
Exemple #5
0
    def test_smearedabeles_reshape(self):
        # test smeared reflectivity calculation with values generated from
        # Motofit (quadrature precsion order = 13)
        theoretical = np.loadtxt(os.path.join(self.pth,
                                              'smeared_theoretical.txt'))
        qvals, rvals, dqvals = np.hsplit(theoretical, 3)
        '''
        the order of the quadrature precision used to create these smeared
        values in Motofit was 13.
        Do the same here
        '''
        reshaped_q = np.reshape(qvals, (2, 250))
        reshaped_r = np.reshape(rvals, (2, 250))
        reshaped_dq = np.reshape(dqvals, (2, 250))

        rff = ReflectModel(self.structure, quad_order=13)
        calc = rff.model(reshaped_q, x_err=reshaped_dq)

        assert_almost_equal(calc, reshaped_r)
class Motofit(object):
    """
    An interactive slab modeller (Jupyter/ipywidgets based) for Neutron and
    X-ray reflectometry data.

    The interactive modeller is designed to be used in a Jupyter notebook.

    >>> # specify that plots are in a separate graph window
    >>> %matplotlib qt

    >>> # alternately if you want the graph to be embedded in the notebook use
    >>> # %matplotlib notebook

    >>> from refnx.reflect import Motofit
    >>> # create an instance of the modeller
    >>> app = Motofit()
    >>> # display it in the notebook by calling the object with a datafile.
    >>> app('dataset1.txt')
    >>> # lets fit a different dataset
    >>> app2 = Motofit()
    >>> app2('dataset2.txt')

    The `Motofit` instance has several useful attributes that can be used in
    other cells. For example, one can access the `objective` and `curvefitter`
    attributes for more advanced fitting functionality than is available in the
    GUI. A `code` attribute can be used to retrieve a Python code fragment that
    can be used as a basis for developing more complicated models, such as
    interparameter constraints, global fitting, etc.

    Attributes
    ----------
    dataset: :class:`refnx.dataset.Data1D`
        The dataset associated with the modeller
    model: :class:`refnx.reflect.ReflectModel`
        Calculates a theoretical model, from an interfacial structure
        (`model.Structure`).
    objective: :class:`refnx.analysis.Objective`
        The Objective that allows one to compare the model against the data.
    fig: :class:`matplotlib.figure.Figure`
        Graph displaying the data.

    """
    def __init__(self):
        # attributes for the graph
        # for the graph
        self.qmin = 0.005
        self.qmax = 0.5
        self.qpnt = 1000
        self.fig = None

        self.ax_data = None
        self.ax_residual = None
        self.ax_sld = None
        # gridspecs specify how the plots are laid out. Gridspec1 is when the
        # residuals plot is displayed. Gridspec2 is when it's not visible
        self._gridspec1 = gridspec.GridSpec(2,
                                            2,
                                            height_ratios=[5, 1],
                                            width_ratios=[1, 1],
                                            hspace=0.01)
        self._gridspec2 = gridspec.GridSpec(1, 2)

        self.theoretical_plot = None
        self.theoretical_plot_sld = None

        # attributes for a user dataset
        self.dataset = None
        self.objective = None
        self._curvefitter = None
        self.data_plot = None
        self.residuals_plot = None
        self.data_plot_sld = None

        self.dataset_name = widgets.Text(description="dataset:")
        self.dataset_name.disabled = True
        self.chisqr = widgets.FloatText(description="chi-squared:")
        self.chisqr.disabled = True

        # fronting
        slab0 = Slab(0, 0, 0)
        slab1 = Slab(25, 3.47, 3)
        slab2 = Slab(0, 2.07, 3)

        structure = slab0 | slab1 | slab2
        rename_params(structure)
        self.model = ReflectModel(structure)
        structure = slab0 | slab1 | slab2
        self.model = ReflectModel(structure)

        # give some default parameter limits
        self.model.scale.bounds = (0.1, 2)
        self.model.bkg.bounds = (1e-8, 2e-5)
        self.model.dq.bounds = (0, 20)
        for slab in self.model.structure:
            slab.thick.bounds = (0, 2 * slab.thick.value)
            slab.sld.real.bounds = (0, 2 * slab.sld.real.value)
            slab.sld.imag.bounds = (0, 2 * slab.sld.imag.value)
            slab.rough.bounds = (0, 2 * slab.rough.value)

        # the main GUI widget
        self.display_box = widgets.VBox()

        self.tab = widgets.Tab()
        self.tab.set_title(0, "Model")
        self.tab.set_title(1, "Limits")
        self.tab.set_title(2, "Options")
        self.tab.observe(self._on_tab_changed, names="selected_index")

        # an output area for messages.
        self.output = widgets.Output()

        # options tab
        self.plot_type = widgets.Dropdown(
            options=["lin", "logY", "YX4", "YX2"],
            value="lin",
            description="Plot Type:",
            disabled=False,
        )
        self.plot_type.observe(self._on_plot_type_changed, names="value")
        self.use_weights = widgets.RadioButtons(
            options=["Yes", "No"],
            value="Yes",
            description="use dataset weights?",
            style={"description_width": "initial"},
        )
        self.use_weights.observe(self._on_use_weights_changed, names="value")
        self.transform = Transform("lin")
        self.display_residuals = widgets.Checkbox(
            value=False, description="Display residuals")
        self.display_residuals.observe(self._on_display_residuals_changed,
                                       names="value")

        self.model_view = None
        self.set_model(self.model)

    def save_model(self, *args, f=None):
        """
        Serialise a model to a pickle file.
        If `f` is not specified then the file name is constructed from the
        current dataset name; if there is no current dataset then the filename
        is constructed from the current time. These constructed filenames will
        be in the current working directory, for a specific save location `f`
        must be provided.
        This method is only intended to be used to serialise models created by
        this interactive Jupyter widget modeller.

        Parameters
        ----------
        f: file like or str, optional
            File to save model to.
        """
        if f is None:
            f = "model_" + datetime.datetime.now().isoformat() + ".pkl"
            if self.dataset is not None:
                f = "model_" + self.dataset.name + ".pkl"

        with possibly_open_file(f) as g:
            pickle.dump(self.model, g)

    def load_model(self, *args, f=None):
        """
        Load a serialised model.
        If `f` is not specified then an attempt will be made to find a model
        corresponding to the current dataset name,
        `'model_' + self.dataset.name + '.pkl'`. If there is no current
        dataset then the most recent model will be loaded.
        This method is only intended to be used to deserialise models created
        by this interactive Jupyter widget modeller, and will not successfully
        load complicated ReflectModel created outside of the interactive
        modeller.

        Parameters
        ----------
        f: file like or str, optional
            pickle file to load model from.
        """
        if f is None and self.dataset is not None:
            # try and load the model corresponding to the current dataset
            f = "model_" + self.dataset.name + ".pkl"
        elif f is None:
            # load the most recent model file
            files = list(filter(os.path.isfile, glob.glob("model_*.pkl")))
            files.sort(key=lambda x: os.path.getmtime(x))
            files.reverse()
            if len(files):
                f = files[0]

        if f is None:
            self._print("No model file is specified/available.")
            return

        try:
            with possibly_open_file(f, "rb") as g:
                reflect_model = pickle.load(g)
            self.set_model(reflect_model)
        except (RuntimeError, FileNotFoundError) as exc:
            # RuntimeError if the file isn't a ReflectModel
            # FileNotFoundError if the specified file name wasn't found
            self._print(repr(exc), repr(f))

    def set_model(self, model):
        """
        Change the `refnx.reflect.ReflectModel` associated with the `Motofit`
        instance.

        Parameters
        ----------
        model: refnx.reflect.ReflectModel

        """
        if not isinstance(model, ReflectModel):
            raise RuntimeError("`model` was not an instance of ReflectModel")

        if self.model_view is not None:
            self.model_view.unobserve_all()

        # figure out if the reflect_model is a different instance. If it is
        # then the objective has to be updated.
        if model is not self.model:
            self.model = model
            self._update_analysis_objects()

        self.model = model

        self.model_view = ReflectModelView(self.model)
        self.model_view.observe(self.update_model, names=["view_changed"])
        self.model_view.observe(self.redraw, names=["view_redraw"])

        # observe when the number of varying parameters changed. This
        # invalidates a curvefitter, and a new one has to be produced.
        self.model_view.observe(self._on_num_varying_changed,
                                names=["num_varying"])

        self.model_view.do_fit_button.on_click(self.do_fit)
        self.model_view.to_code_button.on_click(self._to_code)
        self.model_view.save_model_button.on_click(self.save_model)
        self.model_view.load_model_button.on_click(self.load_model)

        self.redraw(None)

    def update_model(self, change):
        """
        Updates the plots when the parameters change

        Parameters
        ----------
        change

        """
        if not self.fig:
            return

        q = np.linspace(self.qmin, self.qmax, self.qpnt)
        theoretical = self.model.model(q)
        yt, _ = self.transform(q, theoretical)

        sld_profile = self.model.structure.sld_profile()
        z, sld = sld_profile
        if self.theoretical_plot is not None:
            self.theoretical_plot.set_data(q, yt)

            self.theoretical_plot_sld.set_data(z, sld)
            self.ax_sld.relim()
            self.ax_sld.autoscale_view()

        if self.dataset is not None:
            # if there's a dataset loaded then residuals_plot
            # should exist
            residuals = self.objective.residuals()
            self.chisqr.value = np.sum(residuals**2)

            self.residuals_plot.set_data(self.dataset.x, residuals)
            self.ax_residual.relim()
            self.ax_residual.autoscale_view()

        self.fig.canvas.draw()

    def _on_num_varying_changed(self, change):
        # observe when the number of varying parameters changed. This
        # invalidates a curvefitter, and a new one has to be produced.
        if change["new"] != change["old"]:
            self._curvefitter = None

    def _update_analysis_objects(self):
        use_weights = self.use_weights.value == "Yes"
        self.objective = Objective(
            self.model,
            self.dataset,
            transform=self.transform,
            use_weights=use_weights,
        )
        self._curvefitter = None

    def __call__(self, data=None, model=None):
        """
        Display the `Motofit` GUI in a Jupyter notebook cell.

        Parameters
        ----------
        data: refnx.dataset.Data1D
            The dataset to associate with the `Motofit` instance.

        model: refnx.reflect.ReflectModel or str or file-like
            A model to associate with the data.
            If `model` is a `str` or `file`-like then the `load_model` method
            will be used to try and load the model from file. This assumes that
            the file is a pickle of a `ReflectModel`
        """
        # the theoretical model
        # display the main graph
        import matplotlib.pyplot as plt

        self.fig = plt.figure(figsize=(9, 4))

        # grid specs depending on whether the residuals are displayed
        if self.display_residuals.value:
            d_gs = self._gridspec1[0, 0]
            sld_gs = self._gridspec1[:, 1]
        else:
            d_gs = self._gridspec2[0, 0]
            sld_gs = self._gridspec2[0, 1]

        self.ax_data = self.fig.add_subplot(d_gs)
        self.ax_data.set_xlabel(r"$Q/\AA^{-1}$")
        self.ax_data.set_ylabel("Reflectivity")

        self.ax_data.grid(True, color="b", linestyle="--", linewidth=0.1)

        self.ax_sld = self.fig.add_subplot(sld_gs)
        self.ax_sld.set_ylabel(r"$\rho/10^{-6}\AA^{-2}$")
        self.ax_sld.set_xlabel(r"$z/\AA$")

        self.ax_residual = self.fig.add_subplot(self._gridspec1[1, 0],
                                                sharex=self.ax_data)
        self.ax_residual.set_xlabel(r"$Q/\AA^{-1}$")
        self.ax_residual.grid(True, color="b", linestyle="--", linewidth=0.1)
        self.ax_residual.set_visible(self.display_residuals.value)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            self.fig.tight_layout()

        q = np.linspace(self.qmin, self.qmax, self.qpnt)
        theoretical = self.model.model(q)
        yt, _ = self.transform(q, theoretical)

        self.theoretical_plot = self.ax_data.plot(q, yt, zorder=2)[0]
        self.ax_data.set_yscale("log")

        z, sld = self.model.structure.sld_profile()
        self.theoretical_plot_sld = self.ax_sld.plot(z, sld)[0]

        # the figure has been reset, so remove ref to the data_plot,
        # residual_plot
        self.data_plot = None
        self.residuals_plot = None

        self.dataset = None
        if data is not None:
            self.load_data(data)

        if isinstance(model, ReflectModel):
            self.set_model(model)
            return self.display_box
        elif model is not None:
            self.load_model(model)
            return self.display_box

        self.redraw(None)
        return self.display_box

    def load_data(self, data):
        """
        Load a dataset into the `Motofit` instance.

        Parameters
        ----------
        data: refnx.dataset.Data1D, or str, or file-like
        """
        if isinstance(data, ReflectDataset):
            self.dataset = data
        else:
            self.dataset = ReflectDataset(data)

        self.dataset_name.value = self.dataset.name

        # loading a dataset changes the objective and curvefitter
        self._update_analysis_objects()

        self.qmin = np.min(self.dataset.x)
        self.qmax = np.max(self.dataset.x)
        if self.fig is not None:
            yt, et = self.transform(self.dataset.x, self.dataset.y)

            if self.data_plot is None:
                (self.data_plot, ) = self.ax_data.plot(
                    self.dataset.x,
                    yt,
                    label=self.dataset.name,
                    ms=2,
                    marker="o",
                    ls="",
                    zorder=1,
                )
                self.data_plot.set_label(self.dataset.name)
                self.ax_data.legend()

                # no need to calculate residuals here, that'll be updated in
                # the redraw method
                (self.residuals_plot, ) = self.ax_residual.plot(self.dataset.x)
            else:
                self.data_plot.set_xdata(self.dataset.x)
                self.data_plot.set_ydata(yt)

            # calculate theoretical model over same range as data
            # use redraw over update_model because it ensures chi2 widget gets
            # displayed
            self.redraw(None)
            self.ax_data.relim()
            self.ax_data.autoscale_view()
            self.ax_residual.relim()
            self.ax_residual.autoscale_view()
            self.fig.canvas.draw()

    def redraw(self, change):
        """
        Redraw the Jupyter GUI associated with the `Motofit` instance.
        """
        self._update_display_box(self.display_box)
        self.update_model(None)

    @property
    def curvefitter(self):
        """
        class:`CurveFitter` : Object for fitting the data based on the
        objective.
        """
        if self.objective is not None and self._curvefitter is None:
            self._curvefitter = CurveFitter(self.objective)

        return self._curvefitter

    def _print(self, string):
        """
        Print to the output widget
        """
        from IPython.display import clear_output

        with self.output:
            clear_output()
            print(string)

    def do_fit(self, *args):
        """
        Ask the Motofit object to perform a fit (differential evolution).

        Parameters
        ----------
        change

        Notes
        -----
        After performing the fit the Jupyter display is updated.

        """
        if self.dataset is None:
            return

        if not self.model.parameters.varying_parameters():
            self._print("No parameters are being varied")
            return

        try:
            logp = self.objective.logp()
            if not np.isfinite(logp):
                self._print("One of your parameter values lies outside its"
                            " bounds. Please adjust the value, or the bounds.")
                return
        except ZeroDivisionError:
            self._print("One parameter has equal lower and upper bounds."
                        " Either alter the bounds, or don't let that"
                        " parameter vary.")
            return

        def callback(xk, convergence):
            self.chisqr.value = self.objective.chisqr(xk)

        self.curvefitter.fit("differential_evolution", callback=callback)

        # need to update the widgets as the model will be updated.
        # this also redraws GUI.
        # self.model_view.refresh()
        self.set_model(self.model)

        self._print(str(self.objective))

    def _to_code(self, change=None):
        self._print(self.code)

    @property
    def code(self):
        """
        str : A Python code fragment capable of fitting the data.
        Executable Python code fragment for the GUI model.
        """
        if self.objective is None:
            self._update_analysis_objects()

        return to_code(self.objective)

    def _on_tab_changed(self, change):
        pass

    def _on_plot_type_changed(self, change):
        """
        User would like to plot and fit as logR/linR/RQ4/RQ2, etc
        """
        self.transform = Transform(change["new"])
        if self.objective is not None:
            self.objective.transform = self.transform

        if self.dataset is not None:
            yt, _ = self.transform(self.dataset.x, self.dataset.y)

            self.data_plot.set_xdata(self.dataset.x)
            self.data_plot.set_ydata(yt)

        self.update_model(None)

        # probably have to change LHS axis of the data plot when
        # going between different plot types.
        if change["new"] == "logY":
            self.ax_data.set_yscale("linear")
        else:
            self.ax_data.set_yscale("log")

        self.ax_data.relim()
        self.ax_data.autoscale_view()
        self.fig.canvas.draw()

    def _on_use_weights_changed(self, change):
        self._update_analysis_objects()
        self.update_model(None)

    def _on_display_residuals_changed(self, change):
        import matplotlib.pyplot as plt

        if change["new"]:
            self.ax_residual.set_visible(True)
            self.ax_data.set_position(self._gridspec1[0, 0].get_position(
                self.fig))
            self.ax_sld.set_position(self._gridspec1[:,
                                                     1].get_position(self.fig))
            plt.setp(self.ax_data.get_xticklabels(), visible=False)
        else:
            self.ax_residual.set_visible(False)
            self.ax_data.set_position(self._gridspec2[:, 0].get_position(
                self.fig))
            self.ax_sld.set_position(self._gridspec2[:,
                                                     1].get_position(self.fig))
            plt.setp(self.ax_data.get_xticklabels(), visible=True)

    @property
    def _options_box(self):
        return widgets.VBox(
            [self.plot_type, self.use_weights, self.display_residuals])

    def _update_display_box(self, box):
        """
        Redraw the Jupyter GUI associated with the `Motofit` instance
        """
        vbox_widgets = []

        if self.dataset is not None:
            vbox_widgets.append(widgets.HBox([self.dataset_name, self.chisqr]))

        self.tab.children = [
            self.model_view.model_box,
            self.model_view.limits_box,
            self._options_box,
        ]

        vbox_widgets.append(self.tab)
        vbox_widgets.append(self.output)
        box.children = tuple(vbox_widgets)
Exemple #7
0
model = ReflectModel(structure, bkg=9e-6, scale=1.)
model.bkg.setp(vary=True, bounds=(1e-8, 1e-5))
model.scale.setp(vary=True, bounds=(0.9, 1.1))

# fit on a logR scale, but use weighting
objective = Objective(model,
                      data,
                      transform=Transform('logY'),
                      use_weights=True)

# create the fit instance
fitter = CurveFitter(objective)

# do the fit
res = fitter.fit(method='differential_evolution')

# see the fit results
print(objective)

fig = plt.figure()
ax = fig.add_subplot(2, 1, 1)
ax.scatter(data.x, data.y, label=DATASET_NAME)
ax.semilogy()
ax.plot(data.x, model.model(data.x, x_err=data.x_err), label='fit')
plt.xlabel('Q')
plt.ylabel('logR')
plt.legend()
ax2 = fig.add_subplot(2, 1, 2)
z, rho_z = structure.sld_profile()
ax2.plot(z, rho_z)
Exemple #8
0
def plot_distmodel(objective, refl_mode='rq4', maxd=1000):
    """
    Plot a distribution model for maximum introspection.

    Parameters
    ----------
    objective : refnx.analysis.Objective
        An objective containing a MetaModel, configured in the arbitrary
        method used by Gresham circa 2019.
    refl_mode : string, optional
        The method for plotting the reflectometry profiles, either 'log' or
        'rq4'. The default is 'rq4'.
    maxd : float, optional
        The maximum separation at which the distirubtion is plotted.
        The default is 1000.

    Returns
    -------
    None.

    """
    fig, [ax1, ax2, ax3] = plt.subplots(1, 3, figsize=(10, 3), dpi=150)

    q = objective.data.x
    r = objective.data.y
    r_err = objective.data.y_err

    if refl_mode == 'rq4':
        q4 = q**4
        txt_loc2 = 'bottom'
    else:
        q4 = 1
        txt_loc2 = 'top'

    if type(objective.model) == MetaModel:
        distmodel = objective.model.models[0]
        distscale = objective.model.scales[0]
        h2omodel = objective.model.models[1]
        h2oscale = objective.model.scales[1]

        ax1.plot(*h2omodel.structure.sld_profile(), color='b', alpha=1, lw=1)
        ax2.plot(q, h2omodel(q) * q4 * h2oscale, color='b', alpha=1, lw=1)
        ax2.plot(q, distmodel(q) * q4 * distscale, color='red', alpha=1, lw=1)

        scales = objective.model.scales
        ax2.text(0.95,
                 0.23,
                 'mScales: %0.3f, %0.3f' % (scales[0].value, scales[1].value),
                 ha='right',
                 va=txt_loc2,
                 size='small',
                 transform=ax2.transAxes)

    else:
        distmodel = objective.model
        distscale = 1

    maxscale = np.max(distmodel.scales)

    # Need to call the model to refresh parameters
    objective.model(q)

    d = np.linspace(0, maxd, 5000)
    pdf = distmodel.pdf(d, **distmodel.pdf_kwargs)
    ax3.plot(d, pdf)

    for struct, scale in zip(distmodel.structures, distmodel.scales):
        normscale = np.min([1, np.max([scale / maxscale, 0.001])])

        ax1.plot(*struct.sld_profile(),
                 color='xkcd:crimson',
                 alpha=normscale,
                 lw=1)

        dummy_model = ReflectModel(struct, bkg=objective.model.bkg.value)

        ax2.plot(q,
                 dummy_model.model(q) * q4 * normscale * distscale,
                 alpha=normscale * 0.5,
                 color='xkcd:crimson',
                 lw=1)

        thick = struct[2].thick.value

        ax3.scatter(thick,
                    np.interp(thick, d, pdf),
                    marker='.',
                    color='k',
                    alpha=normscale)

    ax2.plot(q, objective.model(q) * q4, color='k', alpha=1)
    ax2.errorbar(q, r * q4, yerr=r_err * q4, color='b', alpha=0.5)
    ax2.set_yscale('log')

    ax1.set_xlabel('Thickness, $\mathrm{\AA}$')
    ax1.set_ylabel('SLD, $\mathrm{\AA}^{-2}$')
    ax2.set_xlabel('$Q$, $\mathrm{\AA}^{-1}$')
    ax2.set_ylabel('$R$')
    ax3.set_xlabel('Thickness, $\mathrm{\AA}$')

    kwargs = distmodel.pdf_kwargs
    for i, key in enumerate(kwargs):
        ax3.text(0.95,
                 0.95 - 0.06 * i,
                 '%s: %0.4f' % (key, kwargs[key]),
                 ha='right',
                 va='top',
                 size='small',
                 transform=ax3.transAxes)

    i = 0
    for p in distmodel.master_structure.parameters.flattened():
        if p.vary is True:
            ax1.text(0.95,
                     0.05 + 0.06 * i,
                     '%s: %0.3f' % (p.name, p.value),
                     ha='right',
                     va='bottom',
                     size='small',
                     transform=ax1.transAxes)
            i += 1

    ax2.text(0.95,
             0.17,
             'background: %d' % objective.model.bkg.value,
             ha='right',
             va=txt_loc2,
             size='small',
             transform=ax2.transAxes)
    #    ax2.text(0.95, 0.05, 'lnprob: %d' % (objective.logpost()),
    #             ha='right', va=txt_loc2, size='small',
    #             transform=ax2.transAxes)
    ax2.text(0.95,
             0.11,
             'chisqr: %d' % (objective.chisqr()),
             ha='right',
             va=txt_loc2,
             size='small',
             transform=ax2.transAxes)

    ax1.set_xbound(-50, maxd)
    fig.tight_layout()
Exemple #9
0
 def test_reflectivity_model(self):
     # test reflectivity calculation with values generated from Motofit
     rff = ReflectModel(self.structure, dq=0)
     model = rff.model(self.qvals)
     assert_almost_equal(model, self.rvals)
Exemple #10
0
class Motofit(object):
    """
    An interactive slab modeller (Jupyter/ipywidgets based) for Neutron and
    X-ray reflectometry data.

    The interactive modeller is designed to be used in a Jupyter notebook.

    Usage
    -----

    >>> # specify that plots are in a separate graph window
    >>> %matplotlib qt

    >>> # alternately if you want the graph to be embedded in the notebook use
    >>> # %matplotlib notebook

    >>> from refnx.reflect import Motofit
    >>> # create an instance of the modeller
    >>> app = Motofit()
    >>> # display it in the notebook by calling the object with a datafile.
    >>> app('dataset1.txt')
    >>> # lets fit a different dataset
    >>> app2 = Motofit()
    >>> app2('dataset2.txt')

    The `Motofit` instance has several useful attributes that can be used in
    other cells. For example, one can access the `objective` and `curvefitter`
    attributes for more advanced fitting functionality than is available in the
    GUI. A `code` attribute can be used to retrieve a Python code fragment that
    can be used as a basis for developing more complicated models, such as
    interparameter constraints, global fitting, etc.

    Attributes
    ----------
    dataset: refnx.reflect.Data1D
        The dataset associated with the modeller
    model: refnx.reflect.ReflectModel
        Calculates a theoretical model, from an interfacial structure
        (`model.Structure`).
    objective: refnx.analysis.Objective
        The Objective that allows one to compare the model against the data.
    curvefitter: refnx.analysis.CurveFitter
        Object for fitting the data based on the objective.
    fig: matplotlib.Figure
        Graph displaying the data.
    code: str
        A Python code fragment capable of fitting the data.

    Methods
    -------
    __call__ - display the GUI in a Jupyter cell
    save_model - save the current model to a pickle file
    load_model - load a pickle file and set it as the current file
    set_model - use an existing `refnx.reflect.ReflectModel` to set the GUI
                model
    load_data - load a dataset
    do_fit - do a fit
    redraw - Update the notebook cell containing the GUI
    """
    def __init__(self):
        # attributes for the graph
        # for the graph
        self.qmin = 0.005
        self.qmax = 0.5
        self.qpnt = 1000
        self.fig = None

        self.ax_data = None
        self.ax_residual = None
        self.ax_sld = None
        # gridspecs specify how the plots are laid out. Gridspec1 is when the
        # residuals plot is displayed. Gridspec2 is when it's not visible
        self._gridspec1 = gridspec.GridSpec(2,
                                            2,
                                            height_ratios=[5, 1],
                                            width_ratios=[1, 1],
                                            hspace=0.01)
        self._gridspec2 = gridspec.GridSpec(1, 2)

        self.theoretical_plot = None
        self.theoretical_plot_sld = None

        # attributes for a user dataset
        self.dataset = None
        self.objective = None
        self._curvefitter = None
        self.data_plot = None
        self.residuals_plot = None
        self.data_plot_sld = None

        self.dataset_name = widgets.Text(description='dataset:')
        self.dataset_name.disabled = True
        self.chisqr = widgets.FloatText(description='chi-squared:')
        self.chisqr.disabled = True

        # fronting
        slab0 = Slab(0, 0, 0)
        slab1 = Slab(25, 3.47, 3)
        slab2 = Slab(0, 2.07, 3)

        structure = slab0 | slab1 | slab2
        rename_params(structure)
        self.model = ReflectModel(structure)
        structure = slab0 | slab1 | slab2
        self.model = ReflectModel(structure)

        # give some default parameter limits
        self.model.scale.bounds = (0.1, 2)
        self.model.bkg.bounds = (1e-8, 2e-5)
        self.model.dq.bounds = (0, 20)
        for slab in self.model.structure:
            slab.thick.bounds = (0, 2 * slab.thick.value)
            slab.sld.real.bounds = (0, 2 * slab.sld.real.value)
            slab.sld.imag.bounds = (0, 2 * slab.sld.imag.value)
            slab.rough.bounds = (0, 2 * slab.rough.value)

        # the main GUI widget
        self.display_box = widgets.VBox()

        self.tab = widgets.Tab()
        self.tab.set_title(0, 'Model')
        self.tab.set_title(1, 'Limits')
        self.tab.set_title(2, 'Options')
        self.tab.observe(self._on_tab_changed, names='selected_index')

        # an output area for messages.
        self.output = widgets.Output()

        # options tab
        self.plot_type = widgets.Dropdown(
            options=['lin', 'logY', 'YX4', 'YX2'],
            value='lin',
            description='Plot Type:',
            disabled=False)
        self.plot_type.observe(self._on_plot_type_changed, names='value')
        self.use_weights = widgets.RadioButtons(
            options=['Yes', 'No'],
            value='Yes',
            description='use dataset weights?',
            style={'description_width': 'initial'})
        self.use_weights.observe(self._on_use_weights_changed, names='value')
        self.transform = Transform('lin')
        self.display_residuals = widgets.Checkbox(
            value=False, description='Display residuals')
        self.display_residuals.observe(self._on_display_residuals_changed,
                                       names='value')

        self.model_view = None
        self.set_model(self.model)

    def save_model(self, f=None):
        """
        Serialise a model to a pickle file.

        Parameters
        ----------
        f: file like or str
            File to save model to.
        """
        if f is None:
            f = 'model_' + datetime.datetime.now().isoformat() + '.pkl'
            if self.dataset is not None:
                f = 'model_' + self.dataset.name + '.pkl'

        with possibly_open_file(f) as g:
            pickle.dump(self.model, g)

    def load_model(self, f):
        """
        Load a serialised model.

        Parameters
        ----------
        f: file like or str
            pickle file to load model from.
        """
        with possibly_open_file(f) as g:
            reflect_model = pickle.load(g)
            self.set_model(reflect_model)
        self._print(repr(self.objective))

    def set_model(self, model):
        """
        Change the `refnx.reflect.ReflectModel` associated with the `Motofit`
        instance.

        Parameters
        ----------
        model: refnx.reflect.ReflectModel

        """
        if self.model_view is not None:
            self.model_view.unobserve_all()

        # figure out if the reflect_model is a different instance. If it is
        # then the objective has to be updated.
        if model is not self.model:
            self.model = model
            self._update_analysis_objects()

        self.model = model

        self.model_view = ReflectModelView(self.model)
        self.model_view.observe(self.update_model, names=['view_changed'])
        self.model_view.observe(self.redraw, names=['view_redraw'])

        # observe when the number of varying parameters changed. This
        # invalidates a curvefitter, and a new one has to be produced.
        self.model_view.observe(self._on_num_varying_changed,
                                names=['num_varying'])

        self.model_view.do_fit_button.on_click(self.do_fit)
        self.model_view.to_code_button.on_click(self._to_code)

        self.redraw(None)

    def update_model(self, change):
        """
        Updates the plots when the parameters change

        Parameters
        ----------
        change

        """
        if not self.fig:
            return

        q = np.linspace(self.qmin, self.qmax, self.qpnt)
        theoretical = self.model.model(q)
        yt, _ = self.transform(q, theoretical)

        sld_profile = self.model.structure.sld_profile()
        z, sld = sld_profile
        if self.theoretical_plot is not None:
            self.theoretical_plot.set_xdata(q)
            self.theoretical_plot.set_ydata(yt)

            self.theoretical_plot_sld.set_xdata(z)
            self.theoretical_plot_sld.set_ydata(sld)
            self.ax_sld.relim()
            self.ax_sld.autoscale_view()

        if self.dataset is not None:
            # if there's a dataset loaded then residuals_plot
            # should exist
            residuals = self.objective.residuals()
            self.chisqr.value = np.sum(residuals**2)

            self.residuals_plot.set_xdata(self.dataset.x)
            self.residuals_plot.set_ydata(residuals)
            self.ax_residual.relim()
            self.ax_residual.autoscale_view()

        self.fig.canvas.draw()

    def _on_num_varying_changed(self, change):
        # observe when the number of varying parameters changed. This
        # invalidates a curvefitter, and a new one has to be produced.
        if change['new'] != change['old']:
            self._curvefitter = None

    def _update_analysis_objects(self):
        use_weights = self.use_weights.value == 'Yes'
        self.objective = Objective(self.model,
                                   self.dataset,
                                   transform=self.transform,
                                   use_weights=use_weights)
        self._curvefitter = None

    def __call__(self, data=None, model=None):
        """
        Display the `Motofit` GUI in a Jupyter notebook cell.

        Parameters
        ----------
        data: refnx.dataset.Data1D
            The dataset to associate with the `Motofit` instance.

        model: refnx.reflect.ReflectModel or str or file-like
            A model to associate with the data.
            If `model` is a `str` or `file`-like then the `load_model` method
            will be used to try and load the model from file. This assumes that
            the file is a pickle of a `ReflectModel`
        """
        # the theoretical model
        # display the main graph
        self.fig = plt.figure(figsize=(9, 4))

        # grid specs depending on whether the residuals are displayed
        if self.display_residuals.value:
            d_gs = self._gridspec1[0, 0]
            sld_gs = self._gridspec1[:, 1]
        else:
            d_gs = self._gridspec2[0, 0]
            sld_gs = self._gridspec2[0, 1]

        self.ax_data = self.fig.add_subplot(d_gs)
        self.ax_data.set_xlabel('$Q/\AA^{-1}$')
        self.ax_data.set_ylabel('Reflectivity')

        self.ax_data.grid(True, color='b', linestyle='--', linewidth=0.1)

        self.ax_sld = self.fig.add_subplot(sld_gs)
        self.ax_sld.set_ylabel('$\\rho/10^{-6}\AA^{-2}$')
        self.ax_sld.set_xlabel('$z/\AA$')

        self.ax_residual = self.fig.add_subplot(self._gridspec1[1, 0],
                                                sharex=self.ax_data)
        self.ax_residual.set_xlabel('$Q/\AA^{-1}$')
        self.ax_residual.grid(True, color='b', linestyle='--', linewidth=0.1)
        self.ax_residual.set_visible(self.display_residuals.value)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            self.fig.tight_layout()

        q = np.linspace(self.qmin, self.qmax, self.qpnt)
        theoretical = self.model.model(q)
        yt, _ = self.transform(q, theoretical)

        self.theoretical_plot = self.ax_data.plot(q, yt, zorder=2)[0]
        self.ax_data.set_yscale('log')

        z, sld = self.model.structure.sld_profile()
        self.theoretical_plot_sld = self.ax_sld.plot(z, sld)[0]

        # the figure has been reset, so remove ref to the data_plot,
        # residual_plot
        self.data_plot = None
        self.residuals_plot = None

        self.dataset = None
        if data is not None:
            self.load_data(data)

        if isinstance(model, ReflectModel):
            self.set_model(model)
            return self.display_box
        elif model is not None:
            self.load_model(model)
            return self.display_box

        self.redraw(None)
        return self.display_box

    def load_data(self, data):
        """
        Load a dataset into the `Motofit` instance.

        Parameters
        ----------
        data: refnx.dataset.Data1D, or str, or file-like
        """
        if isinstance(data, ReflectDataset):
            self.dataset = data
        else:
            self.dataset = ReflectDataset(data)

        self.dataset_name.value = self.dataset.name

        # loading a dataset changes the objective and curvefitter
        self._update_analysis_objects()

        self.qmin = np.min(self.dataset.x)
        self.qmax = np.max(self.dataset.x)
        if self.fig is not None:
            yt, et = self.transform(self.dataset.x, self.dataset.y)

            if self.data_plot is None:
                self.data_plot, = self.ax_data.plot(self.dataset.x,
                                                    yt,
                                                    label=self.dataset.name,
                                                    ms=2,
                                                    marker='o',
                                                    ls='',
                                                    zorder=1)
                self.data_plot.set_label(self.dataset.name)
                self.ax_data.legend()

                # no need to calculate residuals here, that'll be updated in
                # the redraw method
                self.residuals_plot, = self.ax_residual.plot(self.dataset.x)
            else:
                self.data_plot.set_xdata(self.dataset.x)
                self.data_plot.set_ydata(yt)

            # calculate theoretical model over same range as data
            # use redraw over update_model because it ensures chi2 widget gets
            # displayed
            self.redraw(None)
            self.ax_data.relim()
            self.ax_data.autoscale_view()
            self.ax_residual.relim()
            self.ax_residual.autoscale_view()
            self.fig.canvas.draw()

    def redraw(self, change):
        """
        Redraw the Jupyter GUI associated with the `Motofit` instance.
        """
        self._update_display_box(self.display_box)
        self.update_model(None)

    @property
    def curvefitter(self):
        if self.objective is not None and self._curvefitter is None:
            self._curvefitter = CurveFitter(self.objective)

        return self._curvefitter

    def _print(self, string):
        """
        Print to the output widget
        """
        with self.output:
            clear_output()
            print(string)

    def do_fit(self, change=None):
        """
        Ask the Motofit object to perform a fit (differential evolution).

        Parameters
        ----------
        change

        Notes
        -----
        After performing the fit the Jupyter display is updated.

        """
        if self.dataset is None:
            return

        if not self.model.parameters.varying_parameters():
            self._print("No parameters are being varied")
            return

        try:
            lnprior = self.objective.lnprior()
            if not np.isfinite(lnprior):
                self._print("One of your parameter values lies outside its"
                            " bounds. Please adjust the value, or the bounds.")
                return
        except ZeroDivisionError:
            self._print("One parameter has equal lower and upper bounds."
                        " Either alter the bounds, or don't let that"
                        " parameter vary.")
            return

        def callback(xk, convergence):
            self.chisqr.value = self.objective.chisqr(xk)

        self.curvefitter.fit('differential_evolution', callback=callback)

        # need to update the widgets as the model will be updated.
        # this also redraws GUI.
        # self.model_view.refresh()
        self.set_model(self.model)

        self._print(repr(self.objective))

    def _to_code(self, change=None):
        self._print(self.code)

    @property
    def code(self):
        """
        Executable Python code fragment for the GUI model.
        """
        if self.objective is None:
            self._update_analysis_objects()

        return to_code(self.objective)

    def _on_tab_changed(self, change):
        pass

    def _on_plot_type_changed(self, change):
        """
        User would like to plot and fit as logR/linR/RQ4/RQ2, etc
        """
        self.transform = Transform(change['new'])
        if self.objective is not None:
            self.objective.transform = self.transform

        if self.dataset is not None:
            yt, _ = self.transform(self.dataset.x, self.dataset.y)

            self.data_plot.set_xdata(self.dataset.x)
            self.data_plot.set_ydata(yt)

        self.update_model(None)

        # probably have to change LHS axis of the data plot when
        # going between different plot types.
        if change['new'] == 'logY':
            self.ax_data.set_yscale('linear')
        else:
            self.ax_data.set_yscale('log')

        self.ax_data.relim()
        self.ax_data.autoscale_view()
        self.fig.canvas.draw()

    def _on_use_weights_changed(self, change):
        self._update_analysis_objects()
        self.update_model(None)

    def _on_display_residuals_changed(self, change):
        if change['new']:
            self.ax_residual.set_visible(True)
            self.ax_data.set_position(self._gridspec1[0, 0].get_position(
                self.fig))
            self.ax_sld.set_position(self._gridspec1[:,
                                                     1].get_position(self.fig))
            plt.setp(self.ax_data.get_xticklabels(), visible=False)
        else:
            self.ax_residual.set_visible(False)
            self.ax_data.set_position(self._gridspec2[:, 0].get_position(
                self.fig))
            self.ax_sld.set_position(self._gridspec2[:,
                                                     1].get_position(self.fig))
            plt.setp(self.ax_data.get_xticklabels(), visible=True)

    @property
    def _options_box(self):
        return widgets.VBox(
            [self.plot_type, self.use_weights, self.display_residuals])

    def _update_display_box(self, box):
        """
        Redraw the Jupyter GUI associated with the `Motofit` instance
        """
        vbox_widgets = []

        if self.dataset is not None:
            vbox_widgets.append(widgets.HBox([self.dataset_name, self.chisqr]))

        self.tab.children = [
            self.model_view.model_box, self.model_view.limits_box,
            self._options_box
        ]

        vbox_widgets.append(self.tab)
        vbox_widgets.append(self.output)
        box.children = tuple(vbox_widgets)