Ejemplo n.º 1
0
def jspecgram(snd, nt=1000, starttime=None, endtime=None,
                 nperseg=512, freqsmoothing=1, cmap='viridis', dynrange=(-90., -40.),
                 maxfreq=10e3):
    nfft = int(nperseg*freqsmoothing)
    d = sndplot.Spectrogram(snd, nt=nt, starttime=starttime, endtime=endtime,
                            nperseg=nperseg, nfft=nfft, cmap=cmap, dynrange=dynrange,
                            maxfreq=maxfreq)
    button_play = widgets.Button(description='play')
    button_play.on_click(d.play)
    button_stop = widgets.Button(description='stop')
    button_stop.on_click(d.stop_playing)
    display(widgets.HBox((button_play, button_stop)))
    drs = widgets.FloatRangeSlider(value=(-75, -35), min=-120, max=0, step=1, description='dynamic range (dB)')
    drs.observe(lambda change: d.set_clim(change['new']), names='value')
    display(drs)
    freqs = widgets.FloatRangeSlider(value=(0, 10), min=0, max=snd.fs/2e3, step=0.1,
                                     description='frequency range (kHz)')
    freqs.observe(lambda change: d.set_freqrange(change['new']), names='value')
    display(freqs)
    npersegw = widgets.IntText(value=nperseg, min=1, max=4096 * 2,
                               description='nperseg')
    npersegw.observe(lambda change: d.set_nperseg(change['new']),
                     names='value')
    display(npersegw)
    return d
Ejemplo n.º 2
0
    def range_slider(self,
                     callback_method,
                     _analyser,
                     variable="",
                     description="",
                     range=(0, 0),
                     max=500,
                     step=1,
                     readout_format='d'):
        _analyser.data[variable] = range

        range_slider = widgets.FloatRangeSlider(
            value=range,
            min=0,
            max=max,
            step=step,
            description=description,
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format=readout_format,
            layout=widgets.Layout(width='{}px'.format(self.slider_width),
                                  grid_area=variable))

        range_slider.style.description_width = '{}px'.format(self.desc_width)
        range_slider.observe(callback_method, ['value'])
        return range_slider
Ejemplo n.º 3
0
def float_range_controller(tmin, tmax, start_value=None):
    if start_value is None:
        start_value = [tmin, min(tmin + 50, tmax)]

    slider = widgets.FloatRangeSlider(
        value=start_value,
        min=tmin,
        max=tmax,
        step=0.1,
        description='time window',
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='.1f',
        layout=Layout(width='90%'))

    forward_button = widgets.Button(description='▶', layout=Layout(width='50px'))
    forward_button.on_click(lambda b: move_range_slider_up(slider))

    backwards_button = widgets.Button(description='◀', layout=Layout(width='50px'))
    backwards_button.on_click(lambda b: move_range_slider_down(slider))

    button_box = widgets.HBox(children=[backwards_button, forward_button])
    button_box.layout.align_items = 'center'

    controller = widgets.VBox(
        layout=Layout(width='250px'),
        children=[slider, button_box])

    return controller
    def make_range_slider(self, **kwargs):
        """

        Parameters
        ----------
        kwargs: passed into RangeSlider constructor

        Returns
        -------

        """

        slider_kwargs = dict(value=self.start_value,
                             min=self.vmin,
                             max=self.vmax,
                             continuous_update=False,
                             readout=True,
                             style={'description_width': 'initial'},
                             orientation=self.orientation)

        if self.dtype == 'float':
            slider_kwargs.update(readout_format='.1f',
                                 step=0.1,
                                 description='time window (s)',
                                 layout=Layout(width='100%'))
            slider_kwargs.update(kwargs)
            return widgets.FloatRangeSlider(**slider_kwargs)
        elif self.dtype == 'int':
            slider_kwargs.update(description='unit window',
                                 layout=Layout(height='100%'))
            slider_kwargs.update(kwargs)
            return widgets.IntRangeSlider(**slider_kwargs)
        else:
            raise ValueError('Unrecognized dtype: {}'.format(self.dtype))
Ejemplo n.º 5
0
    def _build_strat_form(self):
        label_width = '20rem'
        control_width = '35rem'
        ap_label = widgets.Label('Aperture radius:')
        ap_label.layout.width = label_width
        self.ap_size = widgets.FloatSlider(
            value=0.1,
            min=self.APERTURE_MIN,
            max=self.APERTURE_MAX,
            step=self.APERTURE_INCREMENT,
        )
        self.ap_size.layout.width = control_width
        self.ap_size.observe(self.check_ann, names='value')

        self.overplot = widgets.Checkbox(description="Overlay apertures?", value=True)
        self.overplot.observe(self._trigger_update_plots, names='value')

        extraction_aperture = widgets.HBox([
            ap_label,
            self.ap_size,
            widgets.Label('arcsec'),
        ])

        background_annulus_label = widgets.Label("Background annulus radii:")
        background_annulus_label.layout.width = label_width
        self.background_annulus = widgets.FloatRangeSlider(
            value=[self.ap_size.value + 0.1, self.ap_size.value + 0.2],
            min=0,
            max=2.0,
            step=self.APERTURE_INCREMENT,
        )
        self.background_annulus.layout.width = control_width
        self.background_annulus.observe(self.check_ann, 'value')
        background_estimation = widgets.HBox([
            background_annulus_label,
            self.background_annulus,
            widgets.Label("arcsec"),
        ])
        background_estimation.layout.width = '100%'

        return widgets.VBox([
            extraction_aperture,
            background_estimation,
            self.overplot
        ])
Ejemplo n.º 6
0
    def create_qubit_params_widgets(self):
        """Creates all the widgets that will be used
        for changing the parameter values for the specified qubit.
        """
        # We need to clear qubit_params_widgets since the previous widgets from the
        # old qubit will still be initialized otherwise.
        self.qubit_params_widgets.clear()

        for param_name, param_val in self.qubit_base_params.items():
            if param_name == "grid":
                grid_min = self.qubit_current_params["grid"].min_val
                grid_max = self.qubit_current_params["grid"].max_val
                self.qubit_params_widgets[param_name] = widgets.FloatRangeSlider(
                    min=-12 * np.pi,
                    max=12 * np.pi,
                    value=[grid_min, grid_max],
                    step=0.05,
                    description="Grid range",
                    continuous_update=False,
                    layout=Layout(width="300px"),
                )
            elif isinstance(param_val, int):
                kwargs = (
                    self.active_defaults.get(param_name) or self.active_defaults["int"]
                )
                self.qubit_params_widgets[param_name] = widgets.IntSlider(
                    **kwargs,
                    value=param_val,
                    description="{}:".format(param_name),
                    continuous_update=False,
                    layout=Layout(width="300px")
                )
            else:
                kwargs = (
                    self.active_defaults.get(param_name)
                    or self.active_defaults["float"]
                )
                self.qubit_params_widgets[param_name] = widgets.FloatSlider(
                    **kwargs,
                    value=param_val,
                    step=0.01,
                    description="{}:".format(param_name),
                    continuous_update=False,
                    layout=Layout(width="300px")
                )
Ejemplo n.º 7
0
    def __init__(self,
                 on_interact=None,
                 output=None,
                 overwrite_previous_output=True,
                 feedback=False,
                 run=True,
                 action_kws={},
                 *args,
                 **kwargs):
        super().__init__(on_interact=on_interact,
                         output=output,
                         overwrite_previous_output=overwrite_previous_output,
                         feedback=feedback,
                         action_kws=action_kws)

        self.widget = widgets.FloatRangeSlider(*args, **kwargs)

        if run:
            self.run()
Ejemplo n.º 8
0
    def make_range_slider(self, **kwargs):
        """

        Parameters
        ----------
        kwargs: passed into RangeSlider constructor

        Returns
        -------

        """

        slider_kwargs = dict(
            value=self.start_value,
            min=self.vmin,
            max=self.vmax,
            continuous_update=False,
            readout=True,
            style={"description_width": "initial"},
            orientation=self.orientation,
        )

        if self.dtype == "float":
            slider_kwargs.update(
                readout_format=".1f",
                step=0.1,
                description="time window (s)",
                layout=Layout(width="100%"),
            )
            slider_kwargs.update(kwargs)
            return widgets.FloatRangeSlider(**slider_kwargs)
        elif self.dtype == "int":
            slider_kwargs.update(description="unit window",
                                 layout=Layout(height="100%"))
            slider_kwargs.update(kwargs)
            return widgets.IntRangeSlider(**slider_kwargs)
        else:
            raise ValueError("Unrecognized dtype: {}".format(self.dtype))
Ejemplo n.º 9
0
    def __init__(self, img_stack):
        self.img_stack = img_stack
        self.fig, self.axes = plt.subplots()
        self.img_mpl = self.axes.imshow(self.img_stack[0])

        self.idx = 0

        vbox = widgets.VBox()

        frame_slider = widgets.IntSlider(value=0,
                                         min=0,
                                         max=len(img_stack),
                                         step=1,
                                         disabled=False,
                                         continuous_update=True,
                                         orientation='horizontal',
                                         readout=True,
                                         readout_format='d')
        frame_slider.observe(self.update, names='value')

        vmin = self.img_stack[0].min()
        vmax = self.img_stack[0].max()
        self.vrange_slider = widgets.FloatRangeSlider(
            value=[vmin, vmax],
            min=vmin,
            max=vmax,
            step=(vmax - vmin) / 1000,
            disabled=False,
            continuous_update=True,
            orientation='horizontal',
            readout=True,
            readout_format='.1f',
        )

        self.vrange_slider.observe(self.v_update, names='value')

        vbox.children = [frame_slider, self.vrange_slider]
        display(vbox)
Ejemplo n.º 10
0
def RandomWave():
    """ Main function called by notebook
    """
    noWaves_sldr = widgets.IntSlider(value=10,
                                     min=2,
                                     max=20,
                                     step=1,
                                     description='No. Waves',
                                     continuous_update=False)
    seed_Text = widgets.IntText(123, description='Seed')
    filter_sldr = widgets.FloatRangeSlider(value=[minf, minf],
                                           min=minf,
                                           max=maxf,
                                           description="Filter Range",
                                           continuous_update=False)
    return widgets.VBox([
        widgets.HBox([noWaves_sldr, filter_sldr, seed_Text]),
        widgets.interactive_output(runWaves, {
            'noWaves': noWaves_sldr,
            'seed': seed_Text,
            'filterRange': filter_sldr
        })
    ])
Ejemplo n.º 11
0
 def _add_widgets(self):
     """Add widgets to the layout."""
     # Slider for the max, vmin view
     self.val_slider = widgets.FloatRangeSlider(
         value=[self.vmin, self.vmax],
         min=self.vmin - np.fabs(self.vmax - self.vmin),
         max=self.vmax + np.fabs(self.vmax - self.vmin),
         step=0.1,
         description='Boost:',
         disabled=False,
         continuous_update=False,
         orientation='horizontal',
         readout=True,
         readout_format='d',
         layout=Layout(width='70%'))
     self.cmap_sel = widgets.Dropdown(options=Defaults.cmaps,
                                      value=Defaults.cmaps[0],
                                      description='Color Map:',
                                      disabled=False,
                                      layout=Layout(width='200px'))
     self.cmap_sel.observe(self._set_cmap)
     self.val_slider.observe(self._set_clim)
     self._add_tabs()
