예제 #1
0
    def setup_layout(self):

        #layout setup
        self.flights = bkwidgets.MultiSelect()
        tile_provider = get_provider(Vendors.CARTODBPOSITRON)
        p = bkplot.figure(x_axis_type='mercator', y_axis_type='mercator')
        p.add_tile(tile_provider)
        self.p2 = bkplot.figure()
        self.lines = self.p2.line(x='time', y='param', source=self.source2)
        self.hist = self.p2.quad(source=self.source3,
                                 top='top',
                                 bottom='bottom',
                                 left='left',
                                 right='right')
        self.params = bkwidgets.Select(
            options=['altitude', 'tas', 'fpf', 'performance_hist'],
            value='altitude')
        p2control = bklayouts.WidgetBox(self.params)
        layout = bklayouts.layout(self.controls, p)
        self.tables = bkwidgets.Select(options=self.tableList,
                                       value=self.tableList[0])
        self.time = bkwidgets.RangeSlider()
        self.populated = False

        #plot setup
        self.points = p.triangle(x='longitude',
                                 y='latitude',
                                 angle='heading',
                                 angle_units='deg',
                                 alpha=0.5,
                                 source=self.source)
        hover = bk.models.HoverTool()
        hover.tooltips = [("Callsign", "@callsign"), ("Time", "@time"),
                          ("Phase", "@status"), ("Heading", "@heading"),
                          ("Altitude", "@altitude"), ("Speed", "@tas")]
        hover2 = bk.models.HoverTool()
        hover2.tooltips = [("Callsign", "@callsign")]
        self.p2.add_tools(hover2)
        p.add_tools(hover)

        #callback setup
        self.params.on_change('value', self.plot_param)
        self.tables.on_change('value', self.set_data_source)
        self.cmdline.on_change('value', self.runCmd)

        self.controls = bklayouts.row(
            bklayouts.WidgetBox(self.tables, self.cmdline),
            bk.models.Div(width=20), p2control)
        self.layout = bklayouts.column(self.controls,
                                       bklayouts.row(p, self.p2))
