def tool_handler(self, doc):
        from bokeh.layouts import row, column, widgetbox, Spacer
        from bokeh.models.mappers import LinearColorMapper
        from bokeh.models import widgets
        from bokeh.models.widgets.markups import Div
        from bokeh.plotting import figure

        self.arr = self.arr.copy(deep=True)

        if not isinstance(self.arr, xr.Dataset):
            self.use_dataset = False

        residual = None
        if self.use_dataset:
            raw_data = self.arr.data
            raw_data.values[np.isnan(raw_data.values)] = 0
            fit_results = self.arr.results
            residual = self.arr.residual
            residual.values[np.isnan(residual.values)] = 0
        else:
            raw_data = self.arr.attrs['original_data']
            fit_results = self.arr

        fit_direction = [d for d in raw_data.dims if d not in fit_results.dims]
        fit_direction = fit_direction[0]

        two_dimensional = False
        if len(raw_data.dims) != 2:
            two_dimensional = True
            x_coords, y_coords = fit_results.coords[
                fit_results.dims[0]], fit_results.coords[fit_results.dims[1]]
            z_coords = raw_data.coords[fit_direction]
        else:
            x_coords, y_coords = raw_data.coords[
                raw_data.dims[0]], raw_data.coords[raw_data.dims[1]]

        if two_dimensional:
            self.settings['palette'] = 'coolwarm'
        default_palette = self.default_palette

        self.app_context.update({
            'data': raw_data,
            'fits': fit_results,
            'residual': residual,
            'original': self.arr,
            'data_range': {
                'x': (np.min(x_coords.values), np.max(x_coords.values)),
                'y': (np.min(y_coords.values), np.max(y_coords.values)),
            }
        })
        if two_dimensional:
            self.app_context['data_range']['z'] = (np.min(z_coords.values),
                                                   np.max(z_coords.values))

        figures, plots, app_widgets = self.app_context['figures'], self.app_context['plots'],\
                                      self.app_context['widgets']

        self.cursor_dims = raw_data.dims
        if two_dimensional:
            self.cursor = [
                np.mean(self.data_range['x']),
                np.mean(self.data_range['y']),
                np.mean(self.data_range['z'])
            ]
        else:
            self.cursor = [
                np.mean(self.data_range['x']),
                np.mean(self.data_range['y'])
            ]

        app_widgets['fit_info_div'] = Div(text='')

        self.app_context['color_maps']['main'] = LinearColorMapper(
            default_palette,
            low=np.min(raw_data.values),
            high=np.max(raw_data.values),
            nan_color='black')

        main_tools = ["wheel_zoom", "tap", "reset", "save"]
        main_title = 'Fit Inspection Tool: WARNING Unidentified'

        try:
            main_title = 'Fit Inspection Tool: {}'.format(
                raw_data.S.label[:60])
        except:
            pass

        figures['main'] = figure(tools=main_tools,
                                 plot_width=self.app_main_size,
                                 plot_height=self.app_main_size,
                                 min_border=10,
                                 min_border_left=50,
                                 toolbar_location='left',
                                 x_axis_location='below',
                                 y_axis_location='right',
                                 title=main_title,
                                 x_range=self.data_range['x'],
                                 y_range=self.app_context['data_range']['y'])
        figures['main'].xaxis.axis_label = raw_data.dims[0]
        figures['main'].yaxis.axis_label = raw_data.dims[1]
        figures['main'].toolbar.logo = None
        figures['main'].background_fill_color = "#fafafa"

        data_for_main = raw_data
        if two_dimensional:
            data_for_main = data_for_main.sel(**dict(
                [[fit_direction, self.cursor[2]]]),
                                              method='nearest')
        plots['main'] = figures['main'].image(
            [data_for_main.values.T],
            x=self.app_context['data_range']['x'][0],
            y=self.app_context['data_range']['y'][0],
            dw=self.app_context['data_range']['x'][1] -
            self.app_context['data_range']['x'][0],
            dh=self.app_context['data_range']['y'][1] -
            self.app_context['data_range']['y'][0],
            color_mapper=self.app_context['color_maps']['main'])

        band_centers = [b.center for b in fit_results.F.bands.values()]
        bands_xs = [b.coords[b.dims[0]].values for b in band_centers]
        bands_ys = [b.values for b in band_centers]
        if fit_results.dims[0] == raw_data.dims[1]:
            bands_ys, bands_xs = bands_xs, bands_ys
        plots['band_locations'] = figures['main'].multi_line(
            xs=bands_xs,
            ys=bands_ys,
            line_color='white',
            line_width=1,
            line_dash='dashed')

        # add cursor lines
        cursor_lines = self.add_cursor_lines(figures['main'])

        # marginals
        if not two_dimensional:
            figures['bottom'] = figure(plot_width=self.app_main_size,
                                       plot_height=self.app_marginal_size,
                                       min_border=10,
                                       title=None,
                                       x_range=figures['main'].x_range,
                                       x_axis_location='above',
                                       toolbar_location=None,
                                       tools=[])
        else:
            figures['bottom'] = Spacer(width=self.app_main_size,
                                       height=self.app_marginal_size)

        right_y_range = figures['main'].y_range
        if two_dimensional:
            right_y_range = self.data_range['z']

        figures['right'] = figure(plot_width=self.app_marginal_size,
                                  plot_height=self.app_main_size,
                                  min_border=10,
                                  title=None,
                                  y_range=right_y_range,
                                  y_axis_location='left',
                                  toolbar_location=None,
                                  tools=[])

        marginal_line_width = 2
        if not two_dimensional:
            bottom_data = raw_data.sel(**dict(
                [[raw_data.dims[1], self.cursor[1]]]),
                                       method='nearest')
            right_data = raw_data.sel(**dict(
                [[raw_data.dims[0], self.cursor[0]]]),
                                      method='nearest')

            plots['bottom'] = figures['bottom'].line(
                x=bottom_data.coords[raw_data.dims[0]].values,
                y=bottom_data.values,
                line_width=marginal_line_width)
            plots['bottom_residual'] = figures['bottom'].line(
                x=[], y=[], line_color='red', line_width=marginal_line_width)
            plots['bottom_fit'] = figures['bottom'].line(
                x=[],
                y=[],
                line_color='blue',
                line_width=marginal_line_width,
                line_dash='dashed')
            plots['bottom_init_fit'] = figures['bottom'].line(
                x=[],
                y=[],
                line_color='green',
                line_width=marginal_line_width,
                line_dash='dotted')

            plots['right'] = figures['right'].line(
                y=right_data.coords[raw_data.dims[1]].values,
                x=right_data.values,
                line_width=marginal_line_width)
            plots['right_residual'] = figures['right'].line(
                x=[], y=[], line_color='red', line_width=marginal_line_width)
            plots['right_fit'] = figures['right'].line(
                x=[],
                y=[],
                line_color='blue',
                line_width=marginal_line_width,
                line_dash='dashed')
            plots['right_init_fit'] = figures['right'].line(
                x=[],
                y=[],
                line_color='green',
                line_width=marginal_line_width,
                line_dash='dotted')
        else:
            right_data = raw_data.sel(**{
                k: v
                for k, v in self.cursor_dict.items() if k != fit_direction
            },
                                      method='nearest')
            plots['right'] = figures['right'].line(
                y=right_data.coords[right_data.dims[0]].values,
                x=right_data.values,
                line_width=marginal_line_width)
            plots['right_residual'] = figures['right'].line(
                x=[], y=[], line_color='red', line_width=marginal_line_width)
            plots['right_fit'] = figures['right'].line(
                x=[],
                y=[],
                line_color='blue',
                line_width=marginal_line_width,
                line_dash='dashed')
            plots['right_init_fit'] = figures['right'].line(
                x=[],
                y=[],
                line_color='green',
                line_width=marginal_line_width,
                line_dash='dotted')

        def on_change_main_view(attr, old, data_source):
            self.selected_data = data_source
            data = None
            if data_source == 'data':
                data = raw_data.sel(**{
                    k: v
                    for k, v in self.cursor_dict.items() if k == fit_direction
                },
                                    method='nearest')
            elif data_source == 'residual':
                data = residual.sel(**{
                    k: v
                    for k, v in self.cursor_dict.items() if k == fit_direction
                },
                                    method='nearest')
            elif two_dimensional:
                data = fit_results.F.s(data_source)
                data.values[np.isnan(data.values)] = 0

            if data is not None:
                if self.remove_outliers:
                    data = data.T.clean_outliers(clip=self.outlier_clip)

                plots['main'].data_source.data = {
                    'image': [data.values.T],
                }
                update_main_colormap(None, None, main_color_range_slider.value)

        def update_fit_display():
            target = 'right'
            if fit_results.dims[0] == raw_data.dims[1]:
                target = 'bottom'

            if two_dimensional:
                target = 'right'
                current_fit = fit_results.sel(**{
                    k: v
                    for k, v in self.cursor_dict.items() if k != fit_direction
                },
                                              method='nearest').item()
                coord_vals = raw_data.coords[fit_direction].values
            else:
                current_fit = fit_results.sel(**dict([[
                    fit_results.dims[0],
                    self.cursor[0 if target == 'right' else 1]
                ]]),
                                              method='nearest').item()
                coord_vals = raw_data.coords[
                    raw_data.dims[0 if target == 'bottom' else 1]].values

            if current_fit is not None:
                app_widgets['fit_info_div'].text = current_fit._repr_html_(
                    short=True)  # pylint: disable=protected-access
            else:
                app_widgets['fit_info_div'].text = 'No fit here.'
                plots['{}_residual'.format(target)].data_source.data = {
                    'x': [],
                    'y': [],
                }
                plots['{}_fit'.format(target)].data_source.data = {
                    'x': [],
                    'y': [],
                }
                plots['{}_init_fit'.format(target)].data_source.data = {
                    'x': [],
                    'y': [],
                }
                return

            if target == 'bottom':
                residual_x = coord_vals
                residual_y = current_fit.residual
                init_fit_x = coord_vals
                init_fit_y = current_fit.init_fit
                fit_x = coord_vals
                fit_y = current_fit.best_fit
            else:
                residual_y = coord_vals
                residual_x = current_fit.residual
                init_fit_y = coord_vals
                init_fit_x = current_fit.init_fit
                fit_y = coord_vals
                fit_x = current_fit.best_fit

            plots['{}_residual'.format(target)].data_source.data = {
                'x': residual_x,
                'y': residual_y,
            }
            plots['{}_fit'.format(target)].data_source.data = {
                'x': fit_x,
                'y': fit_y,
            }
            plots['{}_init_fit'.format(target)].data_source.data = {
                'x': init_fit_x,
                'y': init_fit_y,
            }

        def click_right_marginal(event):
            self.cursor = [self.cursor[0], self.cursor[1], event.y]
            on_change_main_view(None, None, self.selected_data)

        def click_main_image(event):
            if two_dimensional:
                self.cursor = [event.x, event.y, self.cursor[2]]
            else:
                self.cursor = [event.x, event.y]

            if not two_dimensional:
                right_marginal_data = raw_data.sel(**dict(
                    [[raw_data.dims[0], self.cursor[0]]]),
                                                   method='nearest')
                bottom_marginal_data = raw_data.sel(**dict(
                    [[raw_data.dims[1], self.cursor[1]]]),
                                                    method='nearest')
                plots['bottom'].data_source.data = {
                    'x': bottom_marginal_data.coords[raw_data.dims[0]].values,
                    'y': bottom_marginal_data.values,
                }
            else:
                right_marginal_data = raw_data.sel(**{
                    k: v
                    for k, v in self.cursor_dict.items() if k != fit_direction
                },
                                                   method='nearest')

            plots['right'].data_source.data = {
                'y':
                right_marginal_data.coords[right_marginal_data.dims[0]].values,
                'x': right_marginal_data.values,
            }

            update_fit_display()

        def on_change_outlier_clip(attr, old, new):
            self.outlier_clip = new
            on_change_main_view(None, None, self.selected_data)

        def set_remove_outliers(should_remove_outliers):
            if self.remove_outliers != should_remove_outliers:
                self.remove_outliers = should_remove_outliers

                on_change_main_view(None, None, self.selected_data)

        update_main_colormap = self.update_colormap_for('main')
        MAIN_CONTENT_OPTIONS = [
            ('Residual', 'residual'),
            ('Data', 'data'),
        ]

        if two_dimensional:
            available_parameters = fit_results.F.parameter_names

            for param_name in available_parameters:
                MAIN_CONTENT_OPTIONS.append((
                    param_name,
                    param_name,
                ))

        remove_outliers_toggle = widgets.Toggle(label='Remove Outliers',
                                                button_type='primary',
                                                active=self.remove_outliers)
        remove_outliers_toggle.on_click(set_remove_outliers)

        outlier_clip_slider = widgets.Slider(title='Clip',
                                             start=0,
                                             end=10,
                                             value=self.outlier_clip,
                                             callback_throttle=150,
                                             step=0.2)
        outlier_clip_slider.on_change('value', on_change_outlier_clip)

        main_content_select = widgets.Dropdown(label='Main Content',
                                               button_type='primary',
                                               menu=MAIN_CONTENT_OPTIONS)
        main_content_select.on_change('value', on_change_main_view)

        # Widgety things
        main_color_range_slider = widgets.RangeSlider(start=0,
                                                      end=100,
                                                      value=(
                                                          0,
                                                          100,
                                                      ),
                                                      title='Color Range')

        # Attach callbacks
        main_color_range_slider.on_change('value', update_main_colormap)
        figures['main'].on_event(events.Tap, click_main_image)
        if two_dimensional:
            figures['right'].on_event(events.Tap, click_right_marginal)

        layout = row(
            column(figures['main'], figures.get('bottom')),
            column(figures['right'], app_widgets['fit_info_div']),
            column(
                widgetbox(*[
                    widget for widget in [
                        self._cursor_info,
                        main_color_range_slider,
                        main_content_select,
                        remove_outliers_toggle if two_dimensional else None,
                        outlier_clip_slider if two_dimensional else None,
                    ] if widget is not None
                ]), ))

        update_fit_display()

        doc.add_root(layout)
        doc.title = 'Band Tool'
    def tool_handler(self, doc):
        from bokeh.layouts import row, column, widgetbox
        from bokeh.models import widgets
        from bokeh.models.mappers import LinearColorMapper
        from bokeh.plotting import figure

        default_palette = self.default_palette
        difference_palette = cc.coolwarm

        intensity_slider = widgets.Slider(
            title='Relative Intensity Scaling', start=0.5, end=1.5,
            step=0.005, value=1)

        self.app_context.update({
            'A': self.arr,
            'B': self.other,
            'compared': self.compared,
            'plots': {},
            'figures': {},
            'widgets': {},
            'data_range': self.arr.T.range(),
            'color_maps': {},
        })

        self.color_maps['main'] = LinearColorMapper(
            default_palette, low=np.min(self.arr.values), high=np.max(self.arr.values), nan_color='black')

        figure_kwargs = {
            'tools': ['reset', 'wheel_zoom'],
            'plot_width': self.app_main_size,
            'plot_height': self.app_main_size,
            'min_border': 10,
            'toolbar_location': 'left',
            'x_range': self.data_range['x'],
            'y_range': self.data_range['y'],
            'x_axis_location': 'below',
            'y_axis_location': 'right',
        }

        self.figures['A'] = figure(title='Spectrum A', **figure_kwargs)
        self.figures['B'] = figure(title='Spectrum B', **figure_kwargs)
        self.figures['compared'] = figure(title='Comparison', **figure_kwargs)

        self.compared = self.arr - self.other
        diff_low, diff_high = np.min(self.arr.values), np.max(self.arr.values)
        diff_range = np.sqrt((abs(diff_low) + 1) * (abs(diff_high) + 1)) * 1.5
        self.color_maps['difference'] = LinearColorMapper(
            difference_palette, low=-diff_range, high=diff_range, nan_color='white')

        self.plots['A'] = self.figures['A'].image(
            [self.arr.values], x=self.data_range['x'][0], y=self.data_range['y'][0],
            dw=self.data_range['x'][1] - self.data_range['x'][0],
            dh=self.data_range['y'][1] - self.data_range['y'][0],
            color_mapper=self.color_maps['main'],
        )

        self.plots['B'] = self.figures['B'].image(
            [self.other.values], x=self.data_range['x'][0], y=self.data_range['y'][0],
            dw=self.data_range['x'][1] - self.data_range['x'][0],
            dh=self.data_range['y'][1] - self.data_range['y'][0],
            color_mapper=self.color_maps['main']
        )

        self.plots['compared'] = self.figures['compared'].image(
            [self.compared.values], x=self.data_range['x'][0], y=self.data_range['y'][0],
            dw=self.data_range['x'][1] - self.data_range['x'][0],
            dh=self.data_range['y'][1] - self.data_range['y'][0],
            color_mapper=self.color_maps['difference']
        )

        x_axis_name = self.arr.dims[0]
        y_axis_name = self.arr.dims[1]

        stride = self.arr.T.stride()
        delta_x_axis = stride['x']
        delta_y_axis = stride['y']

        delta_x_slider = widgets.Slider(
            title='{} Shift'.format(x_axis_name), start=-20 * delta_x_axis,
            step=delta_x_axis / 2, end=20 * delta_x_axis, value=0)

        delta_y_slider = widgets.Slider(
            title='{} Shift'.format(y_axis_name), start=-20 * delta_y_axis,
            step=delta_y_axis / 2, end=20 * delta_y_axis, value=0)

        @Debounce(0.5)
        def update_summed_figure(attr, old, new):
            # we don't actually use the args because we need to pull all the data out
            shifted = (intensity_slider.value) * scipy.ndimage.interpolation.shift(self.other.values, [
                delta_x_slider.value / delta_x_axis,
                delta_y_slider.value / delta_y_axis,
            ], order=1, prefilter=False, cval=np.nan)
            self.compared = self.arr - xr.DataArray(
                shifted,
                coords=self.arr.coords,
                dims=self.arr.dims)

            self.compared.attrs.update(**self.arr.attrs)
            try:
                del self.compared.attrs['id']
            except KeyError:
                pass

            self.app_context['compared'] = self.compared
            self.plots['compared'].data_source.data = {
                'image': [self.compared.values]
            }

        layout = column(
            row(
                column(self.app_context['figures']['A']),
                column(self.app_context['figures']['B']),
            ),
            row(
                column(self.app_context['figures']['compared']),
                widgetbox(
                    intensity_slider,
                    delta_x_slider,
                    delta_y_slider,
                ),
            )
        )

        update_summed_figure(None, None, None)

        delta_x_slider.on_change('value', update_summed_figure)
        delta_y_slider.on_change('value', update_summed_figure)
        intensity_slider.on_change('value', update_summed_figure)

        doc.add_root(layout)
        doc.title = 'Comparison Tool'