Ejemplo n.º 12
0
    def add_range_slider(self):
        """
        Add another range slider to the list of range sliders, that, when its
        value changes, updates the combined range of the current object.
        """
        # a new range slider is requested, but don't show description again
        slider = widgets.FloatRangeSlider(
            min=self.min,
            max=self.max,
            step=self.step,
            disabled=self.disabled,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.1f',
        )
        if not self.range_slider_list:
            # first slider gets a description
            slider.description = "MultiRangeSlider"

        # when its value changes, update internal selection of combined range
        slider.observe(self.update_selected_values, names='value')

        self.range_slider_list.append(slider)
Ejemplo n.º 13
0
    def _set_widgets(self):
        min_v = []
        max_v = []
        if self._link:
            n_links = 1
        else:
            n_links = self._num_dsets
        for i in range(n_links):
            try:
                min_v.append(
                    (self.vmin[i] - np.fabs(self.vmax[i] - self.vmin[i]),
                     self.vmin[i]))
            except TypeError:
                min_v.append((0, -1))
            try:
                max_v.append(
                    (self.vmax[i] + np.fabs(self.vmax[i] - self.vmin[i]),
                     self.vmax[i]))
            except TypeError:
                max_v.append((1000, 11000))
        self.val_sliders = [
            widgets.FloatRangeSlider(value=[min_v[i][-1], max_v[i][-1]],
                                     min=min_v[i][0],
                                     max=max_v[i][0],
                                     step=self.mag[i],
                                     description='Range:',
                                     disabled=False,
                                     continuous_update=False,
                                     orientation='horizontal',
                                     readout=True,
                                     readout_format='0.4f',
                                     layout=Layout(width='100%'))
            for i in range(n_links)
        ]
        self.cmap_sel = [
            widgets.Dropdown(options=self.cmaps,
                             value=self.cmaps[0],
                             description='CMap:',
                             disabled=False,
                             layout=Layout(width='200px'))
            for i in range(n_links)
        ]
        self.t_step = widgets.BoundedFloatText(value=0,
                                               min=0,
                                               max=10000,
                                               step=1,
                                               disabled=False,
                                               description=self.step_variable,
                                               layout=Layout(width='200px',
                                                             height='30px'))

        self.t_step.observe(self._tstep_observer, names='value')
        for n in range(n_links):
            self.val_sliders[n].observe(self._clim_observer, names='value')
            self.cmap_sel[n].observe(self._cmap_observer, names='value')
            self.val_sliders[n].num = n
            self.cmap_sel[n].num = n
            if n_links > 1:
                self.val_sliders[n].description = 'Range #{}:'.format(n + 1)
                self.cmap_sel[n].description = 'CMap #{}:'.format(n + 1)
            if self._link and n > 0:
                _ = widgets.jslink((self.val_sliders[0], 'value'),
                                   (self.val_sliders[n], 'value'))