예제 #2
0
 def set_data_source(self, attr, new, old):
     t = self.tables.value
     try:
         self.results = checkForTable(t)
     except dbError:
         if os.path.exists(NATS_DIR + t):
             cmd = readNATS.Command(t)
         elif os.path.exists(SHERLOCK_DIR + t):
             cmd = readIFF.Command(t)
         else:
             cmd = readIFF.Command(t)
         self.results = cmd.executeCommand()[1]
         self.db_access.addTable(t, self.results)
     self.results['time'] = self.results['time'].astype(
         'datetime64[s]').astype(int)
     acids = np.unique(self.results['callsign']).tolist()
     times = sorted(np.unique(self.results['time']))
     self.flights = bkwidgets.MultiSelect(options=acids, value=[
         acids[0],
     ])
     self.flights.on_change('value', self.update)
     self.time = bkwidgets.RangeSlider(title="time",
                                       value=(times[0], times[-1]),
                                       start=times[0],
                                       end=times[-1],
                                       step=1)
     self.time.on_change('value', self.update)
     if self.populated:
         self.controls.children[1] = self.flights
         self.controls.children[2] = self.time
     else:
         self.controls.children.insert(1, self.flights)
         self.controls.children.insert(2, self.time)
         self.populated = True
     self.results['longitude'], self.results['latitude'] = merc(
         np.asarray(self.results['latitude'].astype(float)),
         np.asarray(self.results['longitude'].astype(float)))
     self.update('attr', 'new', 'old')
    def tool_handler(self, doc):
        from bokeh.layouts import row, column, widgetbox, Spacer
        from bokeh.models.mappers import LinearColorMapper
        from bokeh.models import widgets
        from bokeh.models.widgets.markups import Div
        from bokeh.plotting import figure

        self.arr = self.arr.copy(deep=True)

        if not isinstance(self.arr, xr.Dataset):
            self.use_dataset = False

        residual = None
        if self.use_dataset:
            raw_data = self.arr.data
            raw_data.values[np.isnan(raw_data.values)] = 0
            fit_results = self.arr.results
            residual = self.arr.residual
            residual.values[np.isnan(residual.values)] = 0
        else:
            raw_data = self.arr.attrs['original_data']
            fit_results = self.arr

        fit_direction = [d for d in raw_data.dims if d not in fit_results.dims]
        fit_direction = fit_direction[0]

        two_dimensional = False
        if len(raw_data.dims) != 2:
            two_dimensional = True
            x_coords, y_coords = fit_results.coords[
                fit_results.dims[0]], fit_results.coords[fit_results.dims[1]]
            z_coords = raw_data.coords[fit_direction]
        else:
            x_coords, y_coords = raw_data.coords[
                raw_data.dims[0]], raw_data.coords[raw_data.dims[1]]

        if two_dimensional:
            self.settings['palette'] = 'coolwarm'
        default_palette = self.default_palette

        self.app_context.update({
            'data': raw_data,
            'fits': fit_results,
            'residual': residual,
            'original': self.arr,
            'data_range': {
                'x': (np.min(x_coords.values), np.max(x_coords.values)),
                'y': (np.min(y_coords.values), np.max(y_coords.values)),
            }
        })
        if two_dimensional:
            self.app_context['data_range']['z'] = (np.min(z_coords.values),
                                                   np.max(z_coords.values))

        figures, plots, app_widgets = self.app_context['figures'], self.app_context['plots'],\
                                      self.app_context['widgets']

        self.cursor_dims = raw_data.dims
        if two_dimensional:
            self.cursor = [
                np.mean(self.data_range['x']),
                np.mean(self.data_range['y']),
                np.mean(self.data_range['z'])
            ]
        else:
            self.cursor = [
                np.mean(self.data_range['x']),
                np.mean(self.data_range['y'])
            ]

        app_widgets['fit_info_div'] = Div(text='')

        self.app_context['color_maps']['main'] = LinearColorMapper(
            default_palette,
            low=np.min(raw_data.values),
            high=np.max(raw_data.values),
            nan_color='black')

        main_tools = ["wheel_zoom", "tap", "reset", "save"]
        main_title = 'Fit Inspection Tool: WARNING Unidentified'

        try:
            main_title = 'Fit Inspection Tool: {}'.format(
                raw_data.S.label[:60])
        except:
            pass

        figures['main'] = figure(tools=main_tools,
                                 plot_width=self.app_main_size,
                                 plot_height=self.app_main_size,
                                 min_border=10,
                                 min_border_left=50,
                                 toolbar_location='left',
                                 x_axis_location='below',
                                 y_axis_location='right',
                                 title=main_title,
                                 x_range=self.data_range['x'],
                                 y_range=self.app_context['data_range']['y'])
        figures['main'].xaxis.axis_label = raw_data.dims[0]
        figures['main'].yaxis.axis_label = raw_data.dims[1]
        figures['main'].toolbar.logo = None
        figures['main'].background_fill_color = "#fafafa"

        data_for_main = raw_data
        if two_dimensional:
            data_for_main = data_for_main.sel(**dict(
                [[fit_direction, self.cursor[2]]]),
                                              method='nearest')
        plots['main'] = figures['main'].image(
            [data_for_main.values.T],
            x=self.app_context['data_range']['x'][0],
            y=self.app_context['data_range']['y'][0],
            dw=self.app_context['data_range']['x'][1] -
            self.app_context['data_range']['x'][0],
            dh=self.app_context['data_range']['y'][1] -
            self.app_context['data_range']['y'][0],
            color_mapper=self.app_context['color_maps']['main'])

        band_centers = [b.center for b in fit_results.F.bands.values()]
        bands_xs = [b.coords[b.dims[0]].values for b in band_centers]
        bands_ys = [b.values for b in band_centers]
        if fit_results.dims[0] == raw_data.dims[1]:
            bands_ys, bands_xs = bands_xs, bands_ys
        plots['band_locations'] = figures['main'].multi_line(
            xs=bands_xs,
            ys=bands_ys,
            line_color='white',
            line_width=1,
            line_dash='dashed')

        # add cursor lines
        cursor_lines = self.add_cursor_lines(figures['main'])

        # marginals
        if not two_dimensional:
            figures['bottom'] = figure(plot_width=self.app_main_size,
                                       plot_height=self.app_marginal_size,
                                       min_border=10,
                                       title=None,
                                       x_range=figures['main'].x_range,
                                       x_axis_location='above',
                                       toolbar_location=None,
                                       tools=[])
        else:
            figures['bottom'] = Spacer(width=self.app_main_size,
                                       height=self.app_marginal_size)

        right_y_range = figures['main'].y_range
        if two_dimensional:
            right_y_range = self.data_range['z']

        figures['right'] = figure(plot_width=self.app_marginal_size,
                                  plot_height=self.app_main_size,
                                  min_border=10,
                                  title=None,
                                  y_range=right_y_range,
                                  y_axis_location='left',
                                  toolbar_location=None,
                                  tools=[])

        marginal_line_width = 2
        if not two_dimensional:
            bottom_data = raw_data.sel(**dict(
                [[raw_data.dims[1], self.cursor[1]]]),
                                       method='nearest')
            right_data = raw_data.sel(**dict(
                [[raw_data.dims[0], self.cursor[0]]]),
                                      method='nearest')

            plots['bottom'] = figures['bottom'].line(
                x=bottom_data.coords[raw_data.dims[0]].values,
                y=bottom_data.values,
                line_width=marginal_line_width)
            plots['bottom_residual'] = figures['bottom'].line(
                x=[], y=[], line_color='red', line_width=marginal_line_width)
            plots['bottom_fit'] = figures['bottom'].line(
                x=[],
                y=[],
                line_color='blue',
                line_width=marginal_line_width,
                line_dash='dashed')
            plots['bottom_init_fit'] = figures['bottom'].line(
                x=[],
                y=[],
                line_color='green',
                line_width=marginal_line_width,
                line_dash='dotted')

            plots['right'] = figures['right'].line(
                y=right_data.coords[raw_data.dims[1]].values,
                x=right_data.values,
                line_width=marginal_line_width)
            plots['right_residual'] = figures['right'].line(
                x=[], y=[], line_color='red', line_width=marginal_line_width)
            plots['right_fit'] = figures['right'].line(
                x=[],
                y=[],
                line_color='blue',
                line_width=marginal_line_width,
                line_dash='dashed')
            plots['right_init_fit'] = figures['right'].line(
                x=[],
                y=[],
                line_color='green',
                line_width=marginal_line_width,
                line_dash='dotted')
        else:
            right_data = raw_data.sel(**{
                k: v
                for k, v in self.cursor_dict.items() if k != fit_direction
            },
                                      method='nearest')
            plots['right'] = figures['right'].line(
                y=right_data.coords[right_data.dims[0]].values,
                x=right_data.values,
                line_width=marginal_line_width)
            plots['right_residual'] = figures['right'].line(
                x=[], y=[], line_color='red', line_width=marginal_line_width)
            plots['right_fit'] = figures['right'].line(
                x=[],
                y=[],
                line_color='blue',
                line_width=marginal_line_width,
                line_dash='dashed')
            plots['right_init_fit'] = figures['right'].line(
                x=[],
                y=[],
                line_color='green',
                line_width=marginal_line_width,
                line_dash='dotted')

        def on_change_main_view(attr, old, data_source):
            self.selected_data = data_source
            data = None
            if data_source == 'data':
                data = raw_data.sel(**{
                    k: v
                    for k, v in self.cursor_dict.items() if k == fit_direction
                },
                                    method='nearest')
            elif data_source == 'residual':
                data = residual.sel(**{
                    k: v
                    for k, v in self.cursor_dict.items() if k == fit_direction
                },
                                    method='nearest')
            elif two_dimensional:
                data = fit_results.F.s(data_source)
                data.values[np.isnan(data.values)] = 0

            if data is not None:
                if self.remove_outliers:
                    data = data.T.clean_outliers(clip=self.outlier_clip)

                plots['main'].data_source.data = {
                    'image': [data.values.T],
                }
                update_main_colormap(None, None, main_color_range_slider.value)

        def update_fit_display():
            target = 'right'
            if fit_results.dims[0] == raw_data.dims[1]:
                target = 'bottom'

            if two_dimensional:
                target = 'right'
                current_fit = fit_results.sel(**{
                    k: v
                    for k, v in self.cursor_dict.items() if k != fit_direction
                },
                                              method='nearest').item()
                coord_vals = raw_data.coords[fit_direction].values
            else:
                current_fit = fit_results.sel(**dict([[
                    fit_results.dims[0],
                    self.cursor[0 if target == 'right' else 1]
                ]]),
                                              method='nearest').item()
                coord_vals = raw_data.coords[
                    raw_data.dims[0 if target == 'bottom' else 1]].values

            if current_fit is not None:
                app_widgets['fit_info_div'].text = current_fit._repr_html_(
                    short=True)  # pylint: disable=protected-access
            else:
                app_widgets['fit_info_div'].text = 'No fit here.'
                plots['{}_residual'.format(target)].data_source.data = {
                    'x': [],
                    'y': [],
                }
                plots['{}_fit'.format(target)].data_source.data = {
                    'x': [],
                    'y': [],
                }
                plots['{}_init_fit'.format(target)].data_source.data = {
                    'x': [],
                    'y': [],
                }
                return

            if target == 'bottom':
                residual_x = coord_vals
                residual_y = current_fit.residual
                init_fit_x = coord_vals
                init_fit_y = current_fit.init_fit
                fit_x = coord_vals
                fit_y = current_fit.best_fit
            else:
                residual_y = coord_vals
                residual_x = current_fit.residual
                init_fit_y = coord_vals
                init_fit_x = current_fit.init_fit
                fit_y = coord_vals
                fit_x = current_fit.best_fit

            plots['{}_residual'.format(target)].data_source.data = {
                'x': residual_x,
                'y': residual_y,
            }
            plots['{}_fit'.format(target)].data_source.data = {
                'x': fit_x,
                'y': fit_y,
            }
            plots['{}_init_fit'.format(target)].data_source.data = {
                'x': init_fit_x,
                'y': init_fit_y,
            }

        def click_right_marginal(event):
            self.cursor = [self.cursor[0], self.cursor[1], event.y]
            on_change_main_view(None, None, self.selected_data)

        def click_main_image(event):
            if two_dimensional:
                self.cursor = [event.x, event.y, self.cursor[2]]
            else:
                self.cursor = [event.x, event.y]

            if not two_dimensional:
                right_marginal_data = raw_data.sel(**dict(
                    [[raw_data.dims[0], self.cursor[0]]]),
                                                   method='nearest')
                bottom_marginal_data = raw_data.sel(**dict(
                    [[raw_data.dims[1], self.cursor[1]]]),
                                                    method='nearest')
                plots['bottom'].data_source.data = {
                    'x': bottom_marginal_data.coords[raw_data.dims[0]].values,
                    'y': bottom_marginal_data.values,
                }
            else:
                right_marginal_data = raw_data.sel(**{
                    k: v
                    for k, v in self.cursor_dict.items() if k != fit_direction
                },
                                                   method='nearest')

            plots['right'].data_source.data = {
                'y':
                right_marginal_data.coords[right_marginal_data.dims[0]].values,
                'x': right_marginal_data.values,
            }

            update_fit_display()

        def on_change_outlier_clip(attr, old, new):
            self.outlier_clip = new
            on_change_main_view(None, None, self.selected_data)

        def set_remove_outliers(should_remove_outliers):
            if self.remove_outliers != should_remove_outliers:
                self.remove_outliers = should_remove_outliers

                on_change_main_view(None, None, self.selected_data)

        update_main_colormap = self.update_colormap_for('main')
        MAIN_CONTENT_OPTIONS = [
            ('Residual', 'residual'),
            ('Data', 'data'),
        ]

        if two_dimensional:
            available_parameters = fit_results.F.parameter_names

            for param_name in available_parameters:
                MAIN_CONTENT_OPTIONS.append((
                    param_name,
                    param_name,
                ))

        remove_outliers_toggle = widgets.Toggle(label='Remove Outliers',
                                                button_type='primary',
                                                active=self.remove_outliers)
        remove_outliers_toggle.on_click(set_remove_outliers)

        outlier_clip_slider = widgets.Slider(title='Clip',
                                             start=0,
                                             end=10,
                                             value=self.outlier_clip,
                                             callback_throttle=150,
                                             step=0.2)
        outlier_clip_slider.on_change('value', on_change_outlier_clip)

        main_content_select = widgets.Dropdown(label='Main Content',
                                               button_type='primary',
                                               menu=MAIN_CONTENT_OPTIONS)
        main_content_select.on_change('value', on_change_main_view)

        # Widgety things
        main_color_range_slider = widgets.RangeSlider(start=0,
                                                      end=100,
                                                      value=(
                                                          0,
                                                          100,
                                                      ),
                                                      title='Color Range')

        # Attach callbacks
        main_color_range_slider.on_change('value', update_main_colormap)
        figures['main'].on_event(events.Tap, click_main_image)
        if two_dimensional:
            figures['right'].on_event(events.Tap, click_right_marginal)

        layout = row(
            column(figures['main'], figures.get('bottom')),
            column(figures['right'], app_widgets['fit_info_div']),
            column(
                widgetbox(*[
                    widget for widget in [
                        self._cursor_info,
                        main_color_range_slider,
                        main_content_select,
                        remove_outliers_toggle if two_dimensional else None,
                        outlier_clip_slider if two_dimensional else None,
                    ] if widget is not None
                ]), ))

        update_fit_display()

        doc.add_root(layout)
        doc.title = 'Band Tool'