Ejemplo n.º 3
0
    def tool_handler(self, doc):
        from bokeh import events
        from bokeh.layouts import row, column, widgetbox
        from bokeh.models.mappers import LinearColorMapper
        from bokeh.models import widgets
        from bokeh.plotting import figure

        if len(self.arr.shape) != 2:
            raise AnalysisError(
                'Cannot use the band tool on non image-like spectra')

        self.data_for_display = self.arr
        x_coords, y_coords = self.data_for_display.coords[
            self.data_for_display.dims[0]], self.data_for_display.coords[
                self.data_for_display.dims[1]]

        default_palette = self.default_palette

        self.app_context.update({
            'data': self.arr,
            'data_range': {
                'x': (np.min(x_coords.values), np.max(x_coords.values)),
                'y': (np.min(y_coords.values), np.max(y_coords.values)),
            },
        })

        figures, plots = self.app_context['figures'], self.app_context['plots']

        self.cursor = [
            np.mean(self.data_range['x']),
            np.mean(self.data_range['y'])
        ]

        self.color_maps['main'] = LinearColorMapper(
            default_palette,
            low=np.min(self.data_for_display.values),
            high=np.max(self.data_for_display.values),
            nan_color='black')

        main_tools = ["wheel_zoom", "tap", "reset"]
        main_title = '{} Tool: WARNING Unidentified'.format(
            self.analysis_fn.__name__)

        try:
            main_title = '{} Tool: {}'.format(
                self.analysis_fn.__name__, self.data_for_display.S.label[:60])
        except:
            pass

        figures['main'] = figure(tools=main_tools,
                                 plot_width=self.app_main_size,
                                 plot_height=self.app_main_size,
                                 min_border=10,
                                 min_border_left=50,
                                 toolbar_location='left',
                                 x_axis_location='below',
                                 y_axis_location='right',
                                 title=main_title,
                                 x_range=self.data_range['x'],
                                 y_range=self.data_range['y'])
        figures['main'].xaxis.axis_label = self.data_for_display.dims[0]
        figures['main'].yaxis.axis_label = self.data_for_display.dims[1]
        figures['main'].toolbar.logo = None
        figures['main'].background_fill_color = "#fafafa"
        plots['main'] = figures['main'].image(
            [self.data_for_display.values.T],
            x=self.app_context['data_range']['x'][0],
            y=self.app_context['data_range']['y'][0],
            dw=self.app_context['data_range']['x'][1] -
            self.app_context['data_range']['x'][0],
            dh=self.app_context['data_range']['y'][1] -
            self.app_context['data_range']['y'][0],
            color_mapper=self.app_context['color_maps']['main'])

        # Create the bottom marginal plot
        bottom_marginal = self.data_for_display.sel(**dict(
            [[self.data_for_display.dims[1], self.cursor[1]]]),
                                                    method='nearest')
        bottom_marginal_original = self.arr.sel(**dict(
            [[self.data_for_display.dims[1], self.cursor[1]]]),
                                                method='nearest')
        figures['bottom_marginal'] = figure(
            plot_width=self.app_main_size,
            plot_height=200,
            title=None,
            x_range=figures['main'].x_range,
            y_range=(np.min(bottom_marginal.values),
                     np.max(bottom_marginal.values)),
            x_axis_location='above',
            toolbar_location=None,
            tools=[])
        plots['bottom_marginal'] = figures['bottom_marginal'].line(
            x=bottom_marginal.coords[self.data_for_display.dims[0]].values,
            y=bottom_marginal.values)
        plots['bottom_marginal_original'] = figures['bottom_marginal'].line(
            x=bottom_marginal_original.coords[self.arr.dims[0]].values,
            y=bottom_marginal_original.values,
            line_color='red')

        # Create the right marginal plot
        right_marginal = self.data_for_display.sel(**dict(
            [[self.data_for_display.dims[0], self.cursor[0]]]),
                                                   method='nearest')
        right_marginal_original = self.arr.sel(**dict(
            [[self.data_for_display.dims[0], self.cursor[0]]]),
                                               method='nearest')
        figures['right_marginal'] = figure(
            plot_width=200,
            plot_height=self.app_main_size,
            title=None,
            y_range=figures['main'].y_range,
            x_range=(np.min(right_marginal.values),
                     np.max(right_marginal.values)),
            y_axis_location='left',
            toolbar_location=None,
            tools=[])
        plots['right_marginal'] = figures['right_marginal'].line(
            y=right_marginal.coords[self.data_for_display.dims[1]].values,
            x=right_marginal.values)
        plots['right_marginal_original'] = figures['right_marginal'].line(
            y=right_marginal_original.coords[
                self.data_for_display.dims[1]].values,
            x=right_marginal_original.values,
            line_color='red')

        # add lines
        self.add_cursor_lines(figures['main'])
        _ = figures['main'].multi_line(xs=[],
                                       ys=[],
                                       line_color='white',
                                       line_width=1)  # band lines

        # prep the widgets for the analysis function
        signature = inspect.signature(self.analysis_fn)

        # drop the first which has to be the input data, we can revisit later if this is too limiting
        parameter_names = list(signature.parameters)[1:]
        named_widgets = dict(zip(parameter_names, self.widget_specification))
        built_widgets = {}

        def update_marginals():
            right_marginal_data = self.data_for_display.sel(**dict(
                [[self.data_for_display.dims[0], self.cursor[0]]]),
                                                            method='nearest')
            bottom_marginal_data = self.data_for_display.sel(**dict(
                [[self.data_for_display.dims[1], self.cursor[1]]]),
                                                             method='nearest')
            plots['bottom_marginal'].data_source.data = {
                'x': bottom_marginal_data.coords[
                    self.data_for_display.dims[0]].values,
                'y': bottom_marginal_data.values,
            }
            plots['right_marginal'].data_source.data = {
                'y': right_marginal_data.coords[
                    self.data_for_display.dims[1]].values,
                'x': right_marginal_data.values,
            }

            right_marginal_data = self.arr.sel(**dict(
                [[self.data_for_display.dims[0], self.cursor[0]]]),
                                               method='nearest')
            bottom_marginal_data = self.arr.sel(**dict(
                [[self.data_for_display.dims[1], self.cursor[1]]]),
                                                method='nearest')
            plots['bottom_marginal_original'].data_source.data = {
                'x': bottom_marginal_data.coords[
                    self.data_for_display.dims[0]].values,
                'y': bottom_marginal_data.values,
            }
            plots['right_marginal_original'].data_source.data = {
                'y': right_marginal_data.coords[
                    self.data_for_display.dims[1]].values,
                'x': right_marginal_data.values,
            }
            figures['bottom_marginal'].y_range.start = np.min(
                bottom_marginal_data.values)
            figures['bottom_marginal'].y_range.end = np.max(
                bottom_marginal_data.values)
            figures['right_marginal'].x_range.start = np.min(
                right_marginal_data.values)
            figures['right_marginal'].x_range.end = np.max(
                right_marginal_data.values)

        def click_main_image(event):
            self.cursor = [event.x, event.y]
            update_marginals()

        error_msg = widgets.Div(text='')

        @Debounce(0.25)
        def update_data_for_display():
            try:
                self.data_for_display = self.analysis_fn(
                    self.arr, *[
                        built_widgets[p].value for p in parameter_names
                        if p in built_widgets
                    ])
                error_msg.text = ''
            except Exception as e:
                error_msg.text = '{}'.format(e)

            # flush + update
            update_marginals()
            plots['main'].data_source.data = {
                'image': [self.data_for_display.values.T]
            }

        def update_data_change_wrapper(attr, old, new):
            if old != new:
                update_data_for_display()

        for parameter_name in named_widgets.keys():
            specification = named_widgets[parameter_name]

            widget = None
            if specification == int:
                widget = widgets.Slider(start=-20,
                                        end=20,
                                        value=0,
                                        title=parameter_name)
            if specification == float:
                widget = widgets.Slider(start=-20,
                                        end=20,
                                        value=0.,
                                        step=0.1,
                                        title=parameter_name)

            if widget is not None:
                built_widgets[parameter_name] = widget
                widget.on_change('value', update_data_change_wrapper)

        update_main_colormap = self.update_colormap_for('main')

        self.app_context['run'] = lambda x: x

        main_color_range_slider = widgets.RangeSlider(start=0,
                                                      end=100,
                                                      value=(
                                                          0,
                                                          100,
                                                      ),
                                                      title='Color Range')

        # Attach callbacks
        main_color_range_slider.on_change('value', update_main_colormap)

        figures['main'].on_event(events.Tap, click_main_image)

        layout = row(
            column(figures['main'], figures['bottom_marginal']),
            column(figures['right_marginal']),
            column(
                widgetbox(*[
                    built_widgets[p] for p in parameter_names
                    if p in built_widgets
                ]),
                widgetbox(
                    self._cursor_info,
                    main_color_range_slider,
                    error_msg,
                )))

        doc.add_root(layout)
        doc.title = 'Band Tool'