Ejemplo n.º 14
0
def orthoslices_3D(data, normalize=True, continuous_update=False, **kwargs):
    if not isnotebook:
        print('Function not suited for working outside of Jupyter notebooks!')
    else:
        get_ipython().magic('matplotlib notebook')
        get_ipython().magic('matplotlib notebook')

    e_range, x_range, y_range = data.shape
    dmin = np.amin(data)
    dmax = np.amax(data)

    w_e = widgets.IntSlider(value=e_range // 2,
                            min=0,
                            max=e_range - 1,
                            step=1,
                            description='Energy:',
                            disabled=False,
                            continuous_update=continuous_update,
                            orientation='horizontal',
                            readout=True,
                            readout_format='d')
    w_kx = widgets.IntSlider(value=x_range // 2,
                             min=0,
                             max=x_range - 1,
                             step=1,
                             description='kx:',
                             disabled=False,
                             continuous_update=continuous_update,
                             orientation='horizontal',
                             readout=True,
                             readout_format='d')
    w_ky = widgets.IntSlider(value=y_range // 2,
                             min=0,
                             max=y_range - 1,
                             step=1,
                             description='ky:',
                             disabled=False,
                             continuous_update=continuous_update,
                             orientation='horizontal',
                             readout=True,
                             readout_format='d')
    w_clim = widgets.FloatRangeSlider(value=[.1, .9],
                                      min=0,
                                      max=1,
                                      step=0.001,
                                      description='Contrast:',
                                      disabled=False,
                                      continuous_update=continuous_update,
                                      orientation='horizontal',
                                      readout=True,
                                      readout_format='.1f')
    w_cmap = widgets.Dropdown(options=cmaps,
                              value='terrain',
                              description='colormap:',
                              disabled=False)
    w_bin = widgets.BoundedIntText(value=1,
                                   min=1,
                                   max=min(data.shape),
                                   step=1,
                                   description='resample:',
                                   disabled=False)
    w_interpolate = widgets.Checkbox(value=True,
                                     description='Interpolate',
                                     disabled=False)
    w_grid = widgets.Checkbox(value=False, description='Grid', disabled=False)
    w_trackers = widgets.Checkbox(value=True,
                                  description='Trackers',
                                  disabled=False)
    w_trackercol = widgets.ColorPicker(concise=False,
                                       description='tracker line color',
                                       value='orange')

    ui_pos = widgets.HBox([w_e, w_kx, w_ky])
    ui_color = widgets.HBox([
        widgets.VBox([w_clim, w_cmap]),
        widgets.VBox([w_bin, w_interpolate, w_grid]),
        widgets.VBox([w_trackers, w_trackercol]),
    ])

    children = [ui_pos, ui_color]
    tab = widgets.Tab(children=children, )
    tab.set_title(0, 'data select')
    tab.set_title(1, 'colormap')

    figsize = kwargs.pop('figsize', (5, 5))
    fig = plt.figure(figsize=figsize, **kwargs)
    plt.tight_layout()
    # [left, bottom, width, height]
    # fig.locator_params(nbins=4)

    # cbar_ax = fig.add_axes([.05,.4,.05,4], xticklabels=[], yticklabels=[])
    # cbar_ax.yaxis.set_major_locator(plt.LinearLocator(5))

    img_ax = fig.add_axes([.15, .4, .4, .4], xticklabels=[], yticklabels=[])
    img_ax.xaxis.set_major_locator(plt.LinearLocator(5))
    img_ax.yaxis.set_major_locator(plt.LinearLocator(5))

    xproj_ax = fig.add_axes([.15, .1, .4, .28], xticklabels=[], yticklabels=[])
    xproj_ax.set_xlabel('$k_x$')
    xproj_ax.xaxis.set_major_locator(plt.LinearLocator(5))

    yproj_ax = fig.add_axes([.57, .4, .28, .4], xticklabels=[], yticklabels=[])
    yproj_ax.yaxis.set_label_position("right")
    yproj_ax.set_ylabel('$k_y$')
    yproj_ax.yaxis.set_major_locator(plt.LinearLocator(5))

    for ax in [img_ax, yproj_ax, xproj_ax]:  # ,cbar_ax]:
        ax.tick_params(axis="both",
                       direction="in",
                       bottom=True,
                       top=True,
                       left=True,
                       right=True,
                       which='both')

    clim_ = 0.01, .99

    e_img = norm_img(data[data.shape[0] // 2, :, :])
    y_img = norm_img(data[:, data.shape[1] // 2, :])
    x_img = norm_img(data[:, :, data.shape[2] // 2].T)
    e_plot = img_ax.imshow(
        e_img,
        cmap='terrain',
        aspect='auto',
        interpolation='gaussian',
        clim=clim_,
    )  # origin='lower')
    x_plot = yproj_ax.imshow(
        x_img,
        cmap='terrain',
        aspect='auto',
        interpolation='gaussian',
        clim=clim_,
    )  # origin='lower')
    y_plot = xproj_ax.imshow(
        y_img,
        cmap='terrain',
        aspect='auto',
        interpolation='gaussian',
        clim=clim_,
    )  # origin='lower')

    pe_x = img_ax.axvline(x_range / 2, c='orange')
    pe_y = img_ax.axhline(y_range / 2, c='orange')
    px_x = xproj_ax.axvline(x_range / 2, c='orange')
    px_e = xproj_ax.axhline(e_range / 2, c='orange')
    py_y = yproj_ax.axhline(y_range / 2, c='orange')
    py_e = yproj_ax.axvline(e_range / 2, c='orange')

    def update(e, kx, ky, clim, cmap, binning, interpolate, grid, trackers,
               trackerscol):
        if normalize:
            e_img = norm_img(data[e, :, :][::binning, ::binning])
            y_img = norm_img(data[:, ky, :][::binning, ::binning])
            x_img = norm_img(data[:, :, kx][::binning, ::binning])
        else:
            e_img = data[e, :, :][::binning, ::binning]
            y_img = data[:, ky, :][::binning, ::binning]
            x_img = data[:, :, kx][::binning, ::binning]
        for axis, plot, img in zip([img_ax, yproj_ax, xproj_ax],
                                   [e_plot, x_plot, y_plot],
                                   [e_img, x_img.T, y_img]):

            plot.set_data(img)
            plot.set_clim(clim)
            plot.set_cmap(cmap)
            axis.grid(grid)
            if interpolate:
                plot.set_interpolation('gaussian')
            else:
                plot.set_interpolation(None)
            if trackers:
                pe_x.set_xdata(kx)
                pe_x.set_color(trackerscol)
                pe_y.set_ydata(ky)
                pe_y.set_color(trackerscol)
                px_x.set_xdata(kx)
                px_x.set_color(trackerscol)
                px_e.set_ydata(e)
                px_e.set_color(trackerscol)
                py_y.set_ydata(ky)
                py_y.set_color(trackerscol)
                py_e.set_xdata(e)
                py_e.set_color(trackerscol)

    interactive_plot = interactive_output(
        update, {
            'e': w_e,
            'kx': w_kx,
            'ky': w_ky,
            'clim': w_clim,
            'cmap': w_cmap,
            'binning': w_bin,
            'interpolate': w_interpolate,
            'grid': w_grid,
            'trackers': w_trackers,
            'trackerscol': w_trackercol,
        })
    display(interactive_plot, tab)
Ejemplo n.º 15
0
def orthoslices_4D(data,
                   axis_order=['E', 'kx', 'ky', 'kz'],
                   normalize=True,
                   continuous_update=True,
                   **kwargs):
    if not isnotebook:
        raise EnvironmentError(
            'Function not suited for working outside of Jupyter notebooks!')
    else:
        get_ipython().magic('matplotlib notebook')
        get_ipython().magic('matplotlib notebook')

    assert len(
        data.shape
    ) == 4, 'Data should be 4-dimensional, but data has {} dimensions'.format(
        data.shape)

    # make controls for data slicers
    # slicers = []
    # for shape, name in zip(data.shape, axis_order):
    #     slicers.append(widgets.IntSlider(value=shape // 2,
    #                                      min=0,
    #                                      max=shape - 1,
    #                                      step=1,
    #                                      description=name,
    #                                      disabled=False,
    #                                      continuous_update=False,
    #                                      orientation='horizontal',
    #                                      readout=True,
    #                                      readout_format='d'
    #                                      ))

    e_range, x_range, y_range, z_range = data.shape

    w_e = widgets.IntSlider(value=e_range // 2,
                            min=0,
                            max=e_range - 1,
                            step=1,
                            description='Energy:',
                            disabled=False,
                            continuous_update=continuous_update,
                            orientation='horizontal',
                            readout=True,
                            readout_format='d')
    w_kx = widgets.IntSlider(value=x_range // 2,
                             min=0,
                             max=x_range - 1,
                             step=1,
                             description='kx:',
                             disabled=False,
                             continuous_update=continuous_update,
                             orientation='horizontal',
                             readout=True,
                             readout_format='d')
    w_ky = widgets.IntSlider(value=y_range // 2,
                             min=0,
                             max=y_range - 1,
                             step=1,
                             description='ky:',
                             disabled=False,
                             continuous_update=continuous_update,
                             orientation='horizontal',
                             readout=True,
                             readout_format='d')
    w_kz = widgets.IntSlider(value=z_range // 2,
                             min=0,
                             max=z_range - 1,
                             step=1,
                             description='kz:',
                             disabled=False,
                             continuous_update=continuous_update,
                             orientation='horizontal',
                             readout=True,
                             readout_format='d')

    slicers = [w_e, w_kx, w_ky, w_kz]
    ui_slicers = widgets.HBox(slicers)

    # make controls for graphics appearance
    w_clim = widgets.FloatRangeSlider(value=[.1, .9],
                                      min=0,
                                      max=1,
                                      step=0.001,
                                      description='Contrast:',
                                      disabled=False,
                                      continuous_update=True,
                                      orientation='horizontal',
                                      readout=True,
                                      readout_format='.1f')
    w_cmap = widgets.Dropdown(options=cmaps,
                              value='terrain',
                              description='colormap:',
                              disabled=False)
    w_bin = widgets.BoundedIntText(value=1,
                                   min=1,
                                   max=min(data.shape),
                                   step=1,
                                   description='resample:',
                                   disabled=False)
    w_interpolate = widgets.Checkbox(value=True,
                                     description='Interpolate',
                                     disabled=False)
    w_grid = widgets.Checkbox(value=False, description='Grid', disabled=False)
    w_trackers = widgets.Checkbox(value=True,
                                  description='Trackers',
                                  disabled=False)
    w_trackercol = widgets.ColorPicker(concise=False,
                                       description='tracker line color',
                                       value='orange')
    ui_color = widgets.HBox([
        widgets.VBox([w_clim, w_cmap]),
        widgets.VBox([w_bin, w_interpolate, w_grid]),
        widgets.VBox([w_trackers, w_trackercol]),
    ])

    tab = widgets.Tab(children=[ui_slicers, ui_color], )
    tab.set_title(0, 'Data slicing')
    tab.set_title(1, 'Graphics')

    figsize = kwargs.pop('figsize', (5, 5))
    fig = plt.figure(figsize=figsize, **kwargs)
    plt.tight_layout()

    img_ax = fig.add_axes([.15, .4, .4, .4], xticklabels=[], yticklabels=[])
    img_ax.xaxis.set_major_locator(plt.LinearLocator(5))
    img_ax.yaxis.set_major_locator(plt.LinearLocator(5))

    xproj_ax = fig.add_axes([.15, .1, .4, .28], xticklabels=[], yticklabels=[])
    xproj_ax.set_xlabel('$k_x$')
    xproj_ax.xaxis.set_major_locator(plt.LinearLocator(5))

    yproj_ax = fig.add_axes([.57, .4, .28, .4], xticklabels=[], yticklabels=[])
    yproj_ax.yaxis.set_label_position("right")
    yproj_ax.set_ylabel('$k_y$')
    yproj_ax.yaxis.set_major_locator(plt.LinearLocator(5))

    for ax in [img_ax, yproj_ax, xproj_ax]:  # ,cbar_ax]:
        ax.tick_params(axis="both",
                       direction="in",
                       bottom=True,
                       top=True,
                       left=True,
                       right=True,
                       which='both')

    clim_ = 0.01, .99

    e_img = norm_img(data[data.shape[0] // 2, :, :, data.shape[3] // 2])
    y_img = norm_img(data[:, data.shape[1] // 2, :, data.shape[3] // 2])
    x_img = norm_img(data[:, :, data.shape[2] // 2, data.shape[3] // 2].T)
    e_plot = img_ax.imshow(
        e_img,
        cmap='terrain',
        aspect='auto',
        interpolation='gaussian',
        clim=clim_,
    )  # origin='lower')
    x_plot = yproj_ax.imshow(
        x_img,
        cmap='terrain',
        aspect='auto',
        interpolation='gaussian',
        clim=clim_,
    )  # origin='lower')
    y_plot = xproj_ax.imshow(
        y_img,
        cmap='terrain',
        aspect='auto',
        interpolation='gaussian',
        clim=clim_,
    )  # origin='lower')

    pe_x = img_ax.axvline(x_range / 2, c='orange')
    pe_y = img_ax.axhline(y_range / 2, c='orange')
    px_x = xproj_ax.axvline(x_range / 2, c='orange')
    px_e = xproj_ax.axhline(e_range / 2, c='orange')
    py_y = yproj_ax.axhline(y_range / 2, c='orange')
    py_e = yproj_ax.axvline(e_range / 2, c='orange')

    def update(e, kx, ky, kz, clim, cmap, binning, interpolate, grid, trackers,
               trackerscol):
        if normalize:
            e_img = norm_img(data[e, :, :, kz][::binning, ::binning])
            y_img = norm_img(data[:, ky, :, kz][::binning, ::binning])
            x_img = norm_img(data[:, :, kx, kz][::binning, ::binning])
        else:
            e_img = data[e, :, :, kz][::binning, ::binning]
            y_img = data[:, ky, :, kz][::binning, ::binning]
            x_img = data[:, :, kx, kz][::binning, ::binning]
        for axis, plot, img in zip([img_ax, yproj_ax, xproj_ax],
                                   [e_plot, x_plot, y_plot],
                                   [e_img, x_img.T, y_img]):

            plot.set_data(img)
            plot.set_clim(clim)
            plot.set_cmap(cmap)
            axis.grid(grid)
            if interpolate:
                plot.set_interpolation('gaussian')
            else:
                plot.set_interpolation(None)
            if trackers:
                pe_x.set_xdata(kx)
                pe_x.set_color(trackerscol)
                pe_y.set_ydata(ky)
                pe_y.set_color(trackerscol)
                px_x.set_xdata(kx)
                px_x.set_color(trackerscol)
                px_e.set_ydata(e)
                px_e.set_color(trackerscol)
                py_y.set_ydata(ky)
                py_y.set_color(trackerscol)
                py_e.set_xdata(e)
                py_e.set_color(trackerscol)

    interactive_plot = interactive_output(
        update, {
            'e': w_e,
            'kx': w_kx,
            'ky': w_ky,
            'kz': w_kz,
            'clim': w_clim,
            'cmap': w_cmap,
            'binning': w_bin,
            'interpolate': w_interpolate,
            'grid': w_grid,
            'trackers': w_trackers,
            'trackerscol': w_trackercol,
        })
    display(interactive_plot, tab)
Ejemplo n.º 16
0
def make_tabulated_sandbox(num_experiments=1):
    # We create many copies of the same widget objects in order to isolate our experimental areas.
    num_samples = [widgets.IntSlider(value=1500, continuous_update=False, 
        orientation='vertical', disable=False,
        min=int(5E2), max=int(2.5E4), step=500, 
        description='Samples $N$') for k in range(num_experiments)]

    sd = [widgets.FloatSlider(value=0.25, continuous_update=False, 
        orientation='vertical', disable=False,
        min=0.05, max=1.75, step=0.05, 
        description='$\sigma$') for k in range(num_experiments)]

    lam_min, lam_max = 2.0, 7.0
    
        
    lam_bound = [widgets.FloatRangeSlider(value=[0,1], continuous_update=False, 
        orientation='horizontal', disable=False,
        min=lam_min, max = lam_max, step=0.5, 
        description='$\Lambda \in$') for k in range(num_experiments)]

    lam_0 = [widgets.FloatSlider(value=4.5, continuous_update=False, 
        orientation='horizontal', disable=False,
        min=lam_bound[k].value[0], max=lam_bound[k].value[1], step=0.1, 
        description='IC: $\lambda_0$') for k in range(num_experiments)]


    t_0 = [widgets.FloatSlider(value=0.5, continuous_update=False, 
        orientation='horizontal', disable=False,
        min=0.1, max=2.0, step=0.05,
        description='$t_0$ =', readout_format='.2f') for k in range(num_experiments)]

    Delta_t = [widgets.FloatSlider(value=0.1, continuous_update=False, 
        orientation='horizontal', disable=False,
        min=0.05, max=0.5, step=0.05,
        description='$\Delta_t$ =', readout_format='1.2e') for k in range(num_experiments)]

    num_observations = [widgets.IntSlider(value=50, continuous_update=False, 
        orientation='horizontal', disable=False,
        min=1, max=100, 
        description='# Obs. =') for k in range(num_experiments)]
    
    compare = [widgets.Checkbox(value=False, disable=False,
        description='Observed v. Q(Post)') for k in range(num_experiments)]
    
    smooth_post = [widgets.Checkbox(value=False, disable=False,
        description='Smooth Posterior') for k in range(num_experiments)]
    
    fixed_noise = [widgets.Checkbox(value=False, disable=False,
        description='Fixed Noise Model') for k in range(num_experiments)]
    
    num_trials = [widgets.IntSlider(value=1, continuous_update=False, 
        orientation='vertical', disable=False,
        min=1, max=25, 
        description='Num. Trials') for k in range(num_experiments)]
    
    # IF YOU ADD MORE FUNCTIONS to cb_sandbox.py, increase max below.
    fun_choice = [widgets.IntSlider(value=0, continuous_update=False, 
        orientation='horizontal', disable=False,
        min=0, max=2, 
        description='Fun. Choice') for k in range(num_experiments)]
    
    fixed_obs_window = [widgets.Checkbox(value=False, disable=False,
        description='Fixed Obs. Window') for k in range(num_experiments)]
                                
    Keys = [{'num_samples': num_samples[k], 
            'lam_bound': lam_bound[k], 
            'lam_0': lam_0[k], 
            't_0': t_0[k], 
            'Delta_t': Delta_t[k],
            'num_observations': num_observations[k], 
            'sd': sd[k],
            'compare': compare[k],
            'smooth_post': smooth_post[k],
            'fixed_noise': fixed_noise[k],
            'num_trials': num_trials[k], 
            'fun_choice': fun_choice[k]} for k in range(num_experiments)] 

    # Make all the interactive outputs for each tab and store them in a vector called out. (for output)
    out = [widgets.interactive_output(sandbox, Keys[k]) for k in range(num_experiments)]
    
    
    ### LINK WIDGETS TOGETHER (dependent variables) ###
    # Different ranges for different problems                        
    def update_lam_bound(*args):
        k = tab_nest.selected_index
        if fun_choice[k].value == 2: # fixed frequency
            lam_bound[k].min = 0.0
            lam_bound[k].max = 10.0
            lam_bound[k].value = [0, 1]
            lam_0[k].value = 0.5
        if fun_choice[k].value == 1: # fixed initial condition
            lam_bound[k].min = -1.0 
            lam_bound[k].max = 1.0
            lam_bound[k].value = [-1, 1]
            lam_0[k].value = 0
            
    [fun_choice[k].observe(update_lam_bound, 'value') for k in range(num_experiments)]
    
    # if you change the bounds on the parameter space, update the bounds of lambda_0                          
    def update_lam_0(*args):
        k = tab_nest.selected_index
    #     lam_0[k].value = np.minimum(lam_0[k].value, lam_bound[k].value[1] )
    #     lam_0[k].value = np.maximum(lam_0[k].value, lam_bound[k].value[0] )
        lam_0[k].min = lam_bound[k].value[0] 
        lam_0[k].max = lam_bound[k].value[1]

    [lam_bound[k].observe(update_lam_0, 'value') for k in range(num_experiments)]
    
    
    current_window_size = [ num_observations[k].value*Delta_t[k].value for k in range(num_experiments)]
    def lock_window_size(*args): # if you want to lock the window
        k = tab_nest.selected_index
        if fixed_obs_window[k].value:
            current_window_size[k] = num_observations[k].value*Delta_t[k].value # record the present value for later use.
        
    def update_num_obs(*args): # update num obs if Delta_t changes
        k = tab_nest.selected_index
        if fixed_obs_window[k].value:
            num_observations[k].value = current_window_size[k]/Delta_t[k].value
    
    def update_delta_t(*args): # update num obs if Delta_t changes
        k = tab_nest.selected_index
        if fixed_obs_window[k].value:
            Delta_t[k].value = current_window_size[k]/num_observations[k].value
    
    [fixed_obs_window[k].observe(lock_window_size, 'value') for k in range(num_experiments)]
    [Delta_t[k].observe(update_num_obs, 'value') for k in range(num_experiments)]
    [num_observations[k].observe(update_delta_t, 'value') for k in range(num_experiments)]
    
    
    ### GENERATE USER INTERFACE ###
    lbl = widgets.Label("UQ Sandbox", disabled=False)
    # horizontal and vertical sliders are grouped together, displayed in one horizontal box.
    # This HBox lives in a collapsable accordion below which the results are displayed.
    h_sliders = [widgets.VBox([lam_bound[k], lam_0[k], 
                               t_0[k], Delta_t[k], 
                               num_observations[k] ]) for k in range(num_experiments) ]
    v_sliders = [widgets.HBox([ num_samples[k], num_trials[k],
                               sd[k] ]) for k in range(num_experiments) ]
    options = [ widgets.VBox([widgets.Text('Model Options', disabled=True), 
                              fixed_noise[k], fixed_obs_window[k], fun_choice[k],
                              widgets.Text('Plotting Options', disabled=True), 
                              compare[k], smooth_post[k]]) for k in range(num_experiments)]
    user_interface = [widgets.HBox([h_sliders[k], options[k], v_sliders[k]]) for k in range(num_experiments)]
    
    # format the widgets layout (non-default options)
    for k in range(num_experiments): 
        h_sliders[k].layout.justify_content = 'center'
        v_sliders[k].layout.justify_content = 'center'
        user_interface[k].layout.justify_content = 'center'

        
    ### MAKE TABULATED NOTEBOOK ###
    # Create our pages
    pages = [widgets.HBox() for k in range(num_experiments)]

    # instantiate notebook with tabs (accordions) representing experiments
    tab_nest = widgets.Tab()
    tab_nest.children = [pages[k] for k in range(num_experiments)]

    # title your notebooks
    experiment_names = ['Experiment %d'%k for k in range(num_experiments)]
    for k in range(num_experiments):
        tab_nest.set_title(k, experiment_names[k])

    # Spawn the children!!!
    for k in range(num_experiments):
    #     content = widgets.VBox( [user_interface[k], out[k]] )
        A = widgets.Accordion(children=[ user_interface[k] ])
        A.set_title(0,lbl.value)
        A.layout.justify_content = 'center'
        content = widgets.VBox([ A, out[k]  ])
        content.layout.justify_content = 'center'
        tab_nest.children[k].children = [content]
    
    return tab_nest, Keys, fixed_obs_window
Ejemplo n.º 17
0
    def create_plot_settings_widgets(self):
        """Creates all the widgets that will be used for general plotting options."""
        self.qubit_plot_options_widgets = {}
        std_layout = Layout(width="300px")

        operator_dropdown_list = self.get_operators()
        scan_dropdown_list = self.qubit_scan_params.keys()
        mode_dropdown_list = [
            ("Re(·)", "real"),
            ("Im(·)", "imag"),
            ("|·|", "abs"),
            (u"|\u00B7|\u00B2", "abs_sqr"),
        ]
        file = open(self.active_qubit._image_filename, "rb")
        image = file.read()

        self.qubit_plot_options_widgets = {
            "qubit_info_image_widget":
            widgets.Image(value=image,
                          format="jpg",
                          layout=Layout(width="700px")),
            "save_button":
            widgets.Button(icon="save", layout=widgets.Layout(width="35px")),
            "filename_text":
            widgets.Text(
                value=str(Path.cwd().joinpath("plot.pdf")),
                description="",
                disabled=False,
                layout=Layout(width="500px"),
            ),
            "scan_dropdown":
            widgets.Dropdown(
                options=scan_dropdown_list,
                value=self.active_defaults["scan_param"],
                description="Scan over",
                disabled=False,
                layout=std_layout,
            ),
            "mode_dropdown":
            widgets.Dropdown(
                options=mode_dropdown_list,
                description="Plot as:",
                disabled=False,
                layout=std_layout,
            ),
            "operator_dropdown":
            widgets.Dropdown(
                options=operator_dropdown_list,
                value=self.active_defaults["operator"],
                description="Operator",
                disabled=False,
                layout=std_layout,
            ),
            "scan_range_slider":
            widgets.FloatRangeSlider(
                min=self.active_defaults[
                    self.active_defaults["scan_param"]]["min"],
                max=self.active_defaults[self.active_defaults["scan_param"]]
                ["max"],
                value=[
                    self.active_defaults[self.active_defaults["scan_param"]]
                    ["min"],
                    self.active_defaults[
                        self.active_defaults["scan_param"]]["max"],
                ],
                step=0.05,
                description="{} range".format(
                    self.active_defaults["scan_param"]),
                continuous_update=False,
                layout=std_layout,
            ),
            "eigenvalue_state_slider":
            widgets.IntSlider(
                min=1,
                max=10,
                value=7,
                description="Highest state",
                continuous_update=False,
                layout=std_layout,
            ),
            "matrix_element_state_slider":
            widgets.IntSlider(
                min=1,
                max=6,
                value=4,
                description="Highest state",
                continuous_update=False,
                layout=std_layout,
            ),
            "wavefunction_single_state_selector":
            widgets.IntSlider(
                min=0,
                max=10,
                value=0,
                description="State no.",
                continuous_update=False,
                layout=std_layout,
            ),
            "wavefunction_scale_slider":
            widgets.FloatSlider(
                min=0.1,
                max=4,
                value=self.active_defaults["scale"],
                description="\u03c8 ampl.",
                continuous_update=False,
                layout=std_layout,
            ),
            "wavefunction_multi_state_selector":
            widgets.SelectMultiple(
                options=range(0, 10),
                value=[0, 1, 2, 3, 4],
                description="States",
                disabled=False,
                continuous_update=False,
                layout=std_layout,
            ),
            "show_numbers_checkbox":
            widgets.Checkbox(value=False,
                             description="Show values",
                             disabled=False),
            "show3d_checkbox":
            widgets.Checkbox(value=True, description="Show 3D",
                             disabled=False),
            "subtract_ground_checkbox":
            widgets.Checkbox(value=False,
                             description="Subtract E\u2080",
                             disabled=False),
            "manual_scale_checkbox":
            widgets.Checkbox(value=False,
                             description="Manual Scaling",
                             disabled=False),
        }
        self.qubit_plot_options_widgets["save_button"].on_click(
            self.save_button_clicked_action)
        self.qubit_plot_options_widgets["scan_dropdown"].observe(
            self.scan_dropdown_eventhandler, names="value")
Ejemplo n.º 18
0
    def __init__(self, data_handler, output_dir, scale=1):
        super().__init__()

        assert data_handler.data_ready
        self.data_handler = data_handler

        self.output_dir = output_dir
        if not os.path.isdir(self.output_dir):
            os.makedirs(self.output_dir)

        # always use batch size of 1
        self.data_handler.batch_size = 1
        # this is the number of 2d plots we need to draw
        n_slices = self.data_handler.crop_size[0]
        self.n_rows = int(np.sqrt(n_slices))
        self.n_cols = n_slices // self.n_rows
        if n_slices % self.n_rows > 0:
            self.n_cols += 1
        width = self.n_cols * scale
        height = self.n_rows * scale
        self.fig_ori, self.ax_ori = plt.subplots(self.n_rows,
                                                 self.n_cols,
                                                 figsize=(width, height),
                                                 squeeze=False)
        self.fig_aug, self.ax_aug = plt.subplots(self.n_rows,
                                                 self.n_cols,
                                                 figsize=(width, height),
                                                 squeeze=False)

        self.fig_ori.tight_layout()
        self.fig_aug.tight_layout()

        for axes in [self.ax_ori, self.ax_aug]:
            for i in range(self.n_rows):
                for j in range(self.n_cols):
                    axes[i, j].axis("off")
                    # axes[i, j].set_title(f"Slice {i*n_cols + j}")

        # "do_rotation" true/false
        self.do_rotation = widgets.Checkbox(value=True,
                                            description='do_rotation',
                                            disabled=False,
                                            indent=False)
        # p_rot_per_sample float 0.5
        self.p_rot_per_sample = widgets.FloatSlider(
            value=0.5,
            min=0,
            max=1.0,
            description="p_rot_per_sample",
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.2f')
        # rotation angles
        self.augment_slider_rot = widgets.FloatSlider(
            value=15., min=0., max=360, description="rotation angle")

        # "do_elastic_deform" true/false
        self.do_elastic_deform = widgets.Checkbox(
            value=True,
            description='do_elastic_deform',
            disabled=False,
            indent=False)
        # p_el_per_sample float 0.5
        self.p_el_per_sample = widgets.FloatSlider(
            value=0.5,
            min=0,
            max=1.0,
            description="p_el_per_sample",
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.2f')
        # "deformation_scale"  tuple (0, 0.25)
        self.deformation_scale = widgets.FloatRangeSlider(
            value=[0, 0.25],
            min=0,
            max=1.0,
            step=0.05,
            description='deformation_scale',
            # disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.2f')

        # "do_scale" true/false
        self.do_scale = widgets.Checkbox(value=True,
                                         description='do_scale',
                                         disabled=False,
                                         indent=False)
        # p_scale_per_sample float 0.5
        self.p_scale_per_sample = widgets.FloatSlider(
            value=0.5,
            min=0,
            max=1.0,
            description="p_scale_per_sample",
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.2f')
        # "scale" tuple (0.75, 1.25)
        self.scale = widgets.FloatRangeSlider(
            value=[0.75, 1.25],
            min=0,
            max=5.0,
            step=0.1,
            description='scale',
            # disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.2f')
        # "do_mirror", true/false
        self.do_mirror = widgets.Checkbox(value=True,
                                          description='do_mirror',
                                          disabled=False,
                                          indent=False)
        # "p_per_sample" float 0.15
        # for gamma, gaussian noise, brightness,
        self.p_per_sample = widgets.FloatSlider(value=0.15,
                                                min=0,
                                                max=1.0,
                                                description="p_per_sample",
                                                continuous_update=False,
                                                orientation='horizontal',
                                                readout=True,
                                                readout_format='.2f')
        # "brightness_range" tuple (0.7, 1.5)
        self.brightness_range = widgets.FloatRangeSlider(
            value=[0.7, 1.5],
            min=0,
            max=5.0,
            step=0.1,
            description='brightness_range',
            # disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.2f')
        # "gaussian_noise_variance" tuple (0, 0.05)
        self.gaussian_noise_variance = widgets.FloatRangeSlider(
            value=[0., 0.05],
            min=0,
            max=1.0,
            step=0.05,
            description='gaussian_noise_variance',
            # disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.2f')
        # "gamma_range" tuple (0.5, 2)
        self.gamma_range = widgets.FloatRangeSlider(
            value=[0.5, 2.0],
            min=0,
            max=5.0,
            step=0.1,
            description='gamma_range',
            # disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.2f')

        self.arg_to_widget_map = {
            # rotations
            "do_rotation": self.do_rotation,
            "p_rot_per_sample": self.p_rot_per_sample,
            "angle_x": self.augment_slider_rot,
            "angle_y": self.augment_slider_rot,
            "angle_z": self.augment_slider_rot,
            # elastic deforms
            "do_elastic_deform": self.do_elastic_deform,
            "p_el_per_sample": self.p_el_per_sample,
            "deformation_scale": self.deformation_scale,
            # scaling
            "do_scale": self.do_scale,
            "p_scale_per_sample": self.p_scale_per_sample,
            "scale": self.scale,
            # mirroring
            "do_mirror": self.do_mirror,
            # all others
            "p_per_sample": self.p_per_sample,
            "brightness_range": self.brightness_range,
            "gaussian_noise_variance": self.gaussian_noise_variance,
            "gamma_range": self.gamma_range,
        }

        self.patient_dropdown = widgets.Dropdown(
            options=self.data_handler.patient_ids, value=None)
        self.patient_dropdown.observe(self._init_patient_cb, names="value")

        self.apply_aug_button = widgets.Button(description="Augment",
                                               button_style="success")
        self.apply_aug_button.on_click(self._plot_augmented_batch_cb)

        self.output_widget = widgets.Output(
            layout={'border': '1px solid black'})

        self.children = [
            widgets.VBox(children=[
                self.patient_dropdown,
                widgets.HBox(children=[
                    self.fig_ori.canvas,
                    widgets.VBox(children=[
                        widgets.Label("Augmentation args"),
                        # rotation children
                        widgets.VBox(children=[
                            self.do_rotation,
                            self.p_rot_per_sample,
                            self.augment_slider_rot,
                        ],
                                     layout={'border': '1px solid black'}),
                        # elastic deform children
                        widgets.VBox(children=[
                            self.do_elastic_deform,
                            self.p_el_per_sample,
                            self.deformation_scale,
                        ],
                                     layout={'border': '1px solid black'}),
                        # scaling
                        widgets.VBox(children=[
                            self.do_scale,
                            self.p_scale_per_sample,
                            self.scale,
                        ],
                                     layout={'border': '1px solid black'}),
                        # mirroring
                        widgets.VBox(children=[
                            self.do_mirror,
                        ],
                                     layout={'border': '1px solid black'}),
                        # all others
                        widgets.VBox(children=[
                            self.p_per_sample, self.brightness_range,
                            self.gaussian_noise_variance, self.gamma_range
                        ],
                                     layout={'border': '1px solid black'})
                    ]),
                    self.fig_aug.canvas,
                ]),
                self.apply_aug_button,
                self.output_widget
            ]),
        ]

        self.generator = None

        self.output_widget.clear_output()
        with self.output_widget:
            print(self.ax_ori.shape, self.ax_aug.shape)
Ejemplo n.º 19
0
    def draw_workflow(self, start_end, workflow, name):
        assert isinstance(workflow, Workflow)

        start_s = start_end["seconds"]["start"]
        last_s = start_end["seconds"]["last"]
        start_t = start_end["time"]["start"]
        last_t = start_end["time"]["last"]

        # 1. get main stacked axis list
        main_stacked_results = []

        main_stacked_results.append(("BLANK", "#eeeeee", [
            (0, c.from_seconds) for c in chain.from_iterable(
                s.intervals for s in workflow.start_group.iter_nxtgroups())
        ]))

        for group in workflow.sort_topologically():
            _name = group.desc
            color = getcolor_byint(group.intervals[0], ignore_lr=True)
            main_stacked_results.append(
                (_name, color, [(c.from_seconds, c.to_seconds)
                                for c in group.intervals]))

        # 2. calc main_x_list from main_stacked_results
        types, colors, main_x_list, main_ys_list =\
                _generate_stackeddata(main_stacked_results)

        # 3. calc and check the arguments
        x_start_default = 0
        y_start_default = 0
        x_end_default = last_s * 1.05
        y_end_default = workflow.len_reqs

        def _interact(x, y, selects, color):
            with self._build_fig("workflowplot", name, figsize=(30, 6)) as fig:
                ax = fig.add_subplot(1, 1, 1)
                ax.set_xlabel("lapse (seconds)")
                ax.set_ylabel("requests")
                ax.set_xlim(x[0], x[1])
                ax.set_ylim(y[0], y[1])
                plot_colors = colors[:]
                if color:
                    for s in selects:
                        plot_colors[s] = color
                ax.annotate(start_t, xy=(0, 0), xytext=(0, 0))
                ax.annotate(last_t, xy=(last_s, 0), xytext=(last_s, 0))
                ax.plot([0, last_s], [0, 0], 'r*')

                ax.stackplot(main_x_list,
                             *main_ys_list,
                             colors=plot_colors,
                             edgecolor="black",
                             linewidth=.1)
                # ax.legend((mpatches.Patch(color=color) for color in colors), types)

        if self.out_path:
            _interact((x_start_default, x_end_default),
                      (y_start_default, y_end_default), [], None)
        else:
            from ipywidgets import widgets, interactive_output, fixed, Layout
            from IPython.display import display

            layout = Layout(width="99%")
            w_lapse = widgets.FloatRangeSlider(
                value=[x_start_default, x_end_default],
                min=x_start_default,
                max=x_end_default,
                step=0.0001,
                description='x-lapse:',
                continuous_update=False,
                readout_format='.4f',
                layout=layout)
            w_requests = widgets.IntRangeSlider(
                value=[y_start_default, y_end_default],
                min=y_start_default,
                max=y_end_default,
                description='y-requests:',
                continuous_update=False,
                layout=layout)
            w_range = widgets.VBox([w_lapse, w_requests])

            w_color = widgets.ColorPicker(concise=True,
                                          description='Highlight:',
                                          value='#ff40ff',
                                          layout=layout)
            options = [types[i] for i in range(1, len(types) - 1)]
            dedup = defaultdict(lambda: -1)
            for i, o in enumerate(options):
                dedup[o] += 1
                if dedup[o]:
                    options[i] = ("%s (%d)" % (options[i], dedup[o]))
            options = [(v, i + 1) for i, v in enumerate(options)]
            w_select = widgets.SelectMultiple(options=options,
                                              rows=min(10, len(options)),
                                              description='Steps:',
                                              layout=layout)
            w_highlight = widgets.VBox([w_color, w_select])

            w_tab = widgets.Tab()
            w_tab.children = [w_range, w_highlight]
            w_tab.set_title(0, "range")
            w_tab.set_title(1, "highlight")

            out = widgets.interactive_output(
                _interact, {
                    'x': w_lapse,
                    'y': w_requests,
                    'selects': w_select,
                    'color': w_color
                })
            display(w_tab, out)
Ejemplo n.º 20
0
    def slider(self,
               figsize=(8, 8),
               exclude_particle_records=['charge', 'mass'],
               **kw):
        """
        Navigate the simulation using a slider

        Parameters:
        -----------
        figsize: tuple
            Size of the figures

        exclude_particle_records: list of strings
            List of particle quantities that should not be displayed
            in the slider (typically because they are less interesting)

        kw: dict
            Extra arguments to pass to matplotlib's imshow
        """

        # -----------------------
        # Define useful functions
        # -----------------------

        def refresh_field(change=None, force=False):
            """
            Refresh the current field figure

            Parameters :
            ------------
            change: dictionary
                Dictionary passed by the widget to a callback functions
                whenever a change of a widget happens
                (see docstring of ipywidgets.Widget.observe)
                This is mainline a place holder ; not used in this function

            force: bool
                Whether to force the update
            """
            # Determine whether to do the refresh
            do_refresh = False
            if (self.avail_fields is not None):
                if force or fld_refresh_toggle.value:
                    do_refresh = True
            # Do the refresh
            if do_refresh:
                plt.figure(fld_figure_button.value, figsize=figsize)
                plt.clf()

                # When working in inline mode, in an ipython notebook,
                # clear the output (prevents the images from stacking
                # in the notebook)
                if 'inline' in matplotlib.get_backend():
                    clear_output()

                if fld_use_button.value:
                    i_power = fld_magnitude_button.value
                    vmin = fld_range_button.value[0] * 10**i_power
                    vmax = fld_range_button.value[1] * 10**i_power
                else:
                    vmin = None
                    vmax = None

                self.get_field(t=self.current_t,
                               output=False,
                               plot=True,
                               field=fieldtype_button.value,
                               coord=coord_button.value,
                               m=convert_to_int(mode_button.value),
                               slicing=slicing_button.value,
                               theta=theta_button.value,
                               slicing_dir=slicing_dir_button.value,
                               vmin=vmin,
                               vmax=vmax,
                               cmap=fld_color_button.value)

        def refresh_ptcl(change=None, force=False):
            """
            Refresh the current particle figure

            Parameters :
            ------------
            change: dictionary
                Dictionary passed by the widget to a callback functions
                whenever a change of a widget happens
                (see docstring of ipywidgets.Widget.observe)
                This is mainline a place holder ; not used in this function

            force: bool
                Whether to force the update
            """
            # Determine whether to do the refresh
            do_refresh = False
            if self.avail_species is not None:
                if force or ptcl_refresh_toggle.value:
                    do_refresh = True
            # Do the refresh
            if do_refresh:
                plt.figure(ptcl_figure_button.value, figsize=figsize)
                plt.clf()

                # When working in inline mode, in an ipython notebook,
                # clear the output (prevents the images from stacking
                # in the notebook)
                if 'inline' in matplotlib.get_backend():
                    clear_output()

                if ptcl_use_button.value:
                    i_power = ptcl_magnitude_button.value
                    vmin = ptcl_range_button.value[0] * 10**i_power
                    vmax = ptcl_range_button.value[1] * 10**i_power
                else:
                    vmin = None
                    vmax = None

                if ptcl_yaxis_button.value == 'None':
                    # 1D histogram
                    self.get_particle(t=self.current_t,
                                      output=False,
                                      var_list=[ptcl_xaxis_button.value],
                                      select=ptcl_select_widget.to_dict(),
                                      species=ptcl_species_button.value,
                                      plot=True,
                                      vmin=vmin,
                                      vmax=vmax,
                                      cmap=ptcl_color_button.value,
                                      nbins=ptcl_bins_button.value)
                else:
                    # 2D histogram
                    self.get_particle(t=self.current_t,
                                      output=False,
                                      var_list=[
                                          ptcl_xaxis_button.value,
                                          ptcl_yaxis_button.value
                                      ],
                                      select=ptcl_select_widget.to_dict(),
                                      species=ptcl_species_button.value,
                                      plot=True,
                                      vmin=vmin,
                                      vmax=vmax,
                                      cmap=ptcl_color_button.value,
                                      nbins=ptcl_bins_button.value)

        def refresh_field_type(change):
            """
            Refresh the field type and disable the coordinates buttons
            if the field is scalar.

            Parameter
            ---------
            change: dictionary
                Dictionary passed by the widget to a callback functions
                whenever a change of a widget happens
                (see docstring of ipywidgets.Widget.observe)
            """
            if self.avail_fields[change['new']] == 'scalar':
                coord_button.disabled = True
            elif self.avail_fields[change['new']] == 'vector':
                coord_button.disabled = False
            refresh_field()

        def refresh_species(change=None):
            """
            Refresh the particle species buttons by populating them
            with the available records for the current species

            Parameter
            ---------
            change: dictionary
                Dictionary passed by the widget to a callback functions
                whenever a change of a widget happens
                (see docstring of ipywidgets.Widget.observe)
            """
            # Deactivate the particle refreshing to avoid callback
            # while modifying the widgets
            saved_refresh_value = ptcl_refresh_toggle.value
            ptcl_refresh_toggle.value = False

            # Get available records for this species
            avail_records = [
                q for q in self.avail_record_components[
                    ptcl_species_button.value]
                if q not in exclude_particle_records
            ]
            # Update the plotting buttons
            ptcl_xaxis_button.options = avail_records
            ptcl_yaxis_button.options = avail_records + ['None']
            if ptcl_xaxis_button.value not in ptcl_xaxis_button.options:
                ptcl_xaxis_button.value = avail_records[0]
            if ptcl_yaxis_button.value not in ptcl_yaxis_button.options:
                ptcl_yaxis_button.value = 'None'

            # Update the selection widgets
            for dropdown_button in ptcl_select_widget.quantity:
                dropdown_button.options = avail_records

            # Put back the previous value of the refreshing button
            ptcl_refresh_toggle.value = saved_refresh_value

        def change_t(change):
            "Plot the result at the required time"
            self.current_t = 1.e-15 * change['new']
            refresh_field()
            refresh_ptcl()

        def step_fw(b):
            "Plot the result one iteration further"
            if self.current_i < len(self.t) - 1:
                self.current_t = self.t[self.current_i + 1]
            else:
                self.current_t = self.t[self.current_i]
            slider.value = self.current_t * 1.e15

        def step_bw(b):
            "Plot the result one iteration before"
            if self.current_t > 0:
                self.current_t = self.t[self.current_i - 1]
            else:
                self.current_t = self.t[self.current_i]
            slider.value = self.current_t * 1.e15

        # ---------------
        # Define widgets
        # ---------------

        # Slider
        slider = widgets.FloatSlider(
            min=math.ceil(1.e15 * self.tmin),
            max=math.ceil(1.e15 * self.tmax),
            step=math.ceil(1.e15 * (self.tmax - self.tmin)) / 20.,
            description="t (fs)")
        slider.observe(change_t, names='value', type='change')

        # Forward button
        button_p = widgets.Button(description="+")
        button_p.on_click(step_fw)

        # Backward button
        button_m = widgets.Button(description="-")
        button_m.on_click(step_bw)

        # Display the time widgets
        container = widgets.HBox(children=[button_m, button_p, slider])
        display(container)

        # Field widgets
        # -------------
        if (self.avail_fields is not None):

            # Field type
            # ----------
            # Field button
            fieldtype_button = widgets.ToggleButtons(
                description='Field:', options=sorted(self.avail_fields.keys()))
            fieldtype_button.observe(refresh_field_type, 'value', 'change')

            # Coord button
            if self.geometry == "thetaMode":
                coord_button = widgets.ToggleButtons(
                    description='Coord:', options=['x', 'y', 'z', 'r', 't'])
            elif self.geometry in ["2dcartesian", "3dcartesian"]:
                coord_button = widgets.ToggleButtons(description='Coord:',
                                                     options=['x', 'y', 'z'])
            coord_button.observe(refresh_field, 'value', 'change')
            # Mode and theta button (for thetaMode)
            mode_button = widgets.ToggleButtons(description='Mode:',
                                                options=self.avail_circ_modes)
            mode_button.observe(refresh_field, 'value', 'change')
            theta_button = widgets.FloatSlider(width=140,
                                               value=0.,
                                               description=r'Theta:',
                                               min=-math.pi / 2,
                                               max=math.pi / 2)
            theta_button.observe(refresh_field, 'value', 'change')
            # Slicing buttons (for 3D)
            slicing_dir_button = widgets.ToggleButtons(
                value=self.axis_labels[1],
                options=self.axis_labels,
                description='Slicing direction:')
            slicing_dir_button.observe(refresh_field, 'value', 'change')
            slicing_button = widgets.FloatSlider(width=150,
                                                 description='Slicing:',
                                                 min=-1.,
                                                 max=1.,
                                                 value=0.)
            slicing_button.observe(refresh_field, 'value', 'change')

            # Plotting options
            # ----------------
            # Figure number
            fld_figure_button = widgets.IntText(description='Figure ',
                                                value=0,
                                                width=50)
            # Range of values
            fld_range_button = widgets.FloatRangeSlider(min=-10,
                                                        max=10,
                                                        width=220)
            fld_range_button.observe(refresh_field, 'value', 'change')
            # Order of magnitude
            fld_magnitude_button = widgets.IntText(description='x 10^',
                                                   value=9,
                                                   width=50)
            fld_magnitude_button.observe(refresh_field, 'value', 'change')
            # Use button
            fld_use_button = widgets.Checkbox(description=' Use this range',
                                              value=False)
            fld_use_button.observe(refresh_field, 'value', 'change')
            # Colormap button
            fld_color_button = widgets.Select(options=sorted(
                plt.cm.datad.keys()),
                                              height=50,
                                              width=200,
                                              value='jet')
            fld_color_button.observe(refresh_field, 'value', 'change')
            # Resfresh buttons
            fld_refresh_toggle = widgets.ToggleButton(
                description='Always refresh', value=True)
            fld_refresh_button = widgets.Button(description='Refresh now!')
            fld_refresh_button.on_click(partial(refresh_field, force=True))

            # Containers
            # ----------
            # Field type container
            if self.geometry == "thetaMode":
                container_fields = widgets.VBox(width=260,
                                                children=[
                                                    fieldtype_button,
                                                    coord_button, mode_button,
                                                    theta_button
                                                ])
            elif self.geometry == "2dcartesian":
                container_fields = widgets.VBox(
                    width=260, children=[fieldtype_button, coord_button])
            elif self.geometry == "3dcartesian":
                container_fields = widgets.VBox(width=260,
                                                children=[
                                                    fieldtype_button,
                                                    coord_button,
                                                    slicing_dir_button,
                                                    slicing_button
                                                ])
            # Plotting options container
            container_fld_plots = widgets.VBox(width=260,
                                               children=[
                                                   fld_figure_button,
                                                   fld_range_button,
                                                   widgets.HBox(children=[
                                                       fld_magnitude_button,
                                                       fld_use_button
                                                   ],
                                                                height=50),
                                                   fld_color_button
                                               ])
            # Accordion for the field widgets
            accord1 = widgets.Accordion(
                children=[container_fields, container_fld_plots])
            accord1.set_title(0, 'Field type')
            accord1.set_title(1, 'Plotting options')
            # Complete field container
            container_fld = widgets.VBox(
                width=300,
                children=[
                    accord1,
                    widgets.HBox(
                        children=[fld_refresh_toggle, fld_refresh_button])
                ])

        # Particle widgets
        # ----------------
        if (self.avail_species is not None):

            # Particle quantities
            # -------------------
            # Species selection
            ptcl_species_button = widgets.Dropdown(width=250,
                                                   options=self.avail_species)
            ptcl_species_button.observe(refresh_species, 'value', 'change')
            # Get available records for this species
            avail_records = [
                q for q in self.avail_record_components[
                    ptcl_species_button.value]
                if q not in exclude_particle_records
            ]
            # Particle quantity on the x axis
            ptcl_xaxis_button = widgets.ToggleButtons(options=avail_records)
            ptcl_xaxis_button.observe(refresh_ptcl, 'value', 'change')
            # Particle quantity on the y axis
            ptcl_yaxis_button = widgets.ToggleButtons(options=avail_records +
                                                      ['None'],
                                                      value='None')
            ptcl_yaxis_button.observe(refresh_ptcl, 'value', 'change')

            # Particle selection
            # ------------------
            # 3 selection rules at maximum
            ptcl_select_widget = ParticleSelectWidget(3, avail_records,
                                                      refresh_ptcl)

            # Plotting options
            # ----------------
            # Figure number
            ptcl_figure_button = widgets.IntText(description='Figure ',
                                                 value=1,
                                                 width=50)
            # Number of bins
            ptcl_bins_button = widgets.IntText(description='nbins:',
                                               value=100,
                                               width=100)
            ptcl_bins_button.observe(refresh_ptcl, 'value', 'change')
            # Colormap button
            ptcl_color_button = widgets.Select(options=sorted(
                plt.cm.datad.keys()),
                                               height=50,
                                               width=200,
                                               value='Blues')
            ptcl_color_button.observe(refresh_ptcl, 'value', 'change')
            # Range of values
            ptcl_range_button = widgets.FloatRangeSlider(min=0,
                                                         max=10,
                                                         width=220,
                                                         value=(0, 5))
            ptcl_range_button.observe(refresh_ptcl, 'value', 'change')
            # Order of magnitude
            ptcl_magnitude_button = widgets.IntText(description='x 10^',
                                                    value=9,
                                                    width=50)
            ptcl_magnitude_button.observe(refresh_ptcl, 'value', 'change')
            # Use button
            ptcl_use_button = widgets.Checkbox(description=' Use this range',
                                               value=False)
            ptcl_use_button.observe(refresh_ptcl, 'value', 'change')
            # Resfresh buttons
            ptcl_refresh_toggle = widgets.ToggleButton(
                description='Always refresh', value=True)
            ptcl_refresh_button = widgets.Button(description='Refresh now!')
            ptcl_refresh_button.on_click(partial(refresh_ptcl, force=True))

            # Containers
            # ----------
            # Particle quantity container
            container_ptcl_quantities = widgets.VBox(width=310,
                                                     children=[
                                                         ptcl_species_button,
                                                         ptcl_xaxis_button,
                                                         ptcl_yaxis_button
                                                     ])
            # Particle selection container
            container_ptcl_select = ptcl_select_widget.to_container()
            # Plotting options container
            container_ptcl_plots = widgets.VBox(width=310,
                                                children=[
                                                    ptcl_figure_button,
                                                    ptcl_bins_button,
                                                    ptcl_range_button,
                                                    widgets.HBox(children=[
                                                        ptcl_magnitude_button,
                                                        ptcl_use_button
                                                    ],
                                                                 height=50),
                                                    ptcl_color_button
                                                ])
            # Accordion for the field widgets
            accord2 = widgets.Accordion(children=[
                container_ptcl_quantities, container_ptcl_select,
                container_ptcl_plots
            ])
            accord2.set_title(0, 'Particle quantities')
            accord2.set_title(1, 'Particle selection')
            accord2.set_title(2, 'Plotting options')
            # Complete particle container
            container_ptcl = widgets.VBox(
                width=370,
                children=[
                    accord2,
                    widgets.HBox(
                        children=[ptcl_refresh_toggle, ptcl_refresh_button])
                ])

        # Global container
        if (self.avail_fields is not None) and \
                (self.avail_species is not None):
            global_container = widgets.HBox(
                children=[container_fld, container_ptcl])
            display(global_container)
        elif self.avail_species is None:
            display(container_fld)
        elif self.avail_fields is None:
            display(container_ptcl)
Ejemplo n.º 21
0
def do_plots_gui():
    year = dt.datetime.now().year
    last_dates = check_dates()
    product = "TAMSAT"
    product_sel = widgets.RadioButtons(
        options=["TAMSAT", "ERA5", "MODIS"],
        value="TAMSAT",
        description="Product family",
    )
    variable_sel = widgets.Dropdown(options=variable_lists[product],
                                    description="Variable")

    months_this_year = widgets.Select(
        options=range(1, last_dates[product] + 1),
        description="Month current year",
    )
    anomaly_calc = widgets.Checkbox(value=True,
                                    description="Do anomaly calculations")
    sel_cmap = widgets.Dropdown(options=diverging_cmaps, value="Spectral")
    sel_boundz = widgets.FloatRangeSlider(
        value=[-2.5, 2.5],
        min=-6,
        max=6,
        step=0.1,
        description="Scale for anomaly colormap",
    )
    sel_period = widgets.RadioButtons(options=periods,
                                      description="LTA temporal window",
                                      disabled=False)

    def on_value_change(change):
        product = change.new
        variable_sel.options = variable_lists[product]

    def on_value_change2(change):
        product = change.new
        months_this_year.options = range(1, last_dates[product] + 1)

    product_sel.observe(on_value_change, "value")
    product_sel.observe(on_value_change2, "value")

    def on_value_change3(change):
        anomaly = change.new
        if anomaly:
            sel_cmap.options = diverging_cmaps
        else:
            sel_cmap.options = uniform_cmaps

    def on_value_change4(change):
        anomaly = change.new
        if anomaly:
            sel_boundz.min = -6
            sel_boundz.max = 6
            sel_boundz.value = [-2.5, 2.5]
            sel_boundz.step = 0.1
        else:
            sel_boundz.min = 0
            sel_boundz.max = 100
            sel_boundz.value = [5, 95]
            sel_boundz.step = 1

    def on_value_change5(change):
        anomaly = change.new
        if anomaly:
            sel_period.layout.visibility = "visible"
        else:
            sel_period.layout.visibility = "hidden"

    anomaly_calc.observe(on_value_change3, "value")
    anomaly_calc.observe(on_value_change4, "value")
    anomaly_calc.observe(on_value_change5, "value")

    def do_plots(**kwds):
        if kwds["anomaly"]:
            plot_anomaly(
                kwds["product"],
                kwds["variable"],
                year,
                kwds["month"],
                kwds["cmap"],
                kwds["boundz"],
                kwds["lta_period"],
            )
        else:
            plot_field(
                kwds["product"],
                kwds["variable"],
                year,
                kwds["month"],
                kwds["cmap"],
                kwds["boundz"],
            )

    widgets.interact_manual(
        do_plots,
        anomaly=anomaly_calc,
        product=product_sel,
        variable=variable_sel,
        cmap=sel_cmap,
        boundz=sel_boundz,
        month=months_this_year,
        lta_period=sel_period,
    )
Ejemplo n.º 22
0
    def _generate_widgets_macro_stress_strain(self):

        data = [{
            'x': self.exp_tensile_test.eng_strain,
            'y': self.exp_tensile_test.eng_stress,
            'name': 'Experimental',
            'line': {
                'color': DEFAULT_PLOTLY_COLORS[0],
            },
        }, {
            'x': [self.exp_tensile_test.plastic_range[0]] * 2 + [None] +
            [self.exp_tensile_test.plastic_range[1]] * 2,
            'y': [
                -HardeningLawFitter.FIG_PAD[1],
                HardeningLawFitter.FIG_PAD[1] +
                self.exp_tensile_test.max_stress,
                None,
                -HardeningLawFitter.FIG_PAD[1],
                HardeningLawFitter.FIG_PAD[1] +
                self.exp_tensile_test.max_stress,
            ],
            'mode':
            'lines',
            'line': {
                'color': '#888',
                'width': 2,
            },
            'showlegend':
            False,
        }]
        layout = {
            'title': 'Experimental Data',
            'width': HardeningLawFitter.FIG_WIDTH,
            'height': HardeningLawFitter.FIG_HEIGHT,
            'margin': HardeningLawFitter.FIG_MARG,
            'xaxis': {
                'title':
                'Engineering strain, ε',
                'range': [
                    -HardeningLawFitter.FIG_PAD[0],
                    HardeningLawFitter.FIG_PAD[0] +
                    self.exp_tensile_test.max_strain
                ],
            },
            'yaxis': {
                'title':
                'Engineering stress, σ / MPa',
                'range': [
                    -HardeningLawFitter.FIG_PAD[1],
                    HardeningLawFitter.FIG_PAD[1] +
                    self.exp_tensile_test.max_stress
                ],
            },
        }

        widget_ss_type = widgets.RadioButtons(
            options=['Engineering', 'True'],
            description='Stress/strain:',
            value='Engineering',
        )
        plastic_range_widget = widgets.FloatRangeSlider(
            value=self.exp_tensile_test.plastic_range,
            step=0.005,
            min=self.exp_tensile_test.min_true_strain,
            max=self.exp_tensile_test.max_true_strain,
            description='Plastic range:',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout_format='.4f',
            layout=widgets.Layout(width='90%'),
        )
        widget_ss_type.observe(self._update_widgets_stress_strain_type,
                               names='value')
        plastic_range_widget.observe(self._update_widgets_plastic_range,
                                     names='value')
        out = {
            'fig': go.FigureWidget(data=data, layout=layout),
            'fig_trace_idx': {
                'macro_stress_strain': [0],
                'plastic_range_boundaries': [1],
            },
            'controls': {
                'stress_strain_type': widget_ss_type,
                'plastic_range': plastic_range_widget,
            },
        }

        return out
Ejemplo n.º 23
0
import numpy as np

py.init_notebook_mode()

# load fig
fig = plotly.plotly.get_figure("https://plot.ly/~jordanpeterson/889")

# find the range of the slider.
xmin, xmax = fig['layout']['xaxis']['range']

# create FigureWidget from fig
f = go.FigureWidget(data=fig.data, layout=fig.layout)

slider = widgets.FloatRangeSlider(
    min=xmin,
    max=xmax,
    step=(xmax - xmin) / 1000.0,
    readout=False,
    description='Time')
slider.layout.width = '800px'


# our function that will modify the xaxis range
def update_range(y):
    f.layout.xaxis.range = [y[0], y[1]]


# display the FigureWidget and slider with center justification
vb = VBox((f, interactive(update_range, y=slider)))
vb.layout.align_items = 'center'
vb
Ejemplo n.º 24
0
    def slider(self, figsize=(10, 10), **kw):
        """
        Navigate the simulation using a slider

        Parameters :
        ------------
        figsize: tuple
            Size of the figures

        kw: dict
            Extra arguments to pass to matplotlib's imshow
        """

        # -----------------------
        # Define useful functions
        # -----------------------

        def refresh_field(force=False):
            "Refresh the current field figure"

            # Determine whether to do the refresh
            do_refresh = False
            if (self.avail_fields is not None):
                if force == True or fld_refresh_toggle.value == True:
                    do_refresh = True
            # Do the refresh
            if do_refresh == True:
                plt.figure(fld_figure_button.value, figsize=figsize)
                plt.clf()

                # When working in inline mode, in an ipython notebook,
                # clear the output (prevents the images from stacking
                # in the notebook)
                if 'inline' in matplotlib.get_backend():
                    clear_output()

                if fld_use_button.value == True:
                    i_power = fld_magnitude_button.value
                    vmin = fld_range_button.value[0] * 10**i_power
                    vmax = fld_range_button.value[1] * 10**i_power
                else:
                    vmin = None
                    vmax = None

                self.get_field(t=self.current_t,
                               output=False,
                               plot=True,
                               field=fieldtype_button.value,
                               coord=coord_button.value,
                               m=convert_to_int(mode_button.value),
                               slicing=slicing_button.value,
                               theta=theta_button.value,
                               slicing_dir=slicing_dir_button.value,
                               vmin=vmin,
                               vmax=vmax,
                               cmap=fld_color_button.value)

        def refresh_ptcl(force=False):
            "Refresh the current particle figure"

            # Determine whether to do the refresh
            do_refresh = False
            if self.avail_species is not None:
                if force == True or ptcl_refresh_toggle.value == True:
                    do_refresh = True
            # Do the refresh
            if do_refresh == True:
                plt.figure(ptcl_figure_button.value, figsize=figsize)
                plt.clf()

                # When working in inline mode, in an ipython notebook,
                # clear the output (prevents the images from stacking
                # in the notebook)
                if 'inline' in matplotlib.get_backend():
                    clear_output()

                if ptcl_use_button.value == True:
                    i_power = ptcl_magnitude_button.value
                    vmin = ptcl_range_button.value[0] * 10**i_power
                    vmax = ptcl_range_button.value[1] * 10**i_power
                else:
                    vmin = None
                    vmax = None

                if ptcl_yaxis_button.value == 'None':
                    # 1D histogram
                    self.get_particle(t=self.current_t,
                                      output=False,
                                      var_list=[ptcl_xaxis_button.value],
                                      select=ptcl_select_widget.to_dict(),
                                      species=ptcl_species_button.value,
                                      plot=True,
                                      vmin=vmin,
                                      vmax=vmax,
                                      cmap=ptcl_color_button.value,
                                      nbins=ptcl_bins_button.value)
                else:
                    # 2D histogram
                    self.get_particle(t=self.current_t,
                                      output=False,
                                      var_list=[
                                          ptcl_xaxis_button.value,
                                          ptcl_yaxis_button.value
                                      ],
                                      select=ptcl_select_widget.to_dict(),
                                      species=ptcl_species_button.value,
                                      plot=True,
                                      vmin=vmin,
                                      vmax=vmax,
                                      cmap=ptcl_color_button.value,
                                      nbins=ptcl_bins_button.value)

        def refresh_ptcl_now(b):
            "Refresh the particles immediately"
            refresh_ptcl(force=True)

        def refresh_fld_now(b):
            "Refresh the fields immediately"
            refresh_field(force=True)

        def change_t(name, value):
            "Plot the result at the required time"
            self.current_t = 1.e-15 * value
            refresh_field()
            refresh_ptcl()

        def step_fw(b):
            "Plot the result one iteration further"
            if self.current_i < len(self.t) - 1:
                self.current_t = self.t[self.current_i + 1]
            else:
                print("Reached last iteration.")
                self.current_t = self.t[self.current_i]
            slider.value = self.current_t * 1.e15

        def step_bw(b):
            "Plot the result one iteration before"
            if self.current_t > 0:
                self.current_t = self.t[self.current_i - 1]
            else:
                print("Reached first iteration.")
                self.current_t = self.t[self.current_i]
            slider.value = self.current_t * 1.e15

        # ---------------
        # Define widgets
        # ---------------

        # Slider
        slider = widgets.FloatSlider(
            min=math.ceil(1.e15 * self.tmin),
            max=math.ceil(1.e15 * self.tmax),
            step=math.ceil(1.e15 * (self.tmax - self.tmin)) / 20.,
            description="t (fs)")
        slider.on_trait_change(change_t, 'value')

        # Forward button
        button_p = widgets.Button(description="+")
        button_p.on_click(step_fw)

        # Backward button
        button_m = widgets.Button(description="-")
        button_m.on_click(step_bw)

        # Display the time widgets
        container = widgets.HBox(children=[button_m, button_p, slider])
        display(container)

        # Field widgets
        # -------------
        if (self.avail_fields is not None):

            # Field type
            # ----------
            # Field button
            fieldtype_button = widgets.ToggleButtons(
                description='Field:', options=sorted(self.avail_fields.keys()))
            fieldtype_button.on_trait_change(refresh_field)

            # Coord button
            if self.geometry == "thetaMode":
                coord_button = widgets.ToggleButtons(
                    description='Coord:', options=['x', 'y', 'z', 'r', 't'])
            elif self.geometry in ["2dcartesian", "3dcartesian"]:
                coord_button = widgets.ToggleButtons(description='Coord:',
                                                     options=['x', 'y', 'z'])
            coord_button.on_trait_change(refresh_field)
            # Mode and theta button (for thetaMode)
            mode_button = widgets.ToggleButtons(description='Mode:',
                                                options=self.avail_circ_modes)
            mode_button.on_trait_change(refresh_field)
            theta_button = widgets.FloatSlider(width=140,
                                               value=0.,
                                               description=r'Theta:',
                                               min=-math.pi / 2,
                                               max=math.pi / 2)
            theta_button.on_trait_change(refresh_field)
            # Slicing buttons (for 3D)
            slicing_dir_button = widgets.ToggleButtons(
                value='y',
                description='Slicing direction:',
                options=['x', 'y', 'z'])
            slicing_dir_button.on_trait_change(refresh_field)
            slicing_button = widgets.FloatSlider(width=150,
                                                 description='Slicing:',
                                                 min=-1.,
                                                 max=1.,
                                                 value=0.)
            slicing_button.on_trait_change(refresh_field)

            # Plotting options
            # ----------------
            # Figure number
            fld_figure_button = widgets.IntText(description='Figure ',
                                                value=0,
                                                width=50)
            # Range of values
            fld_range_button = widgets.FloatRangeSlider(min=-10,
                                                        max=10,
                                                        width=220)
            fld_range_button.on_trait_change(refresh_field)
            # Order of magnitude
            fld_magnitude_button = widgets.IntText(description='x 10^',
                                                   value=9,
                                                   width=50)
            fld_magnitude_button.on_trait_change(refresh_field)
            # Use button
            fld_use_button = widgets.Checkbox(description=' Use this range',
                                              value=False)
            fld_use_button.on_trait_change(refresh_field)
            # Colormap button
            fld_color_button = widgets.Select(options=sorted(
                plt.cm.datad.keys()),
                                              height=50,
                                              width=200,
                                              value='jet')
            fld_color_button.on_trait_change(refresh_field)
            # Resfresh buttons
            fld_refresh_toggle = widgets.ToggleButton(
                description='Always refresh', value=True)
            fld_refresh_button = widgets.Button(description='Refresh now!')
            fld_refresh_button.on_click(refresh_fld_now)

            # Containers
            # ----------
            # Field type container
            if self.geometry == "thetaMode":
                container_fields = widgets.VBox(width=260,
                                                children=[
                                                    fieldtype_button,
                                                    coord_button, mode_button,
                                                    theta_button
                                                ])
            elif self.geometry == "2dcartesian":
                container_fields = widgets.VBox(
                    width=260, children=[fieldtype_button, coord_button])
            elif self.geometry == "3dcartesian":
                container_fields = widgets.VBox(width=260,
                                                children=[
                                                    fieldtype_button,
                                                    coord_button,
                                                    slicing_dir_button,
                                                    slicing_button
                                                ])
            # Plotting options container
            container_fld_plots = widgets.VBox(width=260,
                                               children=[
                                                   fld_figure_button,
                                                   fld_range_button,
                                                   widgets.HBox(children=[
                                                       fld_magnitude_button,
                                                       fld_use_button
                                                   ],
                                                                height=50),
                                                   fld_color_button
                                               ])
            # Accordion for the field widgets
            accord1 = widgets.Accordion(
                children=[container_fields, container_fld_plots])
            accord1.set_title(0, 'Field type')
            accord1.set_title(1, 'Plotting options')
            # Complete field container
            container_fld = widgets.VBox(
                width=300,
                children=[
                    accord1,
                    widgets.HBox(
                        children=[fld_refresh_toggle, fld_refresh_button])
                ])

        # Particle widgets
        # ----------------
        if (self.avail_species is not None):

            # Particle quantities
            # -------------------
            # Species selection
            ptcl_species_button = widgets.Dropdown(width=250,
                                                   options=self.avail_species)
            ptcl_species_button.on_trait_change(refresh_ptcl)
            # Remove charge and mass (less interesting)
            avail_ptcl_quantities = [ q for q in self.avail_ptcl_quantities \
                        if (q in ['charge', 'mass'])==False ]
            # Particle quantity on the x axis
            ptcl_xaxis_button = widgets.ToggleButtons(
                value='z', options=avail_ptcl_quantities)
            ptcl_xaxis_button.on_trait_change(refresh_ptcl)
            # Particle quantity on the y axis
            ptcl_yaxis_button = widgets.ToggleButtons(
                value='x', options=avail_ptcl_quantities + ['None'])
            ptcl_yaxis_button.on_trait_change(refresh_ptcl)

            # Particle selection
            # ------------------
            # 3 selection rules at maximum
            ptcl_select_widget = ParticleSelectWidget(3, avail_ptcl_quantities,
                                                      refresh_ptcl)

            # Plotting options
            # ----------------
            # Figure number
            ptcl_figure_button = widgets.IntText(description='Figure ',
                                                 value=1,
                                                 width=50)
            # Number of bins
            ptcl_bins_button = widgets.IntSlider(description='nbins:',
                                                 min=50,
                                                 max=300,
                                                 value=100,
                                                 width=150)
            ptcl_bins_button.on_trait_change(refresh_ptcl)
            # Colormap button
            ptcl_color_button = widgets.Select(options=sorted(
                plt.cm.datad.keys()),
                                               height=50,
                                               width=200,
                                               value='Blues')
            ptcl_color_button.on_trait_change(refresh_ptcl)
            # Range of values
            ptcl_range_button = widgets.FloatRangeSlider(min=0,
                                                         max=10,
                                                         width=220,
                                                         value=(0, 5))
            ptcl_range_button.on_trait_change(refresh_ptcl)
            # Order of magnitude
            ptcl_magnitude_button = widgets.IntText(description='x 10^',
                                                    value=9,
                                                    width=50)
            ptcl_magnitude_button.on_trait_change(refresh_ptcl)
            # Use button
            ptcl_use_button = widgets.Checkbox(description=' Use this range',
                                               value=False)
            ptcl_use_button.on_trait_change(refresh_ptcl)
            # Resfresh buttons
            ptcl_refresh_toggle = widgets.ToggleButton(
                description='Always refresh', value=True)
            ptcl_refresh_button = widgets.Button(description='Refresh now!')
            ptcl_refresh_button.on_click(refresh_ptcl_now)

            # Containers
            # ----------
            # Particle quantity container
            container_ptcl_quantities = widgets.VBox(width=310,
                                                     children=[
                                                         ptcl_species_button,
                                                         ptcl_xaxis_button,
                                                         ptcl_yaxis_button
                                                     ])
            # Particle selection container
            container_ptcl_select = ptcl_select_widget.to_container()
            # Plotting options container
            container_ptcl_plots = widgets.VBox(width=310,
                                                children=[
                                                    ptcl_figure_button,
                                                    ptcl_bins_button,
                                                    ptcl_range_button,
                                                    widgets.HBox(children=[
                                                        ptcl_magnitude_button,
                                                        ptcl_use_button
                                                    ],
                                                                 height=50),
                                                    ptcl_color_button
                                                ])
            # Accordion for the field widgets
            accord2 = widgets.Accordion(children=[
                container_ptcl_quantities, container_ptcl_select,
                container_ptcl_plots
            ])
            accord2.set_title(0, 'Particle quantities')
            accord2.set_title(1, 'Particle selection')
            accord2.set_title(2, 'Plotting options')
            # Complete particle container
            container_ptcl = widgets.VBox(
                width=370,
                children=[
                    accord2,
                    widgets.HBox(
                        children=[ptcl_refresh_toggle, ptcl_refresh_button])
                ])

        # Global container
        if (self.avail_fields is not None) and \
          (self.avail_species is not None):
            global_container = widgets.HBox(
                children=[container_fld, container_ptcl])
            display(global_container)
        elif self.avail_species is None:
            display(container_fld)
        elif self.avail_fields is None:
            display(container_ptcl)
Ejemplo n.º 25
0
    def __init__(self, parent):
        """Set all widgets for the tab.

        Parameters:
            parent (ipywidget) : The parent widget object embeding this tab
        """
        self.parent = parent
        self.title = 'Set Calibrant'
        self.alpha = 0.5  # The transparency value of the overlay
        self.clim = (0.5, 0.9)  # Standard clim (max alwys 1)
        self.img = None  # Image to be overlayed
        energy = 10e3  # [eV] Photon energy, default value can be overwirtten
        # Convert the energy to wave-length
        self.wave_length = self._energy2lambda(energy)
        self.calibrant = 'None'  # Calibrant material
        self.pxsize = 0.2 / 1000  # [mm] Standard detector pixel size
        self.cdist = 0.2  # [m] Standard probe distance
        # Get all calibrants defined in pyFAI
        self.calibrants = [self.calibrant] + calibrants
        # Calibrant selection
        self.calib_btn = widgets.Dropdown(options=self.calibrants,
                                          value='None',
                                          description='Calibrant',
                                          disabled=False,
                                          layout=Layout(width='250px',
                                                        height='30px'))
        # Probe distance selection
        self.dist_btn = widgets.BoundedFloatText(value=self.cdist,
                                                 min=0,
                                                 max=10000,
                                                 layout=Layout(width='150px',
                                                               height='30px'),
                                                 step=0.01,
                                                 disabled=False,
                                                 description='Distance [m]')
        # Photon energy selection
        self.energy_btn = widgets.BoundedFloatText(value=energy,
                                                   min=3000,
                                                   max=100000,
                                                   layout=Layout(
                                                       width='200px',
                                                       height='30px'),
                                                   step=1,
                                                   disabled=False,
                                                   description='Energy [eV]')
        # Pixel size selection
        self.pxsize_btn = widgets.BoundedFloatText(
            value=self.cdist,
            min=0,
            max=20,
            layout=Layout(width='150px', height='30px'),
            step=0.01,
            disabled=False,
            description='Pixel Size [mm]')
        # Apply button to display the ring structure
        self.aply_btn = widgets.Button(description='Apply',
                                       disabled=False,
                                       button_style='',
                                       icon='',
                                       tooltip='Apply Material',
                                       layout=Layout(width='100px',
                                                     height='30px'))
        # Clear button to delete overlay
        self.clr_btn = widgets.Button(description='Clear',
                                      tooltip='Do not show overlay',
                                      disabled=False,
                                      button_style='',
                                      icon='',
                                      layout=Layout(width='100px',
                                                    height='30px'))
        # Set transparency value
        self.alpha_slider = widgets.FloatSlider(value=self.alpha,
                                                min=0,
                                                max=1,
                                                step=0.01,
                                                description='Transparancy:',
                                                orientation='horizontal',
                                                readout=True,
                                                readout_format='.2f',
                                                layout=Layout(width='40%'))
        # Set clim
        self.val_slider = widgets.FloatRangeSlider(value=self.clim,
                                                   min=0,
                                                   max=1,
                                                   step=0.01,
                                                   description='Range:',
                                                   disabled=False,
                                                   continuous_update=False,
                                                   orientation='horizontal',
                                                   readout=True,
                                                   readout_format='.2f',
                                                   layout=Layout(width='40%'))
        # Arange the buttons
        self.row1 = widgets.HBox(
            [self.calib_btn, self.dist_btn, self.energy_btn, self.pxsize_btn])
        self.row2 = widgets.HBox(
            [self.val_slider, self.alpha_slider, self.aply_btn, self.clr_btn])
        # Connect all methods to the buttons
        self.val_slider.observe(self._set_clim, names='value')
        self.alpha_slider.observe(self._set_alpha, names='value')
        self.clr_btn.on_click(self._clear_overlay)
        self.aply_btn.on_click(self._draw_overlay)
        self.calib_btn.observe(self._set_calibrant, names='value')
        self.pxsize_btn.observe(self._set_pxsize, names='value')
        self.energy_btn.observe(self._set_wavelength, names='value')
        self.dist_btn.observe(self._set_cdist, names='value')
        super(widgets.VBox, self).__init__([self.row1, self.row2])