예제 #4
0
    def tool_handler_2d(self, doc):
        from bokeh import events
        from bokeh.layouts import row, column, widgetbox, Spacer
        from bokeh.models import ColumnDataSource, widgets
        from bokeh.models.mappers import LinearColorMapper
        from bokeh.models.widgets.markups import Div
        from bokeh.plotting import figure

        arr = self.arr
        # Set up the data
        x_coords, y_coords = arr.coords[arr.dims[0]], arr.coords[arr.dims[1]]

        # Styling
        default_palette = self.default_palette
        if arr.S.is_subtracted:
            default_palette = cc.coolwarm

        error_alpha = 0.3
        error_fill = '#3288bd'

        # Application Organization
        self.app_context.update({
            'data': arr,
            'data_range': {
                'x': (np.min(x_coords.values), np.max(x_coords.values)),
                'y': (np.min(y_coords.values), np.max(y_coords.values)),
            },
            'show_stat_variation': False,
            'color_mode': 'linear',
        })

        def stats_patch_from_data(data, subsampling_rate=None):
            if subsampling_rate is None:
                subsampling_rate = int(min(data.values.shape[0] / 50, 5))
                if subsampling_rate == 0:
                    subsampling_rate = 1

            x_values = data.coords[data.dims[0]].values[::subsampling_rate]
            values = data.values[::subsampling_rate]
            sq = np.sqrt(values)
            lower, upper = values - sq, values + sq

            return {
                'x': np.append(x_values, x_values[::-1]),
                'y': np.append(lower, upper[::-1]),
            }

        def update_stat_variation(plot_name, data):
            patch_data = stats_patch_from_data(data)
            if plot_name != 'right':  # the right plot is on transposed axes
                plots[plot_name +
                      '_marginal_err'].data_source.data = patch_data
            else:
                plots[plot_name + '_marginal_err'].data_source.data = {
                    'x': patch_data['y'],
                    'y': patch_data['x'],
                }

        figures, plots, app_widgets = self.app_context[
            'figures'], self.app_context['plots'], self.app_context['widgets']

        if self.cursor_default is not None and len(self.cursor_default) == 2:
            self.cursor = self.cursor_default
        else:
            self.cursor = [
                np.mean(self.app_context['data_range']['x']),
                np.mean(self.app_context['data_range']['y'])
            ]  # try a sensible default

        # create the main inset plot
        main_image = arr
        prepped_main_image = self.prep_image(main_image)
        self.app_context['color_maps']['main'] = LinearColorMapper(
            default_palette,
            low=np.min(prepped_main_image),
            high=np.max(prepped_main_image),
            nan_color='black')

        main_tools = ["wheel_zoom", "tap", "reset", "save"]
        main_title = 'Bokeh Tool: WARNING Unidentified'
        try:
            main_title = "Bokeh Tool: %s" % arr.S.label[:60]
        except:
            pass
        figures['main'] = figure(tools=main_tools,
                                 plot_width=self.app_main_size,
                                 plot_height=self.app_main_size,
                                 min_border=10,
                                 min_border_left=50,
                                 toolbar_location='left',
                                 x_axis_location='below',
                                 y_axis_location='right',
                                 title=main_title,
                                 x_range=self.app_context['data_range']['x'],
                                 y_range=self.app_context['data_range']['y'])
        figures['main'].xaxis.axis_label = arr.dims[0]
        figures['main'].yaxis.axis_label = arr.dims[1]
        figures['main'].toolbar.logo = None
        figures['main'].background_fill_color = "#fafafa"
        plots['main'] = figures['main'].image(
            [prepped_main_image.T],
            x=self.app_context['data_range']['x'][0],
            y=self.app_context['data_range']['y'][0],
            dw=self.app_context['data_range']['x'][1] -
            self.app_context['data_range']['x'][0],
            dh=self.app_context['data_range']['y'][1] -
            self.app_context['data_range']['y'][0],
            color_mapper=self.app_context['color_maps']['main'])

        app_widgets['info_div'] = Div(text='',
                                      width=self.app_marginal_size,
                                      height=100)

        # Create the bottom marginal plot
        bottom_marginal = arr.sel(**dict([[arr.dims[1], self.cursor[1]]]),
                                  method='nearest')
        figures['bottom_marginal'] = figure(
            plot_width=self.app_main_size,
            plot_height=200,
            title=None,
            x_range=figures['main'].x_range,
            y_range=(np.min(bottom_marginal.values),
                     np.max(bottom_marginal.values)),
            x_axis_location='above',
            toolbar_location=None,
            tools=[])
        plots['bottom_marginal'] = figures['bottom_marginal'].line(
            x=bottom_marginal.coords[arr.dims[0]].values,
            y=bottom_marginal.values)
        plots['bottom_marginal_err'] = figures['bottom_marginal'].patch(
            x=[],
            y=[],
            color=error_fill,
            fill_alpha=error_alpha,
            line_color=None)

        # Create the right marginal plot
        right_marginal = arr.sel(**dict([[arr.dims[0], self.cursor[0]]]),
                                 method='nearest')
        figures['right_marginal'] = figure(
            plot_width=200,
            plot_height=self.app_main_size,
            title=None,
            y_range=figures['main'].y_range,
            x_range=(np.min(right_marginal.values),
                     np.max(right_marginal.values)),
            y_axis_location='left',
            toolbar_location=None,
            tools=[])
        plots['right_marginal'] = figures['right_marginal'].line(
            y=right_marginal.coords[arr.dims[1]].values,
            x=right_marginal.values)
        plots['right_marginal_err'] = figures['right_marginal'].patch(
            x=[],
            y=[],
            color=error_fill,
            fill_alpha=error_alpha,
            line_color=None)

        cursor_lines = self.add_cursor_lines(figures['main'])

        # Attach tools and callbacks
        toggle = widgets.Toggle(label="Show Stat. Variation",
                                button_type="success",
                                active=False)

        def set_show_stat_variation(should_show):
            self.app_context['show_stat_variation'] = should_show

            if should_show:
                main_image_data = arr
                update_stat_variation(
                    'bottom',
                    main_image_data.sel(**dict([[arr.dims[1],
                                                 self.cursor[1]]]),
                                        method='nearest'))
                update_stat_variation(
                    'right',
                    main_image_data.sel(**dict([[arr.dims[0],
                                                 self.cursor[0]]]),
                                        method='nearest'))
                plots['bottom_marginal_err'].visible = True
                plots['right_marginal_err'].visible = True
            else:
                plots['bottom_marginal_err'].visible = False
                plots['right_marginal_err'].visible = False

        toggle.on_click(set_show_stat_variation)

        scan_keys = [
            'x', 'y', 'z', 'pass_energy', 'hv', 'location', 'id', 'probe_pol',
            'pump_pol'
        ]
        scan_info_source = ColumnDataSource({
            'keys': [k for k in scan_keys if k in arr.attrs],
            'values': [
                str(v) if isinstance(v, float) and np.isnan(v) else v
                for v in [arr.attrs[k] for k in scan_keys if k in arr.attrs]
            ],
        })
        scan_info_columns = [
            widgets.TableColumn(field='keys', title='Attr.'),
            widgets.TableColumn(field='values', title='Value'),
        ]

        POINTER_MODES = [
            (
                'Cursor',
                'cursor',
            ),
            (
                'Path',
                'path',
            ),
        ]

        COLOR_MODES = [
            (
                'Adaptive Hist. Eq. (Slow)',
                'adaptive_equalization',
            ),
            # ('Histogram Eq.', 'equalization',), # not implemented
            (
                'Linear',
                'linear',
            ),
            # ('Log', 'log',), # not implemented
        ]

        def on_change_color_mode(attr, old, new_color_mode):
            self.app_context['color_mode'] = new_color_mode
            if old is None or old != new_color_mode:
                right_image_data = arr.sel(**dict(
                    [[arr.dims[0], self.cursor[0]]]),
                                           method='nearest')
                bottom_image_data = arr.sel(**dict(
                    [[arr.dims[1], self.cursor[1]]]),
                                            method='nearest')
                main_image_data = arr
                prepped_right_image = self.prep_image(right_image_data)
                prepped_bottom_image = self.prep_image(bottom_image_data)
                prepped_main_image = self.prep_image(main_image_data)
                plots['right'].data_source.data = {
                    'image': [prepped_right_image]
                }
                plots['bottom'].data_source.data = {
                    'image': [prepped_bottom_image.T]
                }
                plots['main'].data_source.data = {
                    'image': [prepped_main_image.T]
                }
                update_main_colormap(None, None, main_color_range_slider.value)

        color_mode_dropdown = widgets.Dropdown(label='Color Mode',
                                               button_type='primary',
                                               menu=COLOR_MODES)
        color_mode_dropdown.on_change('value', on_change_color_mode)

        symmetry_point_name_input = widgets.TextInput(
            title='Symmetry Point Name', value="G")
        snap_checkbox = widgets.CheckboxButtonGroup(labels=['Snap Axes'],
                                                    active=[])
        place_symmetry_point_at_cursor_button = widgets.Button(
            label="Place Point", button_type="primary")

        def update_symmetry_points_for_display():
            pass

        def place_symmetry_point():
            cursor_dict = dict(zip(arr.dims, self.cursor))
            skip_dimensions = {'eV', 'delay', 'cycle'}
            if 'symmetry_points' not in arr.attrs:
                arr.attrs['symmetry_points'] = {}

            snap_distance = {
                'phi': 2,
                'beta': 2,
                'kx': 0.01,
                'ky': 0.01,
                'kz': 0.01,
                'kp': 0.01,
                'hv': 4,
            }

            cursor_dict = {
                k: v
                for k, v in cursor_dict.items() if k not in skip_dimensions
            }
            snapped = copy.copy(cursor_dict)

            if 'Snap Axes' in [
                    snap_checkbox.labels[i] for i in snap_checkbox.active
            ]:
                for axis, value in cursor_dict.items():
                    options = [
                        point[axis]
                        for point in arr.attrs['symmetry_points'].values()
                        if axis in point
                    ]
                    options = sorted(options, key=lambda x: np.abs(x - value))
                    if options and np.abs(options[0] -
                                          value) < snap_distance[axis]:
                        snapped[axis] = options[0]

            arr.attrs['symmetry_points'][
                symmetry_point_name_input.value] = snapped

        place_symmetry_point_at_cursor_button.on_click(place_symmetry_point)

        main_color_range_slider = widgets.RangeSlider(
            start=0, end=100, value=(
                0,
                100,
            ), title='Color Range (Main)')

        layout = row(
            column(figures['main'], figures['bottom_marginal']),
            column(figures['right_marginal'], Spacer(width=200, height=200)),
            column(
                widgetbox(
                    widgets.Dropdown(label='Pointer Mode',
                                     button_type='primary',
                                     menu=POINTER_MODES)),
                widgets.Tabs(tabs=[
                    widgets.Panel(child=widgetbox(
                        Div(text='<h2>Colorscale:</h2>'),
                        color_mode_dropdown,
                        main_color_range_slider,
                        Div(text=
                            '<h2 style="padding-top: 30px;">General Settings:</h2>'
                            ),
                        toggle,
                        self._cursor_info,
                        sizing_mode='scale_width'),
                                  title='Settings'),
                    widgets.Panel(child=widgetbox(
                        app_widgets['info_div'],
                        Div(text=
                            '<h2 style="padding-top: 30px; padding-bottom: 10px;">Scan Info</h2>'
                            ),
                        widgets.DataTable(source=scan_info_source,
                                          columns=scan_info_columns,
                                          width=400,
                                          height=400),
                        sizing_mode='scale_width',
                        width=400),
                                  title='Info'),
                    widgets.Panel(child=widgetbox(
                        Div(text='<h2>Preparation</h2>'),
                        symmetry_point_name_input,
                        snap_checkbox,
                        place_symmetry_point_at_cursor_button,
                        sizing_mode='scale_width'),
                                  title='Preparation'),
                ],
                             width=400)))

        update_main_colormap = self.update_colormap_for('main')

        def on_click_save(event):
            save_dataset(arr)
            print(event)

        def click_main_image(event):
            self.cursor = [event.x, event.y]

            right_marginal_data = arr.sel(**dict(
                [[arr.dims[0], self.cursor[0]]]),
                                          method='nearest')
            bottom_marginal_data = arr.sel(**dict(
                [[arr.dims[1], self.cursor[1]]]),
                                           method='nearest')
            plots['bottom_marginal'].data_source.data = {
                'x': bottom_marginal_data.coords[arr.dims[0]].values,
                'y': bottom_marginal_data.values,
            }
            plots['right_marginal'].data_source.data = {
                'y': right_marginal_data.coords[arr.dims[1]].values,
                'x': right_marginal_data.values,
            }
            if self.app_context['show_stat_variation']:
                update_stat_variation('right', right_marginal_data)
                update_stat_variation('bottom', bottom_marginal_data)
            figures['bottom_marginal'].y_range.start = np.min(
                bottom_marginal_data.values)
            figures['bottom_marginal'].y_range.end = np.max(
                bottom_marginal_data.values)
            figures['right_marginal'].x_range.start = np.min(
                right_marginal_data.values)
            figures['right_marginal'].x_range.end = np.max(
                right_marginal_data.values)

            self.save_app()

        figures['main'].on_event(events.Tap, click_main_image)
        main_color_range_slider.on_change('value', update_main_colormap)

        doc.add_root(layout)
        doc.title = "Bokeh Tool"
        self.load_app()
        self.save_app()