Ejemplo n.º 4
0
DMFullData = None
DM_list = list(
    np.arange(dm_range[0],
              dm_range[1] + dm_range_spacing,
              step=dm_range_spacing))

ScatterData = None

PreFoldingData = None
PostFoldingData = None

################################################################################
dmSlider = widgets.Slider(title="Dispersion Measure",
                          value=0,
                          start=dm_range[0],
                          end=dm_range[1],
                          step=dm_range_spacing)

scStep = FL_bw / FL_Nf  #Span of frequencies to one bin
scStart = FL_f0 - (FL_bw / 2) + (scStep / 2
                                 )  #Middle of the lowest frequency bin
scEnd = FL_f0 + (FL_bw / 2) - (scStep / 2
                               )  #Middl6e of the highest frequency bin
scSlider = widgets.Slider(title="Frequency (MHz)",
                          value=scEnd,
                          start=scStart,
                          end=scEnd,
                          step=scStep)

flSlider = widgets.Slider(title="Folding Frequency (Hz)",
    def tool_handler(self, doc):
        from bokeh.layouts import row, column, widgetbox
        from bokeh.models import widgets, Spacer
        from bokeh.models.mappers import LinearColorMapper
        from bokeh.plotting import figure

        default_palette = self.default_palette

        x_coords, y_coords = self.arr.coords[
            self.arr.dims[1]], self.arr.coords[self.arr.dims[0]]
        self.app_context.update({
            'data': self.arr,
            'cached_data': {},
            'gamma_cached_data': {},
            'plots': {},
            'data_range': self.arr.T.range(),
            'figures': {},
            'widgets': {},
            'color_maps': {}
        })

        self.app_context['color_maps']['d2'] = LinearColorMapper(
            default_palette,
            low=np.min(self.arr.values),
            high=np.max(self.arr.values),
            nan_color='black')

        self.app_context['color_maps']['curvature'] = LinearColorMapper(
            default_palette,
            low=np.min(self.arr.values),
            high=np.max(self.arr.values),
            nan_color='black')

        self.app_context['color_maps']['raw'] = LinearColorMapper(
            default_palette,
            low=np.min(self.arr.values),
            high=np.max(self.arr.values),
            nan_color='black')

        plots, figures, data_range, cached_data, gamma_cached_data = (
            self.app_context['plots'],
            self.app_context['figures'],
            self.app_context['data_range'],
            self.app_context['cached_data'],
            self.app_context['gamma_cached_data'],
        )

        cached_data['raw'] = self.arr.values
        gamma_cached_data['raw'] = self.arr.values

        figure_kwargs = {
            'tools': ['reset', 'wheel_zoom'],
            'plot_width': self.app_main_size,
            'plot_height': self.app_main_size,
            'min_border': 10,
            'toolbar_location': 'left',
            'x_range': data_range['x'],
            'y_range': data_range['y'],
            'x_axis_location': 'below',
            'y_axis_location': 'right',
        }
        figures['d2'] = figure(title='d2 Spectrum', **figure_kwargs)

        figure_kwargs.update({
            'y_range': self.app_context['figures']['d2'].y_range,
            'x_range': self.app_context['figures']['d2'].x_range,
            'toolbar_location': None,
            'y_axis_location': 'left',
        })

        figures['curvature'] = figure(title='Curvature', **figure_kwargs)
        figures['raw'] = figure(title='Raw Image', **figure_kwargs)

        figures['curvature'].yaxis.major_label_text_font_size = '0pt'

        # TODO add support for color mapper
        plots['d2'] = figures['d2'].image(
            [self.arr.values],
            x=data_range['x'][0],
            y=data_range['y'][0],
            dw=data_range['x'][1] - data_range['x'][0],
            dh=data_range['y'][1] - data_range['y'][0],
            color_mapper=self.app_context['color_maps']['d2'])
        plots['curvature'] = figures['curvature'].image(
            [self.arr.values],
            x=data_range['x'][0],
            y=data_range['y'][0],
            dw=data_range['x'][1] - data_range['x'][0],
            dh=data_range['y'][1] - data_range['y'][0],
            color_mapper=self.app_context['color_maps']['curvature'])
        plots['raw'] = figures['raw'].image(
            [self.arr.values],
            x=data_range['x'][0],
            y=data_range['y'][0],
            dw=data_range['x'][1] - data_range['x'][0],
            dh=data_range['y'][1] - data_range['y'][0],
            color_mapper=self.app_context['color_maps']['raw'])

        smoothing_sliders_by_name = {}
        smoothing_sliders = []  # need one for each axis
        axis_resolution = self.arr.T.stride(generic_dim_names=False)
        for dim in self.arr.dims:
            coords = self.arr.coords[dim]
            resolution = float(axis_resolution[dim])
            high_resolution = len(coords) / 3 * resolution
            low_resolution = resolution

            # could make this axis dependent for more reasonable defaults
            default = 15 * resolution

            if default > high_resolution:
                default = (high_resolution + low_resolution) / 2

            new_slider = widgets.Slider(title='{} Window'.format(dim),
                                        start=low_resolution,
                                        end=high_resolution,
                                        step=resolution,
                                        value=default)
            smoothing_sliders.append(new_slider)
            smoothing_sliders_by_name[dim] = new_slider

        n_smoothing_steps_slider = widgets.Slider(title="Smoothing Steps",
                                                  start=0,
                                                  end=5,
                                                  step=1,
                                                  value=2)
        beta_slider = widgets.Slider(title="β",
                                     start=-8,
                                     end=8,
                                     step=1,
                                     value=0)
        direction_select = widgets.Select(
            options=list(self.arr.dims),
            value='eV' if 'eV' in self.arr.dims else
            self.arr.dims[0],  # preference to energy,
            title='Derivative Direction')
        interleave_smoothing_toggle = widgets.Toggle(
            label='Interleave smoothing with d/dx',
            active=True,
            button_type='primary')
        clamp_spectrum_toggle = widgets.Toggle(
            label='Clamp positive values to 0',
            active=True,
            button_type='primary')
        filter_select = widgets.Select(options=['Gaussian', 'Boxcar'],
                                       value='Boxcar',
                                       title='Type of Filter')

        color_slider = widgets.RangeSlider(start=0,
                                           end=100,
                                           value=(
                                               0,
                                               100,
                                           ),
                                           title='Color Clip')
        gamma_slider = widgets.Slider(start=0.1,
                                      end=4,
                                      value=1,
                                      step=0.1,
                                      title='Gamma')

        # don't need any cacheing here for now, might if this ends up being too slow
        def smoothing_fn(n_passes):
            if n_passes == 0:
                return lambda x: x

            filter_factory = {
                'Gaussian': gaussian_filter,
                'Boxcar': boxcar_filter,
            }.get(filter_select.value, boxcar_filter)

            filter_size = {
                d: smoothing_sliders_by_name[d].value
                for d in self.arr.dims
            }
            return filter_factory(filter_size, n_passes)

        @Debounce(0.25)
        def force_update():
            n_smoothing_steps = n_smoothing_steps_slider.value
            d2_data = self.arr
            if interleave_smoothing_toggle.active:
                f = smoothing_fn(n_smoothing_steps // 2)
                d2_data = d1_along_axis(f(d2_data), direction_select.value)
                f = smoothing_fn(n_smoothing_steps - (n_smoothing_steps // 2))
                d2_data = d1_along_axis(f(d2_data), direction_select.value)

            else:
                f = smoothing_fn(n_smoothing_steps)
                d2_data = d2_along_axis(f(self.arr), direction_select.value)

            d2_data.values[
                d2_data.values != d2_data.
                values] = 0  # remove NaN values until Bokeh fixes NaNs over the wire
            if clamp_spectrum_toggle.active:
                d2_data.values = -d2_data.values
                d2_data.values[d2_data.values < 0] = 0
            cached_data['d2'] = d2_data.values
            gamma_cached_data['d2'] = d2_data.values**gamma_slider.value
            plots['d2'].data_source.data = {'image': [gamma_cached_data['d2']]}

            curv_smoothing_fn = smoothing_fn(n_smoothing_steps)
            smoothed_curvature_data = curv_smoothing_fn(self.arr)
            curvature_data = curvature(smoothed_curvature_data,
                                       self.arr.dims,
                                       beta=beta_slider.value)
            curvature_data.values[
                curvature_data.values != curvature_data.values] = 0
            if clamp_spectrum_toggle.active:
                curvature_data.values = -curvature_data.values
                curvature_data.values[curvature_data.values < 0] = 0

            cached_data['curvature'] = curvature_data.values
            gamma_cached_data[
                'curvature'] = curvature_data.values**gamma_slider.value
            plots['curvature'].data_source.data = {
                'image': [gamma_cached_data['curvature']]
            }
            update_color_slider(color_slider.value)

        # TODO better integrate these, they can share code with the above if we are more careful.
        def take_d2(d2_data):
            n_smoothing_steps = n_smoothing_steps_slider.value
            if interleave_smoothing_toggle.active:
                f = smoothing_fn(n_smoothing_steps // 2)
                d2_data = d1_along_axis(f(d2_data), direction_select.value)
                f = smoothing_fn(n_smoothing_steps - (n_smoothing_steps // 2))
                d2_data = d1_along_axis(f(d2_data), direction_select.value)

            else:
                f = smoothing_fn(n_smoothing_steps)
                d2_data = d2_along_axis(f(self.arr), direction_select.value)

            d2_data.values[
                d2_data.values != d2_data.
                values] = 0  # remove NaN values until Bokeh fixes NaNs over the wire
            if clamp_spectrum_toggle.active:
                d2_data.values = -d2_data.values
                d2_data.values[d2_data.values < 0] = 0

            return d2_data

        def take_curvature(curvature_data, curve_dims):
            curv_smoothing_fn = smoothing_fn(n_smoothing_steps_slider.value)
            smoothed_curvature_data = curv_smoothing_fn(curvature_data)
            curvature_data = curvature(smoothed_curvature_data,
                                       curve_dims,
                                       beta=beta_slider.value)
            curvature_data.values[
                curvature_data.values != curvature_data.values] = 0
            if clamp_spectrum_toggle.active:
                curvature_data.values = -curvature_data.values
                curvature_data.values[curvature_data.values < 0] = 0

            return curvature_data

        # These functions will always be linked to the current context of the curvature tool.
        self.app_context['d2_fn'] = take_d2
        self.app_context['curvature_fn'] = take_curvature

        def force_update_change_wrapper(attr, old, new):
            if old != new:
                force_update()

        def force_update_click_wrapper(event):
            force_update()

        @Debounce(0.1)
        def update_color_slider(new):
            def update_plot(name, data):
                low, high = np.min(data), np.max(data)
                dynamic_range = high - low
                self.app_context['color_maps'][name].update(
                    low=low + new[0] / 100 * dynamic_range,
                    high=low + new[1] / 100 * dynamic_range)

            update_plot('d2', gamma_cached_data['d2'])
            update_plot('curvature', gamma_cached_data['curvature'])
            update_plot('raw', gamma_cached_data['raw'])

        @Debounce(0.1)
        def update_gamma_slider(new):
            gamma_cached_data['d2'] = cached_data['d2']**new
            gamma_cached_data['curvature'] = cached_data['curvature']**new
            gamma_cached_data['raw'] = cached_data['raw']**new
            update_color_slider(color_slider.value)

        def update_color_handler(attr, old, new):
            update_color_slider(new)

        def update_gamma_handler(attr, old, new):
            update_gamma_slider(new)

        layout = column(
            row(
                column(self.app_context['figures']['d2'],
                       interleave_smoothing_toggle, direction_select),
                column(self.app_context['figures']['curvature'], beta_slider,
                       clamp_spectrum_toggle),
                column(self.app_context['figures']['raw'], color_slider,
                       gamma_slider)),
            widgetbox(
                filter_select,
                *smoothing_sliders,
                n_smoothing_steps_slider,
            ),
            Spacer(height=100),
        )

        # Attach event handlers
        for w in (n_smoothing_steps_slider, beta_slider, direction_select,
                  *smoothing_sliders, filter_select):
            w.on_change('value', force_update_change_wrapper)

        interleave_smoothing_toggle.on_click(force_update_click_wrapper)
        clamp_spectrum_toggle.on_click(force_update_click_wrapper)

        color_slider.on_change('value', update_color_handler)
        gamma_slider.on_change('value', update_gamma_handler)

        force_update()

        doc.add_root(layout)
        doc.title = 'Curvature Tool'
Ejemplo n.º 6
0
pwave2 = bkp.figure(title='Amplitude Spectrum',
                    x_axis_label='Freq',
                    webgl=True)

sw1 = bkm.ColumnDataSource(data=dict(time=wave1.timeseries, B=wave1.timeAmp))
sf1 = bkm.ColumnDataSource(data=dict(freq=wave1.freqseries, B=wave1.ampSpec))

pwave1.line('time', 'B', color=wave1.colour, source=sw1)
pwave2.line('freq', 'B', color=wave1.colour, source=sf1)

waveletslayout = bkl.column(pwave1, pwave2)

#set up widgets
domFreq = bkw.Slider(title="Dom Frequency",
                     value=25,
                     start=0,
                     end=50.0,
                     step=1)


# Set up callbacks
def update_data(attrname, old, new):

    # Get the current slider values
    dF = domFreq.value

    # Generate the new curve
    wave1.typeRicker(dF)
    wave1.calcAmpSpec()
    print('update')
    sw1.data = dict(time=wave1.timeseries, B=wave1.timeAmp)