예제 #5
0
    def tool_handler(self, doc):
        from bokeh import events
        from bokeh.layouts import row, column, widgetbox
        from bokeh.models.mappers import LinearColorMapper
        from bokeh.models import widgets
        from bokeh.plotting import figure

        if len(self.arr.shape) != 2:
            raise AnalysisError(
                'Cannot use the band tool on non image-like spectra')

        self.data_for_display = self.arr
        x_coords, y_coords = self.data_for_display.coords[
            self.data_for_display.dims[0]], self.data_for_display.coords[
                self.data_for_display.dims[1]]

        default_palette = self.default_palette

        self.app_context.update({
            'data': self.arr,
            'data_range': {
                'x': (np.min(x_coords.values), np.max(x_coords.values)),
                'y': (np.min(y_coords.values), np.max(y_coords.values)),
            },
        })

        figures, plots = self.app_context['figures'], self.app_context['plots']

        self.cursor = [
            np.mean(self.data_range['x']),
            np.mean(self.data_range['y'])
        ]

        self.color_maps['main'] = LinearColorMapper(
            default_palette,
            low=np.min(self.data_for_display.values),
            high=np.max(self.data_for_display.values),
            nan_color='black')

        main_tools = ["wheel_zoom", "tap", "reset"]
        main_title = '{} Tool: WARNING Unidentified'.format(
            self.analysis_fn.__name__)

        try:
            main_title = '{} Tool: {}'.format(
                self.analysis_fn.__name__, self.data_for_display.S.label[:60])
        except:
            pass

        figures['main'] = figure(tools=main_tools,
                                 plot_width=self.app_main_size,
                                 plot_height=self.app_main_size,
                                 min_border=10,
                                 min_border_left=50,
                                 toolbar_location='left',
                                 x_axis_location='below',
                                 y_axis_location='right',
                                 title=main_title,
                                 x_range=self.data_range['x'],
                                 y_range=self.data_range['y'])
        figures['main'].xaxis.axis_label = self.data_for_display.dims[0]
        figures['main'].yaxis.axis_label = self.data_for_display.dims[1]
        figures['main'].toolbar.logo = None
        figures['main'].background_fill_color = "#fafafa"
        plots['main'] = figures['main'].image(
            [self.data_for_display.values.T],
            x=self.app_context['data_range']['x'][0],
            y=self.app_context['data_range']['y'][0],
            dw=self.app_context['data_range']['x'][1] -
            self.app_context['data_range']['x'][0],
            dh=self.app_context['data_range']['y'][1] -
            self.app_context['data_range']['y'][0],
            color_mapper=self.app_context['color_maps']['main'])

        # Create the bottom marginal plot
        bottom_marginal = self.data_for_display.sel(**dict(
            [[self.data_for_display.dims[1], self.cursor[1]]]),
                                                    method='nearest')
        bottom_marginal_original = self.arr.sel(**dict(
            [[self.data_for_display.dims[1], self.cursor[1]]]),
                                                method='nearest')
        figures['bottom_marginal'] = figure(
            plot_width=self.app_main_size,
            plot_height=200,
            title=None,
            x_range=figures['main'].x_range,
            y_range=(np.min(bottom_marginal.values),
                     np.max(bottom_marginal.values)),
            x_axis_location='above',
            toolbar_location=None,
            tools=[])
        plots['bottom_marginal'] = figures['bottom_marginal'].line(
            x=bottom_marginal.coords[self.data_for_display.dims[0]].values,
            y=bottom_marginal.values)
        plots['bottom_marginal_original'] = figures['bottom_marginal'].line(
            x=bottom_marginal_original.coords[self.arr.dims[0]].values,
            y=bottom_marginal_original.values,
            line_color='red')

        # Create the right marginal plot
        right_marginal = self.data_for_display.sel(**dict(
            [[self.data_for_display.dims[0], self.cursor[0]]]),
                                                   method='nearest')
        right_marginal_original = self.arr.sel(**dict(
            [[self.data_for_display.dims[0], self.cursor[0]]]),
                                               method='nearest')
        figures['right_marginal'] = figure(
            plot_width=200,
            plot_height=self.app_main_size,
            title=None,
            y_range=figures['main'].y_range,
            x_range=(np.min(right_marginal.values),
                     np.max(right_marginal.values)),
            y_axis_location='left',
            toolbar_location=None,
            tools=[])
        plots['right_marginal'] = figures['right_marginal'].line(
            y=right_marginal.coords[self.data_for_display.dims[1]].values,
            x=right_marginal.values)
        plots['right_marginal_original'] = figures['right_marginal'].line(
            y=right_marginal_original.coords[
                self.data_for_display.dims[1]].values,
            x=right_marginal_original.values,
            line_color='red')

        # add lines
        self.add_cursor_lines(figures['main'])
        _ = figures['main'].multi_line(xs=[],
                                       ys=[],
                                       line_color='white',
                                       line_width=1)  # band lines

        # prep the widgets for the analysis function
        signature = inspect.signature(self.analysis_fn)

        # drop the first which has to be the input data, we can revisit later if this is too limiting
        parameter_names = list(signature.parameters)[1:]
        named_widgets = dict(zip(parameter_names, self.widget_specification))
        built_widgets = {}

        def update_marginals():
            right_marginal_data = self.data_for_display.sel(**dict(
                [[self.data_for_display.dims[0], self.cursor[0]]]),
                                                            method='nearest')
            bottom_marginal_data = self.data_for_display.sel(**dict(
                [[self.data_for_display.dims[1], self.cursor[1]]]),
                                                             method='nearest')
            plots['bottom_marginal'].data_source.data = {
                'x': bottom_marginal_data.coords[
                    self.data_for_display.dims[0]].values,
                'y': bottom_marginal_data.values,
            }
            plots['right_marginal'].data_source.data = {
                'y': right_marginal_data.coords[
                    self.data_for_display.dims[1]].values,
                'x': right_marginal_data.values,
            }

            right_marginal_data = self.arr.sel(**dict(
                [[self.data_for_display.dims[0], self.cursor[0]]]),
                                               method='nearest')
            bottom_marginal_data = self.arr.sel(**dict(
                [[self.data_for_display.dims[1], self.cursor[1]]]),
                                                method='nearest')
            plots['bottom_marginal_original'].data_source.data = {
                'x': bottom_marginal_data.coords[
                    self.data_for_display.dims[0]].values,
                'y': bottom_marginal_data.values,
            }
            plots['right_marginal_original'].data_source.data = {
                'y': right_marginal_data.coords[
                    self.data_for_display.dims[1]].values,
                'x': right_marginal_data.values,
            }
            figures['bottom_marginal'].y_range.start = np.min(
                bottom_marginal_data.values)
            figures['bottom_marginal'].y_range.end = np.max(
                bottom_marginal_data.values)
            figures['right_marginal'].x_range.start = np.min(
                right_marginal_data.values)
            figures['right_marginal'].x_range.end = np.max(
                right_marginal_data.values)

        def click_main_image(event):
            self.cursor = [event.x, event.y]
            update_marginals()

        error_msg = widgets.Div(text='')

        @Debounce(0.25)
        def update_data_for_display():
            try:
                self.data_for_display = self.analysis_fn(
                    self.arr, *[
                        built_widgets[p].value for p in parameter_names
                        if p in built_widgets
                    ])
                error_msg.text = ''
            except Exception as e:
                error_msg.text = '{}'.format(e)

            # flush + update
            update_marginals()
            plots['main'].data_source.data = {
                'image': [self.data_for_display.values.T]
            }

        def update_data_change_wrapper(attr, old, new):
            if old != new:
                update_data_for_display()

        for parameter_name in named_widgets.keys():
            specification = named_widgets[parameter_name]

            widget = None
            if specification == int:
                widget = widgets.Slider(start=-20,
                                        end=20,
                                        value=0,
                                        title=parameter_name)
            if specification == float:
                widget = widgets.Slider(start=-20,
                                        end=20,
                                        value=0.,
                                        step=0.1,
                                        title=parameter_name)

            if widget is not None:
                built_widgets[parameter_name] = widget
                widget.on_change('value', update_data_change_wrapper)

        update_main_colormap = self.update_colormap_for('main')

        self.app_context['run'] = lambda x: x

        main_color_range_slider = widgets.RangeSlider(start=0,
                                                      end=100,
                                                      value=(
                                                          0,
                                                          100,
                                                      ),
                                                      title='Color Range')

        # Attach callbacks
        main_color_range_slider.on_change('value', update_main_colormap)

        figures['main'].on_event(events.Tap, click_main_image)

        layout = row(
            column(figures['main'], figures['bottom_marginal']),
            column(figures['right_marginal']),
            column(
                widgetbox(*[
                    built_widgets[p] for p in parameter_names
                    if p in built_widgets
                ]),
                widgetbox(
                    self._cursor_info,
                    main_color_range_slider,
                    error_msg,
                )))

        doc.add_root(layout)
        doc.title = 'Band Tool'
    def tool_handler(self, doc):
        from bokeh.layouts import row, column, widgetbox
        from bokeh.models import widgets, Spacer
        from bokeh.models.mappers import LinearColorMapper
        from bokeh.plotting import figure

        default_palette = self.default_palette

        x_coords, y_coords = self.arr.coords[
            self.arr.dims[1]], self.arr.coords[self.arr.dims[0]]
        self.app_context.update({
            'data': self.arr,
            'cached_data': {},
            'gamma_cached_data': {},
            'plots': {},
            'data_range': self.arr.T.range(),
            'figures': {},
            'widgets': {},
            'color_maps': {}
        })

        self.app_context['color_maps']['d2'] = LinearColorMapper(
            default_palette,
            low=np.min(self.arr.values),
            high=np.max(self.arr.values),
            nan_color='black')

        self.app_context['color_maps']['curvature'] = LinearColorMapper(
            default_palette,
            low=np.min(self.arr.values),
            high=np.max(self.arr.values),
            nan_color='black')

        self.app_context['color_maps']['raw'] = LinearColorMapper(
            default_palette,
            low=np.min(self.arr.values),
            high=np.max(self.arr.values),
            nan_color='black')

        plots, figures, data_range, cached_data, gamma_cached_data = (
            self.app_context['plots'],
            self.app_context['figures'],
            self.app_context['data_range'],
            self.app_context['cached_data'],
            self.app_context['gamma_cached_data'],
        )

        cached_data['raw'] = self.arr.values
        gamma_cached_data['raw'] = self.arr.values

        figure_kwargs = {
            'tools': ['reset', 'wheel_zoom'],
            'plot_width': self.app_main_size,
            'plot_height': self.app_main_size,
            'min_border': 10,
            'toolbar_location': 'left',
            'x_range': data_range['x'],
            'y_range': data_range['y'],
            'x_axis_location': 'below',
            'y_axis_location': 'right',
        }
        figures['d2'] = figure(title='d2 Spectrum', **figure_kwargs)

        figure_kwargs.update({
            'y_range': self.app_context['figures']['d2'].y_range,
            'x_range': self.app_context['figures']['d2'].x_range,
            'toolbar_location': None,
            'y_axis_location': 'left',
        })

        figures['curvature'] = figure(title='Curvature', **figure_kwargs)
        figures['raw'] = figure(title='Raw Image', **figure_kwargs)

        figures['curvature'].yaxis.major_label_text_font_size = '0pt'

        # TODO add support for color mapper
        plots['d2'] = figures['d2'].image(
            [self.arr.values],
            x=data_range['x'][0],
            y=data_range['y'][0],
            dw=data_range['x'][1] - data_range['x'][0],
            dh=data_range['y'][1] - data_range['y'][0],
            color_mapper=self.app_context['color_maps']['d2'])
        plots['curvature'] = figures['curvature'].image(
            [self.arr.values],
            x=data_range['x'][0],
            y=data_range['y'][0],
            dw=data_range['x'][1] - data_range['x'][0],
            dh=data_range['y'][1] - data_range['y'][0],
            color_mapper=self.app_context['color_maps']['curvature'])
        plots['raw'] = figures['raw'].image(
            [self.arr.values],
            x=data_range['x'][0],
            y=data_range['y'][0],
            dw=data_range['x'][1] - data_range['x'][0],
            dh=data_range['y'][1] - data_range['y'][0],
            color_mapper=self.app_context['color_maps']['raw'])

        smoothing_sliders_by_name = {}
        smoothing_sliders = []  # need one for each axis
        axis_resolution = self.arr.T.stride(generic_dim_names=False)
        for dim in self.arr.dims:
            coords = self.arr.coords[dim]
            resolution = float(axis_resolution[dim])
            high_resolution = len(coords) / 3 * resolution
            low_resolution = resolution

            # could make this axis dependent for more reasonable defaults
            default = 15 * resolution

            if default > high_resolution:
                default = (high_resolution + low_resolution) / 2

            new_slider = widgets.Slider(title='{} Window'.format(dim),
                                        start=low_resolution,
                                        end=high_resolution,
                                        step=resolution,
                                        value=default)
            smoothing_sliders.append(new_slider)
            smoothing_sliders_by_name[dim] = new_slider

        n_smoothing_steps_slider = widgets.Slider(title="Smoothing Steps",
                                                  start=0,
                                                  end=5,
                                                  step=1,
                                                  value=2)
        beta_slider = widgets.Slider(title="β",
                                     start=-8,
                                     end=8,
                                     step=1,
                                     value=0)
        direction_select = widgets.Select(
            options=list(self.arr.dims),
            value='eV' if 'eV' in self.arr.dims else
            self.arr.dims[0],  # preference to energy,
            title='Derivative Direction')
        interleave_smoothing_toggle = widgets.Toggle(
            label='Interleave smoothing with d/dx',
            active=True,
            button_type='primary')
        clamp_spectrum_toggle = widgets.Toggle(
            label='Clamp positive values to 0',
            active=True,
            button_type='primary')
        filter_select = widgets.Select(options=['Gaussian', 'Boxcar'],
                                       value='Boxcar',
                                       title='Type of Filter')

        color_slider = widgets.RangeSlider(start=0,
                                           end=100,
                                           value=(
                                               0,
                                               100,
                                           ),
                                           title='Color Clip')
        gamma_slider = widgets.Slider(start=0.1,
                                      end=4,
                                      value=1,
                                      step=0.1,
                                      title='Gamma')

        # don't need any cacheing here for now, might if this ends up being too slow
        def smoothing_fn(n_passes):
            if n_passes == 0:
                return lambda x: x

            filter_factory = {
                'Gaussian': gaussian_filter,
                'Boxcar': boxcar_filter,
            }.get(filter_select.value, boxcar_filter)

            filter_size = {
                d: smoothing_sliders_by_name[d].value
                for d in self.arr.dims
            }
            return filter_factory(filter_size, n_passes)

        @Debounce(0.25)
        def force_update():
            n_smoothing_steps = n_smoothing_steps_slider.value
            d2_data = self.arr
            if interleave_smoothing_toggle.active:
                f = smoothing_fn(n_smoothing_steps // 2)
                d2_data = d1_along_axis(f(d2_data), direction_select.value)
                f = smoothing_fn(n_smoothing_steps - (n_smoothing_steps // 2))
                d2_data = d1_along_axis(f(d2_data), direction_select.value)

            else:
                f = smoothing_fn(n_smoothing_steps)
                d2_data = d2_along_axis(f(self.arr), direction_select.value)

            d2_data.values[
                d2_data.values != d2_data.
                values] = 0  # remove NaN values until Bokeh fixes NaNs over the wire
            if clamp_spectrum_toggle.active:
                d2_data.values = -d2_data.values
                d2_data.values[d2_data.values < 0] = 0
            cached_data['d2'] = d2_data.values
            gamma_cached_data['d2'] = d2_data.values**gamma_slider.value
            plots['d2'].data_source.data = {'image': [gamma_cached_data['d2']]}

            curv_smoothing_fn = smoothing_fn(n_smoothing_steps)
            smoothed_curvature_data = curv_smoothing_fn(self.arr)
            curvature_data = curvature(smoothed_curvature_data,
                                       self.arr.dims,
                                       beta=beta_slider.value)
            curvature_data.values[
                curvature_data.values != curvature_data.values] = 0
            if clamp_spectrum_toggle.active:
                curvature_data.values = -curvature_data.values
                curvature_data.values[curvature_data.values < 0] = 0

            cached_data['curvature'] = curvature_data.values
            gamma_cached_data[
                'curvature'] = curvature_data.values**gamma_slider.value
            plots['curvature'].data_source.data = {
                'image': [gamma_cached_data['curvature']]
            }
            update_color_slider(color_slider.value)

        # TODO better integrate these, they can share code with the above if we are more careful.
        def take_d2(d2_data):
            n_smoothing_steps = n_smoothing_steps_slider.value
            if interleave_smoothing_toggle.active:
                f = smoothing_fn(n_smoothing_steps // 2)
                d2_data = d1_along_axis(f(d2_data), direction_select.value)
                f = smoothing_fn(n_smoothing_steps - (n_smoothing_steps // 2))
                d2_data = d1_along_axis(f(d2_data), direction_select.value)

            else:
                f = smoothing_fn(n_smoothing_steps)
                d2_data = d2_along_axis(f(self.arr), direction_select.value)

            d2_data.values[
                d2_data.values != d2_data.
                values] = 0  # remove NaN values until Bokeh fixes NaNs over the wire
            if clamp_spectrum_toggle.active:
                d2_data.values = -d2_data.values
                d2_data.values[d2_data.values < 0] = 0

            return d2_data

        def take_curvature(curvature_data, curve_dims):
            curv_smoothing_fn = smoothing_fn(n_smoothing_steps_slider.value)
            smoothed_curvature_data = curv_smoothing_fn(curvature_data)
            curvature_data = curvature(smoothed_curvature_data,
                                       curve_dims,
                                       beta=beta_slider.value)
            curvature_data.values[
                curvature_data.values != curvature_data.values] = 0
            if clamp_spectrum_toggle.active:
                curvature_data.values = -curvature_data.values
                curvature_data.values[curvature_data.values < 0] = 0

            return curvature_data

        # These functions will always be linked to the current context of the curvature tool.
        self.app_context['d2_fn'] = take_d2
        self.app_context['curvature_fn'] = take_curvature

        def force_update_change_wrapper(attr, old, new):
            if old != new:
                force_update()

        def force_update_click_wrapper(event):
            force_update()

        @Debounce(0.1)
        def update_color_slider(new):
            def update_plot(name, data):
                low, high = np.min(data), np.max(data)
                dynamic_range = high - low
                self.app_context['color_maps'][name].update(
                    low=low + new[0] / 100 * dynamic_range,
                    high=low + new[1] / 100 * dynamic_range)

            update_plot('d2', gamma_cached_data['d2'])
            update_plot('curvature', gamma_cached_data['curvature'])
            update_plot('raw', gamma_cached_data['raw'])

        @Debounce(0.1)
        def update_gamma_slider(new):
            gamma_cached_data['d2'] = cached_data['d2']**new
            gamma_cached_data['curvature'] = cached_data['curvature']**new
            gamma_cached_data['raw'] = cached_data['raw']**new
            update_color_slider(color_slider.value)

        def update_color_handler(attr, old, new):
            update_color_slider(new)

        def update_gamma_handler(attr, old, new):
            update_gamma_slider(new)

        layout = column(
            row(
                column(self.app_context['figures']['d2'],
                       interleave_smoothing_toggle, direction_select),
                column(self.app_context['figures']['curvature'], beta_slider,
                       clamp_spectrum_toggle),
                column(self.app_context['figures']['raw'], color_slider,
                       gamma_slider)),
            widgetbox(
                filter_select,
                *smoothing_sliders,
                n_smoothing_steps_slider,
            ),
            Spacer(height=100),
        )

        # Attach event handlers
        for w in (n_smoothing_steps_slider, beta_slider, direction_select,
                  *smoothing_sliders, filter_select):
            w.on_change('value', force_update_change_wrapper)

        interleave_smoothing_toggle.on_click(force_update_click_wrapper)
        clamp_spectrum_toggle.on_click(force_update_click_wrapper)

        color_slider.on_change('value', update_color_handler)
        gamma_slider.on_change('value', update_gamma_handler)

        force_update()

        doc.add_root(layout)
        doc.title = 'Curvature Tool'
예제 #7
0
    def tool_handler(self, doc):
        from bokeh.layouts import row, column, widgetbox
        from bokeh.models.mappers import LinearColorMapper
        from bokeh.models import widgets
        from bokeh.plotting import figure

        if len(self.arr.shape) != 2:
            raise AnalysisError(
                'Cannot use the band tool on non image-like spectra')

        arr = self.arr
        x_coords, y_coords = arr.coords[arr.dims[0]], arr.coords[arr.dims[1]]

        default_palette = self.default_palette

        self.app_context.update({
            'bands': {},
            'center_float': None,
            'data': arr,
            'data_range': {
                'x': (np.min(x_coords.values), np.max(x_coords.values)),
                'y': (np.min(y_coords.values), np.max(y_coords.values)),
            },
            'direction_normal': True,
            'fit_mode': 'mdc',
        })

        figures, plots, app_widgets = self.app_context['figures'], self.app_context['plots'], \
                                      self.app_context['widgets']
        self.cursor = [
            np.mean(self.data_range['x']),
            np.mean(self.data_range['y'])
        ]

        self.color_maps['main'] = LinearColorMapper(default_palette,
                                                    low=np.min(arr.values),
                                                    high=np.max(arr.values),
                                                    nan_color='black')

        main_tools = ["wheel_zoom", "tap", "reset"]
        main_title = 'Band Tool: WARNING Unidentified'

        try:
            main_title = 'Band Tool: {}'.format(arr.S.label[:60])
        except:
            pass

        figures['main'] = figure(tools=main_tools,
                                 plot_width=self.app_main_size,
                                 plot_height=self.app_main_size,
                                 min_border=10,
                                 min_border_left=50,
                                 toolbar_location='left',
                                 x_axis_location='below',
                                 y_axis_location='right',
                                 title=main_title,
                                 x_range=self.data_range['x'],
                                 y_range=self.data_range['y'])
        figures['main'].xaxis.axis_label = arr.dims[0]
        figures['main'].yaxis.axis_label = arr.dims[1]
        figures['main'].toolbar.logo = None
        figures['main'].background_fill_color = "#fafafa"
        plots['main'] = figures['main'].image(
            [arr.values.T],
            x=self.app_context['data_range']['x'][0],
            y=self.app_context['data_range']['y'][0],
            dw=self.app_context['data_range']['x'][1] -
            self.app_context['data_range']['x'][0],
            dh=self.app_context['data_range']['y'][1] -
            self.app_context['data_range']['y'][0],
            color_mapper=self.app_context['color_maps']['main'])

        # add lines
        self.add_cursor_lines(figures['main'])
        band_lines = figures['main'].multi_line(xs=[],
                                                ys=[],
                                                line_color='white',
                                                line_width=1)

        def append_point_to_band():
            cursor = self.cursor
            if self.active_band in self.app_context['bands']:
                self.app_context['bands'][self.active_band]['points'].append(
                    list(cursor))
                update_band_display()

        def click_main_image(event):
            self.cursor = [event.x, event.y]
            if self.pointer_mode == 'band':
                append_point_to_band()

        update_main_colormap = self.update_colormap_for('main')

        POINTER_MODES = [
            (
                'Cursor',
                'cursor',
            ),
            (
                'Band',
                'band',
            ),
        ]

        FIT_MODES = [
            (
                'EDC',
                'edc',
            ),
            (
                'MDC',
                'mdc',
            ),
        ]

        DIRECTIONS = [
            (
                'From Bottom/Left',
                'forward',
            ),
            ('From Top/Right', 'reverse'),
        ]

        BAND_TYPES = [(
            'Lorentzian',
            'Lorentzian',
        ), (
            'Voigt',
            'Voigt',
        ), (
            'Gaussian',
            'Gaussian',
        )]

        band_classes = {
            'Lorentzian': band.Band,
            'Gaussian': band.BackgroundBand,
            'Voigt': band.VoigtBand,
        }

        self.app_context['band_options'] = []

        def pack_bands():
            packed_bands = {}
            for band_name, band_description in self.app_context['bands'].items(
            ):
                if not band_description['points']:
                    raise AnalysisError('Band {} is empty.'.format(band_name))

                stray = None
                try:
                    stray = float(band_description['center_float'])
                except (KeyError, ValueError, TypeError):
                    try:
                        stray = float(self.app_context['center_float'])
                    except Exception:
                        pass

                packed_bands[band_name] = {
                    'name': band_name,
                    'band': band_classes.get(band_description['type'],
                                             band.Band),
                    'dims': self.arr.dims,
                    'params': {
                        'amplitude': {
                            'min': 0
                        },
                    },
                    'points': band_description['points'],
                }

                if stray is not None:
                    packed_bands[band_name]['params']['stray'] = stray

            return packed_bands

        def fit(override_data=None):
            packed_bands = pack_bands()
            dims = list(self.arr.dims)
            if 'eV' in dims:
                dims.remove('eV')
            angular_direction = dims[0]
            if isinstance(override_data, xr.Dataset):
                override_data = normalize_to_spectrum(override_data)
            return fit_patterned_bands(
                override_data if override_data is not None else self.arr,
                packed_bands,
                fit_direction='eV' if self.app_context['fit_mode'] == 'edc'
                else angular_direction,
                direction_normal=self.app_context['direction_normal'])

        self.app_context['pack_bands'] = pack_bands
        self.app_context['fit'] = fit

        self.pointer_dropdown = widgets.Dropdown(label='Pointer Mode',
                                                 button_type='primary',
                                                 menu=POINTER_MODES)
        self.direction_dropdown = widgets.Dropdown(label='Fit Direction',
                                                   button_type='primary',
                                                   menu=DIRECTIONS)
        self.band_dropdown = widgets.Dropdown(
            label='Active Band',
            button_type='primary',
            menu=self.app_context['band_options'])
        self.fit_mode_dropdown = widgets.Dropdown(label='Mode',
                                                  button_type='primary',
                                                  menu=FIT_MODES)
        self.band_type_dropdown = widgets.Dropdown(label='Band Type',
                                                   button_type='primary',
                                                   menu=BAND_TYPES)

        self.band_name_input = widgets.TextInput(placeholder='Band name...')
        self.center_float_widget = widgets.TextInput(
            placeholder='Center Constraint')
        self.center_float_copy = widgets.Button(label='Copy to all...')
        self.add_band_button = widgets.Button(label='Add Band')

        self.clear_band_button = widgets.Button(label='Clear Band')
        self.remove_band_button = widgets.Button(label='Remove Band')

        self.main_color_range_slider = widgets.RangeSlider(start=0,
                                                           end=100,
                                                           value=(
                                                               0,
                                                               100,
                                                           ),
                                                           title='Color Range')

        def add_band(band_name):
            if band_name not in self.app_context['bands']:
                self.app_context['band_options'].append((
                    band_name,
                    band_name,
                ))
                self.band_dropdown.menu = self.app_context['band_options']
                self.app_context['bands'][band_name] = {
                    'type': 'Lorentzian',
                    'points': [],
                    'name': band_name,
                    'center_float': None,
                }

                if self.active_band is None:
                    self.active_band = band_name

                self.save_app()

        def on_copy_center_float():
            for band_name in self.app_context['bands'].keys():
                self.app_context['bands'][band_name][
                    'center_float'] = self.app_context['center_float']
                self.save_app()

        def on_change_active_band(attr, old_band_id, band_id):
            self.app_context['active_band'] = band_id
            self.active_band = band_id

        def on_change_pointer_mode(attr, old_pointer_mode, pointer_mode):
            self.app_context['pointer_mode'] = pointer_mode
            self.pointer_mode = pointer_mode

        def set_center_float_value(attr, old_value, new_value):
            self.app_context['center_float'] = new_value
            if self.active_band in self.app_context['bands']:
                self.app_context['bands'][
                    self.active_band]['center_float'] = new_value

            self.save_app()

        def set_fit_direction(attr, old_direction, new_direction):
            self.app_context['direction_normal'] = new_direction == 'forward'
            self.save_app()

        def set_fit_mode(attr, old_mode, new_mode):
            self.app_context['fit_mode'] = new_mode
            self.save_app()

        def set_band_type(attr, old_type, new_type):
            if self.active_band in self.app_context['bands']:
                self.app_context['bands'][self.active_band]['type'] = new_type

            self.save_app()

        def update_band_display():
            band_names = self.app_context['bands'].keys()
            band_lines.data_source.data = {
                'xs': [[p[0] for p in self.app_context['bands'][b]['points']]
                       for b in band_names],
                'ys': [[p[1] for p in self.app_context['bands'][b]['points']]
                       for b in band_names],
            }
            self.save_app()

        self.update_band_display = update_band_display

        def on_clear_band():
            if self.active_band in self.app_context['bands']:
                self.app_context['bands'][self.active_band]['points'] = []
                update_band_display()

        def on_remove_band():
            if self.active_band in self.app_context['bands']:
                del self.app_context['bands'][self.active_band]
                new_band_options = [
                    b for b in self.app_context['band_options']
                    if b[0] != self.active_band
                ]
                self.band_dropdown.menu = new_band_options
                self.app_context['band_options'] = new_band_options
                self.active_band = None
                update_band_display()

        # Attach callbacks
        self.main_color_range_slider.on_change('value', update_main_colormap)

        figures['main'].on_event(events.Tap, click_main_image)
        self.band_dropdown.on_change('value', on_change_active_band)
        self.pointer_dropdown.on_change('value', on_change_pointer_mode)
        self.add_band_button.on_click(
            lambda: add_band(self.band_name_input.value))
        self.clear_band_button.on_click(on_clear_band)
        self.remove_band_button.on_click(on_remove_band)
        self.center_float_copy.on_click(on_copy_center_float)
        self.center_float_widget.on_change('value', set_center_float_value)
        self.direction_dropdown.on_change('value', set_fit_direction)
        self.fit_mode_dropdown.on_change('value', set_fit_mode)
        self.band_type_dropdown.on_change('value', set_band_type)

        layout = row(
            column(figures['main']),
            column(
                widgetbox(
                    self.pointer_dropdown,
                    self.band_dropdown,
                    self.fit_mode_dropdown,
                    self.band_type_dropdown,
                    self.direction_dropdown,
                ), row(
                    self.band_name_input,
                    self.add_band_button,
                ), row(
                    self.clear_band_button,
                    self.remove_band_button,
                ), row(self.center_float_widget, self.center_float_copy),
                widgetbox(
                    self._cursor_info,
                    self.main_color_range_slider,
                )))

        doc.add_root(layout)
        doc.title = 'Band Tool'
        self.load_app()
        self.save_app()
예제 #8
0
    def tool_handler(self, doc):
        from bokeh import events
        from bokeh.layouts import row, column, widgetbox
        from bokeh.models.mappers import LinearColorMapper
        from bokeh.models import widgets
        from bokeh.plotting import figure

        if len(self.arr.shape) != 2:
            raise AnalysisError(
                'Cannot use mask tool on non image-like spectra')

        arr = self.arr
        x_coords, y_coords = arr.coords[arr.dims[0]], arr.coords[arr.dims[1]]

        default_palette = self.default_palette

        self.app_context.update({
            'region_options': [],
            'regions': {},
            'data': arr,
            'data_range': {
                'x': (np.min(x_coords.values), np.max(x_coords.values)),
                'y': (np.min(y_coords.values), np.max(y_coords.values)),
            },
        })

        self.cursor = [
            np.mean(self.data_range['x']),
            np.mean(self.data_range['y'])
        ]

        self.color_maps['main'] = LinearColorMapper(default_palette,
                                                    low=np.min(arr.values),
                                                    high=np.max(arr.values),
                                                    nan_color='black')

        main_tools = ["wheel_zoom", "tap", "reset"]
        main_title = 'Mask Tool: WARNING Unidentified'

        try:
            main_title = 'Mask Tool: {}'.format(arr.S.label[:60])
        except:
            pass

        self.figures['main'] = figure(
            tools=main_tools,
            plot_width=self.app_main_size,
            plot_height=self.app_main_size,
            min_border=10,
            min_border_left=20,
            toolbar_location='left',
            x_axis_location='below',
            y_axis_location='right',
            title=main_title,
            x_range=self.data_range['x'],
            y_range=self.data_range['y'],
        )

        self.figures['main'].xaxis.axis_label = arr.dims[0]
        self.figures['main'].yaxis.axis_label = arr.dims[1]

        self.plots['main'] = self.figures['main'].image(
            [np.asarray(arr.values.T)],
            x=self.data_range['x'][0],
            y=self.data_range['y'][0],
            dw=self.data_range['x'][1] - self.data_range['x'][0],
            dh=self.data_range['y'][1] - self.data_range['y'][0],
            color_mapper=self.color_maps['main'])

        self.add_cursor_lines(self.figures['main'])
        region_patches = self.figures['main'].patches(xs=[],
                                                      ys=[],
                                                      color='white',
                                                      alpha=0.35,
                                                      line_width=1)

        def add_point_to_region():
            if self.active_region in self.regions:
                self.regions[self.active_region]['points'].append(
                    list(self.cursor))
                update_region_display()

            self.save_app()

        def click_main_image(event):
            self.cursor = [event.x, event.y]
            if self.pointer_mode == 'region':
                add_point_to_region()

        POINTER_MODES = [
            (
                'Cursor',
                'cursor',
            ),
            (
                'Region',
                'region',
            ),
        ]

        def perform_mask(data=None, **kwargs):
            if data is None:
                data = arr

            data = normalize_to_spectrum(data)
            return apply_mask(data, self.app_context['mask'], **kwargs)

        self.app_context['perform_mask'] = perform_mask
        self.app_context['mask'] = None

        pointer_dropdown = widgets.Dropdown(label='Pointer Mode',
                                            button_type='primary',
                                            menu=POINTER_MODES)
        self.region_dropdown = widgets.Dropdown(label='Active Region',
                                                button_type='primary',
                                                menu=self.region_options)

        edge_mask_button = widgets.Button(label='Edge Mask')
        region_name_input = widgets.TextInput(placeholder='Region name...')
        add_region_button = widgets.Button(label='Add Region')

        clear_region_button = widgets.Button(label='Clear Region')
        remove_region_button = widgets.Button(label='Remove Region')

        main_color_range_slider = widgets.RangeSlider(start=0,
                                                      end=100,
                                                      value=(
                                                          0,
                                                          100,
                                                      ),
                                                      title='Color Range')

        def on_click_edge_mask():
            if self.active_region in self.regions:
                old_points = self.regions[self.active_region]['points']
                dims = [d for d in arr.dims if 'eV' != d]
                energy_index = arr.dims.index('eV')
                max_energy = np.max([p[energy_index] for p in old_points])

                other_dim = dims[0]
                other_coord = arr.coords[other_dim].values
                min_other, max_other = np.min(other_coord), np.max(other_coord)
                min_e = np.min(arr.coords['eV'].values)

                if arr.dims.index('eV') == 0:
                    before = [[min_e - 3, min_other - 1], [0, min_other - 1]]
                    after = [[0, max_other + 1], [min_e - 3, max_other + 1]]
                else:
                    before = [[min_other - 1, min_e - 3], [min_other - 1, 0]]
                    after = [[max_other + 1, 0], [max_other + 1, min_e - 3]]
                self.regions[
                    self.active_region]['points'] = before + old_points + after
                self.app_context['mask'] = self.app_context['mask'] or {}
                self.app_context['mask']['fermi'] = max_energy
                update_region_display()

            self.save_app()

        def add_region(region_name):
            if region_name not in self.regions:
                self.region_options.append((
                    region_name,
                    region_name,
                ))
                self.region_dropdown.menu = self.region_options
                self.regions[region_name] = {
                    'points': [],
                    'name': region_name,
                }

                if self.active_region is None:
                    self.active_region = region_name

                self.save_app()

        def on_change_active_region(attr, old_region_id, region_id):
            self.app_context['active_region'] = region_id
            self.active_region = region_id
            self.save_app()

        def on_change_pointer_mode(attr, old_pointer_mode, pointer_mode):
            self.app_context['pointer_mode'] = pointer_mode
            self.pointer_mode = pointer_mode
            self.save_app()

        def update_region_display():
            region_names = self.regions.keys()

            if self.app_context['mask'] is None:
                self.app_context['mask'] = {}
            self.app_context['mask'].update({
                'dims':
                arr.dims,
                'polys': [r['points'] for r in self.regions.values()]
            })

            region_patches.data_source.data = {
                'xs': [[p[0] for p in self.regions[r]['points']]
                       for r in region_names],
                'ys': [[p[1] for p in self.regions[r]['points']]
                       for r in region_names],
            }
            self.save_app()

        self.update_region_display = update_region_display

        def on_clear_region():
            if self.active_region in self.regions:
                self.regions[self.active_region]['points'] = []
                update_region_display()

        def on_remove_region():
            if self.active_region in self.regions:
                del self.regions[self.active_region]
                new_region_options = [
                    b for b in self.region_options
                    if b[0] != self.active_region
                ]
                self.region_dropdown.menu = new_region_options
                self.region_options = new_region_options
                self.active_region = None
                update_region_display()

        # Attach callbacks
        main_color_range_slider.on_change('value',
                                          self.update_colormap_for('main'))

        self.figures['main'].on_event(events.Tap, click_main_image)
        self.region_dropdown.on_change('value', on_change_active_region)
        pointer_dropdown.on_change('value', on_change_pointer_mode)
        add_region_button.on_click(lambda: add_region(region_name_input.value))
        edge_mask_button.on_click(on_click_edge_mask)
        clear_region_button.on_click(on_clear_region)
        remove_region_button.on_click(on_remove_region)

        layout = row(
            column(self.figures['main']),
            column(*[
                f for f in [
                    widgetbox(
                        pointer_dropdown,
                        self.region_dropdown,
                    ),
                    row(
                        region_name_input,
                        add_region_button,
                    ),
                    edge_mask_button if 'eV' in arr.dims else None,
                    row(
                        clear_region_button,
                        remove_region_button,
                    ),
                    widgetbox(
                        self._cursor_info,
                        main_color_range_slider,
                    ),
                ] if f is not None
            ]))

        doc.add_root(layout)
        doc.title = 'Mask Tool'
        self.load_app()
        self.save_app()
예제 #9
0
    def tool_handler(self, doc):
        from bokeh import events
        from bokeh.layouts import row, column, widgetbox
        from bokeh.models.mappers import LinearColorMapper
        from bokeh.models import widgets, warnings
        from bokeh.plotting import figure

        if len(self.arr.shape) != 2:
            raise AnalysisError(
                'Cannot use path tool on non image-like spectra')

        arr = self.arr
        x_coords, y_coords = arr.coords[arr.dims[0]], arr.coords[arr.dims[1]]

        default_palette = self.default_palette

        self.app_context.update({
            'path_options': [],
            'active_path': None,
            'paths': {},
            'data': arr,
            'data_range': {
                'x': (np.min(x_coords.values), np.max(x_coords.values)),
                'y': (np.min(y_coords.values), np.max(y_coords.values)),
            },
        })

        self.cursor = [
            np.mean(self.data_range['x']),
            np.mean(self.data_range['y'])
        ]

        self.color_maps['main'] = LinearColorMapper(default_palette,
                                                    low=np.min(arr.values),
                                                    high=np.max(arr.values),
                                                    nan_color='black')

        main_tools = ["wheel_zoom", "tap", "reset"]
        main_title = 'Path Tool: WARNING Unidentified'

        try:
            main_title = 'Path Tool: {}'.format(arr.S.label[:60])
        except:
            pass

        self.figures['main'] = figure(
            tools=main_tools,
            plot_width=self.app_main_size,
            plot_height=self.app_main_size,
            min_border=10,
            min_border_left=20,
            toolbar_location='left',
            x_axis_location='below',
            y_axis_location='right',
            title=main_title,
            x_range=self.data_range['x'],
            y_range=self.data_range['y'],
        )

        self.figures['main'].xaxis.axis_label = arr.dims[0]
        self.figures['main'].yaxis.axis_label = arr.dims[1]

        self.plots['main'] = self.figures['main'].image(
            [np.asarray(arr.values.T)],
            x=self.data_range['x'][0],
            y=self.data_range['y'][0],
            dw=self.data_range['x'][1] - self.data_range['x'][0],
            dh=self.data_range['y'][1] - self.data_range['y'][0],
            color_mapper=self.color_maps['main'])

        self.plots['paths'] = self.figures['main'].multi_line(
            xs=[], ys=[], line_color='white', line_width=2)

        self.add_cursor_lines(self.figures['main'])

        def add_point_to_path():
            if self.active_path in self.paths:
                self.paths[self.active_path]['points'].append(list(
                    self.cursor))
                update_path_display()

            self.save_app()

        def click_main_image(event):
            self.cursor = [event.x, event.y]
            if self.pointer_mode == 'path':
                add_point_to_path()

        POINTER_MODES = [
            (
                'Cursor',
                'cursor',
            ),
            (
                'Path',
                'path',
            ),
        ]

        def convert_to_xarray():
            """
            For each of the paths, we will create a dataset which has an index dimension,
            and datavariables for each of the coordinate dimensions
            :return:
            """
            def convert_single_path_to_xarray(points):
                vars = {
                    d: np.array([p[i] for p in points])
                    for i, d in enumerate(self.arr.dims)
                }
                coords = {
                    'index': np.array(range(len(points))),
                }
                vars = {
                    k: xr.DataArray(v, coords=coords, dims=['index'])
                    for k, v in vars.items()
                }
                return xr.Dataset(data_vars=vars, coords=coords)

            return {
                path['name']: convert_single_path_to_xarray(path['points'])
                for path in self.paths.values()
            }

        def select(data=None, **kwargs):
            if data is None:
                data = self.arr

            if len(self.paths) > 1:
                warnings.warn('Only using first path.')

            return select_along_path(path=list(
                convert_to_xarray().items())[0][1],
                                     data=data,
                                     **kwargs)

        self.app_context['to_xarray'] = convert_to_xarray
        self.app_context['select'] = select

        pointer_dropdown = widgets.Dropdown(label='Pointer Mode',
                                            button_type='primary',
                                            menu=POINTER_MODES)
        self.path_dropdown = widgets.Dropdown(label='Active Path',
                                              button_type='primary',
                                              menu=self.path_options)

        path_name_input = widgets.TextInput(placeholder='Path name...')
        add_path_button = widgets.Button(label='Add Path')

        clear_path_button = widgets.Button(label='Clear Path')
        remove_path_button = widgets.Button(label='Remove Path')

        main_color_range_slider = widgets.RangeSlider(start=0,
                                                      end=100,
                                                      value=(
                                                          0,
                                                          100,
                                                      ),
                                                      title='Color Range')

        def add_path(path_name):
            if path_name not in self.paths:
                self.path_options.append((
                    path_name,
                    path_name,
                ))
                self.path_dropdown.menu = self.path_options
                self.paths[path_name] = {
                    'points': [],
                    'name': path_name,
                }

                if self.active_path is None:
                    self.active_path = path_name

                self.save_app()

        def on_change_active_path(attr, old_path_id, path_id):
            self.debug_text = path_id
            self.app_context['active_path'] = path_id
            self.active_path = path_id
            self.save_app()

        def on_change_pointer_mode(attr, old_pointer_mode, pointer_mode):
            self.app_context['pointer_mode'] = pointer_mode
            self.pointer_mode = pointer_mode
            self.save_app()

        def update_path_display():
            self.plots['paths'].data_source.data = {
                'xs': [[point[0] for point in p['points']]
                       for p in self.paths.values()],
                'ys': [[point[1] for point in p['points']]
                       for p in self.paths.values()],
            }
            self.save_app()

        self.update_path_display = update_path_display

        def on_clear_path():
            if self.active_path in self.paths:
                self.paths[self.active_path]['points'] = []
                update_path_display()

        def on_remove_path():
            if self.active_path in self.paths:
                del self.paths[self.active_path]
                new_path_options = [
                    b for b in self.path_options if b[0] != self.active_path
                ]
                self.path_dropdown.menu = new_path_options
                self.path_options = new_path_options
                self.active_path = None
                update_path_display()

        # Attach callbacks
        main_color_range_slider.on_change('value',
                                          self.update_colormap_for('main'))

        self.figures['main'].on_event(events.Tap, click_main_image)
        self.path_dropdown.on_change('value', on_change_active_path)
        pointer_dropdown.on_change('value', on_change_pointer_mode)
        add_path_button.on_click(lambda: add_path(path_name_input.value))
        clear_path_button.on_click(on_clear_path)
        remove_path_button.on_click(on_remove_path)

        layout = row(
            column(self.figures['main']),
            column(
                widgetbox(
                    pointer_dropdown,
                    self.path_dropdown,
                ),
                row(
                    path_name_input,
                    add_path_button,
                ),
                row(
                    clear_path_button,
                    remove_path_button,
                ),
                widgetbox(
                    self._cursor_info,
                    main_color_range_slider,
                ),
                self.debug_div,
            ))

        doc.add_root(layout)
        doc.title = 'Path Tool'
        self.load_app()
        self.save_app()