Beispiel #1
0
        def pack_bands():
            packed_bands = {}
            for band_name, band_description in self.app_context['bands'].items(
            ):
                if not band_description['points']:
                    raise AnalysisError('Band {} is empty.'.format(band_name))

                stray = None
                try:
                    stray = float(band_description['center_float'])
                except (KeyError, ValueError, TypeError):
                    try:
                        stray = float(self.app_context['center_float'])
                    except Exception:
                        pass

                packed_bands[band_name] = {
                    'name': band_name,
                    'band': band_classes.get(band_description['type'],
                                             band.Band),
                    'dims': self.arr.dims,
                    'params': {
                        'amplitude': {
                            'min': 0
                        },
                    },
                    'points': band_description['points'],
                }

                if stray is not None:
                    packed_bands[band_name]['params']['stray'] = stray

            return packed_bands
Beispiel #2
0
    def tool_handler(self, doc):
        from bokeh import events
        from bokeh.layouts import row, column, widgetbox
        from bokeh.models.mappers import LinearColorMapper
        from bokeh.models import widgets
        from bokeh.plotting import figure

        if len(self.arr.shape) != 2:
            raise AnalysisError(
                'Cannot use the band tool on non image-like spectra')

        self.data_for_display = self.arr
        x_coords, y_coords = self.data_for_display.coords[
            self.data_for_display.dims[0]], self.data_for_display.coords[
                self.data_for_display.dims[1]]

        default_palette = self.default_palette

        self.app_context.update({
            'data': self.arr,
            'data_range': {
                'x': (np.min(x_coords.values), np.max(x_coords.values)),
                'y': (np.min(y_coords.values), np.max(y_coords.values)),
            },
        })

        figures, plots = self.app_context['figures'], self.app_context['plots']

        self.cursor = [
            np.mean(self.data_range['x']),
            np.mean(self.data_range['y'])
        ]

        self.color_maps['main'] = LinearColorMapper(
            default_palette,
            low=np.min(self.data_for_display.values),
            high=np.max(self.data_for_display.values),
            nan_color='black')

        main_tools = ["wheel_zoom", "tap", "reset"]
        main_title = '{} Tool: WARNING Unidentified'.format(
            self.analysis_fn.__name__)

        try:
            main_title = '{} Tool: {}'.format(
                self.analysis_fn.__name__, self.data_for_display.S.label[:60])
        except:
            pass

        figures['main'] = figure(tools=main_tools,
                                 plot_width=self.app_main_size,
                                 plot_height=self.app_main_size,
                                 min_border=10,
                                 min_border_left=50,
                                 toolbar_location='left',
                                 x_axis_location='below',
                                 y_axis_location='right',
                                 title=main_title,
                                 x_range=self.data_range['x'],
                                 y_range=self.data_range['y'])
        figures['main'].xaxis.axis_label = self.data_for_display.dims[0]
        figures['main'].yaxis.axis_label = self.data_for_display.dims[1]
        figures['main'].toolbar.logo = None
        figures['main'].background_fill_color = "#fafafa"
        plots['main'] = figures['main'].image(
            [self.data_for_display.values.T],
            x=self.app_context['data_range']['x'][0],
            y=self.app_context['data_range']['y'][0],
            dw=self.app_context['data_range']['x'][1] -
            self.app_context['data_range']['x'][0],
            dh=self.app_context['data_range']['y'][1] -
            self.app_context['data_range']['y'][0],
            color_mapper=self.app_context['color_maps']['main'])

        # Create the bottom marginal plot
        bottom_marginal = self.data_for_display.sel(**dict(
            [[self.data_for_display.dims[1], self.cursor[1]]]),
                                                    method='nearest')
        bottom_marginal_original = self.arr.sel(**dict(
            [[self.data_for_display.dims[1], self.cursor[1]]]),
                                                method='nearest')
        figures['bottom_marginal'] = figure(
            plot_width=self.app_main_size,
            plot_height=200,
            title=None,
            x_range=figures['main'].x_range,
            y_range=(np.min(bottom_marginal.values),
                     np.max(bottom_marginal.values)),
            x_axis_location='above',
            toolbar_location=None,
            tools=[])
        plots['bottom_marginal'] = figures['bottom_marginal'].line(
            x=bottom_marginal.coords[self.data_for_display.dims[0]].values,
            y=bottom_marginal.values)
        plots['bottom_marginal_original'] = figures['bottom_marginal'].line(
            x=bottom_marginal_original.coords[self.arr.dims[0]].values,
            y=bottom_marginal_original.values,
            line_color='red')

        # Create the right marginal plot
        right_marginal = self.data_for_display.sel(**dict(
            [[self.data_for_display.dims[0], self.cursor[0]]]),
                                                   method='nearest')
        right_marginal_original = self.arr.sel(**dict(
            [[self.data_for_display.dims[0], self.cursor[0]]]),
                                               method='nearest')
        figures['right_marginal'] = figure(
            plot_width=200,
            plot_height=self.app_main_size,
            title=None,
            y_range=figures['main'].y_range,
            x_range=(np.min(right_marginal.values),
                     np.max(right_marginal.values)),
            y_axis_location='left',
            toolbar_location=None,
            tools=[])
        plots['right_marginal'] = figures['right_marginal'].line(
            y=right_marginal.coords[self.data_for_display.dims[1]].values,
            x=right_marginal.values)
        plots['right_marginal_original'] = figures['right_marginal'].line(
            y=right_marginal_original.coords[
                self.data_for_display.dims[1]].values,
            x=right_marginal_original.values,
            line_color='red')

        # add lines
        self.add_cursor_lines(figures['main'])
        _ = figures['main'].multi_line(xs=[],
                                       ys=[],
                                       line_color='white',
                                       line_width=1)  # band lines

        # prep the widgets for the analysis function
        signature = inspect.signature(self.analysis_fn)

        # drop the first which has to be the input data, we can revisit later if this is too limiting
        parameter_names = list(signature.parameters)[1:]
        named_widgets = dict(zip(parameter_names, self.widget_specification))
        built_widgets = {}

        def update_marginals():
            right_marginal_data = self.data_for_display.sel(**dict(
                [[self.data_for_display.dims[0], self.cursor[0]]]),
                                                            method='nearest')
            bottom_marginal_data = self.data_for_display.sel(**dict(
                [[self.data_for_display.dims[1], self.cursor[1]]]),
                                                             method='nearest')
            plots['bottom_marginal'].data_source.data = {
                'x': bottom_marginal_data.coords[
                    self.data_for_display.dims[0]].values,
                'y': bottom_marginal_data.values,
            }
            plots['right_marginal'].data_source.data = {
                'y': right_marginal_data.coords[
                    self.data_for_display.dims[1]].values,
                'x': right_marginal_data.values,
            }

            right_marginal_data = self.arr.sel(**dict(
                [[self.data_for_display.dims[0], self.cursor[0]]]),
                                               method='nearest')
            bottom_marginal_data = self.arr.sel(**dict(
                [[self.data_for_display.dims[1], self.cursor[1]]]),
                                                method='nearest')
            plots['bottom_marginal_original'].data_source.data = {
                'x': bottom_marginal_data.coords[
                    self.data_for_display.dims[0]].values,
                'y': bottom_marginal_data.values,
            }
            plots['right_marginal_original'].data_source.data = {
                'y': right_marginal_data.coords[
                    self.data_for_display.dims[1]].values,
                'x': right_marginal_data.values,
            }
            figures['bottom_marginal'].y_range.start = np.min(
                bottom_marginal_data.values)
            figures['bottom_marginal'].y_range.end = np.max(
                bottom_marginal_data.values)
            figures['right_marginal'].x_range.start = np.min(
                right_marginal_data.values)
            figures['right_marginal'].x_range.end = np.max(
                right_marginal_data.values)

        def click_main_image(event):
            self.cursor = [event.x, event.y]
            update_marginals()

        error_msg = widgets.Div(text='')

        @Debounce(0.25)
        def update_data_for_display():
            try:
                self.data_for_display = self.analysis_fn(
                    self.arr, *[
                        built_widgets[p].value for p in parameter_names
                        if p in built_widgets
                    ])
                error_msg.text = ''
            except Exception as e:
                error_msg.text = '{}'.format(e)

            # flush + update
            update_marginals()
            plots['main'].data_source.data = {
                'image': [self.data_for_display.values.T]
            }

        def update_data_change_wrapper(attr, old, new):
            if old != new:
                update_data_for_display()

        for parameter_name in named_widgets.keys():
            specification = named_widgets[parameter_name]

            widget = None
            if specification == int:
                widget = widgets.Slider(start=-20,
                                        end=20,
                                        value=0,
                                        title=parameter_name)
            if specification == float:
                widget = widgets.Slider(start=-20,
                                        end=20,
                                        value=0.,
                                        step=0.1,
                                        title=parameter_name)

            if widget is not None:
                built_widgets[parameter_name] = widget
                widget.on_change('value', update_data_change_wrapper)

        update_main_colormap = self.update_colormap_for('main')

        self.app_context['run'] = lambda x: x

        main_color_range_slider = widgets.RangeSlider(start=0,
                                                      end=100,
                                                      value=(
                                                          0,
                                                          100,
                                                      ),
                                                      title='Color Range')

        # Attach callbacks
        main_color_range_slider.on_change('value', update_main_colormap)

        figures['main'].on_event(events.Tap, click_main_image)

        layout = row(
            column(figures['main'], figures['bottom_marginal']),
            column(figures['right_marginal']),
            column(
                widgetbox(*[
                    built_widgets[p] for p in parameter_names
                    if p in built_widgets
                ]),
                widgetbox(
                    self._cursor_info,
                    main_color_range_slider,
                    error_msg,
                )))

        doc.add_root(layout)
        doc.title = 'Band Tool'
Beispiel #3
0
def convert_to_kspace(arr: xr.DataArray,
                      forward=False,
                      bounds=None,
                      resolution=None,
                      coords=None,
                      **kwargs):
    """
    "Forward" or "backward" converts the data to momentum space.

    "Backward"
    The standard method. Works in generality by regridding the data into the new coordinate space and then
    interpolating back into the original data.

    "Forward"
    By converting the coordinates, rather than by interpolating the data. As a result, the data will be totally
    unchanged by the conversion (if we do not apply a Jacobian correction), but the coordinates will no
    longer have equal spacing.

    This is only really useful for zero and one dimensional data because for two dimensional data, the coordinates
    must become two dimensional in order to fully specify every data point (this is true in generality, in 3D the
    coordinates must become 3D as well).

    The only exception to this is if the extra axes do not need to be k-space converted. As is the case where one
    of the dimensions is `cycle` or `delay`, for instance.

    You can request a particular resolution for the new data with the `resolution=` parameter,
    or a specific set of bounds with the `bounds=`

    ```
    from arpes.io import load_example_data
    f = load_example_data()

    # most standard method
    convert_to_kspace(f)

    # get a higher resolution (up-sampled) momentum image
    convert_to_kspace(f, resolution={'kp': 0.001})

    # get an image only for the positive momentum region
    convert_to_kspace(f, bounds={'kp': [0, 1]})

    # get an image manually specifying the `kp` coordinate
    convert_to_kspace(f, kp=np.linspace(0, 1, 1001))

    # or
    convert_to_kspace(f, coords={'kp': np.linspace(0, 1, 1001)})
    ```


    :param arr:
    :param forward:
    :param bounds:
    :param resolution:
    :return:
    """

    if coords is None:
        coords = {}

    coords.update(kwargs)

    if isinstance(arr, xr.Dataset):
        warnings.warn(
            'Remember to use a DataArray not a Dataset, attempting to extract spectrum'
        )
        attrs = arr.attrs.copy()
        arr = normalize_to_spectrum(arr)
        arr.attrs.update(attrs)

    if forward:
        raise NotImplementedError(
            'Forward conversion of datasets not supported. Coordinate conversion is. '
            'See `arpes.utilities.conversion.forward.convert_coordinates_to_kspace_forward`'
        )

    has_eV = 'eV' in arr.dims

    # TODO be smarter about the resolution inference
    old_dims = list(deepcopy(arr.dims))
    remove_dims = [
        'eV', 'delay', 'cycle', 'temp', 'x', 'y', 'optics_insertion'
    ]

    def unconvertible(dimension: str) -> bool:
        if dimension in remove_dims:
            return True

        if 'volt' in dimension:
            return True

        return False

    removed = []

    for to_remove in arr.dims:
        if unconvertible(to_remove):
            removed.append(to_remove)
            old_dims.remove(to_remove)

    # This should always be true because otherwise we have no hope of carrying
    # through with the conversion
    if 'eV' in removed:
        removed.remove('eV')  # This is put at the front as a standardization

    old_dims.sort()

    if not old_dims:
        return arr  # no need to convert, might be XPS or similar

    converted_dims = (['eV'] if has_eV else []) + {
        ('phi', ): ['kp'],
        ('phi', 'theta'): ['kx', 'ky'],
        ('beta', 'phi'): ['kx', 'ky'],
        ('phi', 'psi'): ['kx', 'ky'],
        ('hv', 'phi'): ['kp', 'kz'],
        ('hv', 'phi', 'theta'): ['kx', 'ky', 'kz'],
        ('beta', 'hv', 'phi'): ['kx', 'ky', 'kz'],
        ('hv', 'phi', 'psi'): ['kx', 'ky', 'kz'],
    }.get(tuple(old_dims)) + removed

    convert_cls = {
        ('phi', ): ConvertKp,
        ('beta', 'phi'): ConvertKxKy,
        ('phi', 'theta'): ConvertKxKy,
        ('phi', 'psi'): ConvertKxKy,
        #('chi', 'phi',): ConvertKxKy,
        ('hv', 'phi'): ConvertKpKz,
    }.get(tuple(old_dims))
    converter = convert_cls(arr, converted_dims)

    n_kspace_coordinates = len(
        set(converted_dims).intersection({'kp', 'kx', 'ky', 'kz'}))
    if n_kspace_coordinates > 1 and forward:
        raise AnalysisError(
            'You cannot forward convert more than one momentum to k-space.')

    converted_coordinates = converter.get_coordinates(resolution=resolution,
                                                      bounds=bounds)

    if not set(coords.keys()).issubset(converted_coordinates.keys()):
        extra = set(coords.keys()).difference(converted_coordinates.keys())
        raise ValueError('Unexpected passed coordinates: {}'.format(extra))

    converted_coordinates.update(coords)

    return convert_coordinates(
        arr, converted_coordinates, {
            'dims':
            converted_dims,
            'transforms':
            dict(zip(arr.dims, [converter.conversion_for(d)
                                for d in arr.dims]))
        })[0]
Beispiel #4
0
    def tool_handler(self, doc):
        from bokeh.layouts import row, column, widgetbox
        from bokeh.models.mappers import LinearColorMapper
        from bokeh.models import widgets
        from bokeh.plotting import figure

        if len(self.arr.shape) != 2:
            raise AnalysisError(
                'Cannot use the band tool on non image-like spectra')

        arr = self.arr
        x_coords, y_coords = arr.coords[arr.dims[0]], arr.coords[arr.dims[1]]

        default_palette = self.default_palette

        self.app_context.update({
            'bands': {},
            'center_float': None,
            'data': arr,
            'data_range': {
                'x': (np.min(x_coords.values), np.max(x_coords.values)),
                'y': (np.min(y_coords.values), np.max(y_coords.values)),
            },
            'direction_normal': True,
            'fit_mode': 'mdc',
        })

        figures, plots, app_widgets = self.app_context['figures'], self.app_context['plots'], \
                                      self.app_context['widgets']
        self.cursor = [
            np.mean(self.data_range['x']),
            np.mean(self.data_range['y'])
        ]

        self.color_maps['main'] = LinearColorMapper(default_palette,
                                                    low=np.min(arr.values),
                                                    high=np.max(arr.values),
                                                    nan_color='black')

        main_tools = ["wheel_zoom", "tap", "reset"]
        main_title = 'Band Tool: WARNING Unidentified'

        try:
            main_title = 'Band Tool: {}'.format(arr.S.label[:60])
        except:
            pass

        figures['main'] = figure(tools=main_tools,
                                 plot_width=self.app_main_size,
                                 plot_height=self.app_main_size,
                                 min_border=10,
                                 min_border_left=50,
                                 toolbar_location='left',
                                 x_axis_location='below',
                                 y_axis_location='right',
                                 title=main_title,
                                 x_range=self.data_range['x'],
                                 y_range=self.data_range['y'])
        figures['main'].xaxis.axis_label = arr.dims[0]
        figures['main'].yaxis.axis_label = arr.dims[1]
        figures['main'].toolbar.logo = None
        figures['main'].background_fill_color = "#fafafa"
        plots['main'] = figures['main'].image(
            [arr.values.T],
            x=self.app_context['data_range']['x'][0],
            y=self.app_context['data_range']['y'][0],
            dw=self.app_context['data_range']['x'][1] -
            self.app_context['data_range']['x'][0],
            dh=self.app_context['data_range']['y'][1] -
            self.app_context['data_range']['y'][0],
            color_mapper=self.app_context['color_maps']['main'])

        # add lines
        self.add_cursor_lines(figures['main'])
        band_lines = figures['main'].multi_line(xs=[],
                                                ys=[],
                                                line_color='white',
                                                line_width=1)

        def append_point_to_band():
            cursor = self.cursor
            if self.active_band in self.app_context['bands']:
                self.app_context['bands'][self.active_band]['points'].append(
                    list(cursor))
                update_band_display()

        def click_main_image(event):
            self.cursor = [event.x, event.y]
            if self.pointer_mode == 'band':
                append_point_to_band()

        update_main_colormap = self.update_colormap_for('main')

        POINTER_MODES = [
            (
                'Cursor',
                'cursor',
            ),
            (
                'Band',
                'band',
            ),
        ]

        FIT_MODES = [
            (
                'EDC',
                'edc',
            ),
            (
                'MDC',
                'mdc',
            ),
        ]

        DIRECTIONS = [
            (
                'From Bottom/Left',
                'forward',
            ),
            ('From Top/Right', 'reverse'),
        ]

        BAND_TYPES = [(
            'Lorentzian',
            'Lorentzian',
        ), (
            'Voigt',
            'Voigt',
        ), (
            'Gaussian',
            'Gaussian',
        )]

        band_classes = {
            'Lorentzian': band.Band,
            'Gaussian': band.BackgroundBand,
            'Voigt': band.VoigtBand,
        }

        self.app_context['band_options'] = []

        def pack_bands():
            packed_bands = {}
            for band_name, band_description in self.app_context['bands'].items(
            ):
                if not band_description['points']:
                    raise AnalysisError('Band {} is empty.'.format(band_name))

                stray = None
                try:
                    stray = float(band_description['center_float'])
                except (KeyError, ValueError, TypeError):
                    try:
                        stray = float(self.app_context['center_float'])
                    except Exception:
                        pass

                packed_bands[band_name] = {
                    'name': band_name,
                    'band': band_classes.get(band_description['type'],
                                             band.Band),
                    'dims': self.arr.dims,
                    'params': {
                        'amplitude': {
                            'min': 0
                        },
                    },
                    'points': band_description['points'],
                }

                if stray is not None:
                    packed_bands[band_name]['params']['stray'] = stray

            return packed_bands

        def fit(override_data=None):
            packed_bands = pack_bands()
            dims = list(self.arr.dims)
            if 'eV' in dims:
                dims.remove('eV')
            angular_direction = dims[0]
            if isinstance(override_data, xr.Dataset):
                override_data = normalize_to_spectrum(override_data)
            return fit_patterned_bands(
                override_data if override_data is not None else self.arr,
                packed_bands,
                fit_direction='eV' if self.app_context['fit_mode'] == 'edc'
                else angular_direction,
                direction_normal=self.app_context['direction_normal'])

        self.app_context['pack_bands'] = pack_bands
        self.app_context['fit'] = fit

        self.pointer_dropdown = widgets.Dropdown(label='Pointer Mode',
                                                 button_type='primary',
                                                 menu=POINTER_MODES)
        self.direction_dropdown = widgets.Dropdown(label='Fit Direction',
                                                   button_type='primary',
                                                   menu=DIRECTIONS)
        self.band_dropdown = widgets.Dropdown(
            label='Active Band',
            button_type='primary',
            menu=self.app_context['band_options'])
        self.fit_mode_dropdown = widgets.Dropdown(label='Mode',
                                                  button_type='primary',
                                                  menu=FIT_MODES)
        self.band_type_dropdown = widgets.Dropdown(label='Band Type',
                                                   button_type='primary',
                                                   menu=BAND_TYPES)

        self.band_name_input = widgets.TextInput(placeholder='Band name...')
        self.center_float_widget = widgets.TextInput(
            placeholder='Center Constraint')
        self.center_float_copy = widgets.Button(label='Copy to all...')
        self.add_band_button = widgets.Button(label='Add Band')

        self.clear_band_button = widgets.Button(label='Clear Band')
        self.remove_band_button = widgets.Button(label='Remove Band')

        self.main_color_range_slider = widgets.RangeSlider(start=0,
                                                           end=100,
                                                           value=(
                                                               0,
                                                               100,
                                                           ),
                                                           title='Color Range')

        def add_band(band_name):
            if band_name not in self.app_context['bands']:
                self.app_context['band_options'].append((
                    band_name,
                    band_name,
                ))
                self.band_dropdown.menu = self.app_context['band_options']
                self.app_context['bands'][band_name] = {
                    'type': 'Lorentzian',
                    'points': [],
                    'name': band_name,
                    'center_float': None,
                }

                if self.active_band is None:
                    self.active_band = band_name

                self.save_app()

        def on_copy_center_float():
            for band_name in self.app_context['bands'].keys():
                self.app_context['bands'][band_name][
                    'center_float'] = self.app_context['center_float']
                self.save_app()

        def on_change_active_band(attr, old_band_id, band_id):
            self.app_context['active_band'] = band_id
            self.active_band = band_id

        def on_change_pointer_mode(attr, old_pointer_mode, pointer_mode):
            self.app_context['pointer_mode'] = pointer_mode
            self.pointer_mode = pointer_mode

        def set_center_float_value(attr, old_value, new_value):
            self.app_context['center_float'] = new_value
            if self.active_band in self.app_context['bands']:
                self.app_context['bands'][
                    self.active_band]['center_float'] = new_value

            self.save_app()

        def set_fit_direction(attr, old_direction, new_direction):
            self.app_context['direction_normal'] = new_direction == 'forward'
            self.save_app()

        def set_fit_mode(attr, old_mode, new_mode):
            self.app_context['fit_mode'] = new_mode
            self.save_app()

        def set_band_type(attr, old_type, new_type):
            if self.active_band in self.app_context['bands']:
                self.app_context['bands'][self.active_band]['type'] = new_type

            self.save_app()

        def update_band_display():
            band_names = self.app_context['bands'].keys()
            band_lines.data_source.data = {
                'xs': [[p[0] for p in self.app_context['bands'][b]['points']]
                       for b in band_names],
                'ys': [[p[1] for p in self.app_context['bands'][b]['points']]
                       for b in band_names],
            }
            self.save_app()

        self.update_band_display = update_band_display

        def on_clear_band():
            if self.active_band in self.app_context['bands']:
                self.app_context['bands'][self.active_band]['points'] = []
                update_band_display()

        def on_remove_band():
            if self.active_band in self.app_context['bands']:
                del self.app_context['bands'][self.active_band]
                new_band_options = [
                    b for b in self.app_context['band_options']
                    if b[0] != self.active_band
                ]
                self.band_dropdown.menu = new_band_options
                self.app_context['band_options'] = new_band_options
                self.active_band = None
                update_band_display()

        # Attach callbacks
        self.main_color_range_slider.on_change('value', update_main_colormap)

        figures['main'].on_event(events.Tap, click_main_image)
        self.band_dropdown.on_change('value', on_change_active_band)
        self.pointer_dropdown.on_change('value', on_change_pointer_mode)
        self.add_band_button.on_click(
            lambda: add_band(self.band_name_input.value))
        self.clear_band_button.on_click(on_clear_band)
        self.remove_band_button.on_click(on_remove_band)
        self.center_float_copy.on_click(on_copy_center_float)
        self.center_float_widget.on_change('value', set_center_float_value)
        self.direction_dropdown.on_change('value', set_fit_direction)
        self.fit_mode_dropdown.on_change('value', set_fit_mode)
        self.band_type_dropdown.on_change('value', set_band_type)

        layout = row(
            column(figures['main']),
            column(
                widgetbox(
                    self.pointer_dropdown,
                    self.band_dropdown,
                    self.fit_mode_dropdown,
                    self.band_type_dropdown,
                    self.direction_dropdown,
                ), row(
                    self.band_name_input,
                    self.add_band_button,
                ), row(
                    self.clear_band_button,
                    self.remove_band_button,
                ), row(self.center_float_widget, self.center_float_copy),
                widgetbox(
                    self._cursor_info,
                    self.main_color_range_slider,
                )))

        doc.add_root(layout)
        doc.title = 'Band Tool'
        self.load_app()
        self.save_app()
Beispiel #5
0
    def tool_handler(self, doc):
        from bokeh import events
        from bokeh.layouts import row, column, widgetbox
        from bokeh.models.mappers import LinearColorMapper
        from bokeh.models import widgets
        from bokeh.plotting import figure

        if len(self.arr.shape) != 2:
            raise AnalysisError(
                'Cannot use mask tool on non image-like spectra')

        arr = self.arr
        x_coords, y_coords = arr.coords[arr.dims[0]], arr.coords[arr.dims[1]]

        default_palette = self.default_palette

        self.app_context.update({
            'region_options': [],
            'regions': {},
            'data': arr,
            'data_range': {
                'x': (np.min(x_coords.values), np.max(x_coords.values)),
                'y': (np.min(y_coords.values), np.max(y_coords.values)),
            },
        })

        self.cursor = [
            np.mean(self.data_range['x']),
            np.mean(self.data_range['y'])
        ]

        self.color_maps['main'] = LinearColorMapper(default_palette,
                                                    low=np.min(arr.values),
                                                    high=np.max(arr.values),
                                                    nan_color='black')

        main_tools = ["wheel_zoom", "tap", "reset"]
        main_title = 'Mask Tool: WARNING Unidentified'

        try:
            main_title = 'Mask Tool: {}'.format(arr.S.label[:60])
        except:
            pass

        self.figures['main'] = figure(
            tools=main_tools,
            plot_width=self.app_main_size,
            plot_height=self.app_main_size,
            min_border=10,
            min_border_left=20,
            toolbar_location='left',
            x_axis_location='below',
            y_axis_location='right',
            title=main_title,
            x_range=self.data_range['x'],
            y_range=self.data_range['y'],
        )

        self.figures['main'].xaxis.axis_label = arr.dims[0]
        self.figures['main'].yaxis.axis_label = arr.dims[1]

        self.plots['main'] = self.figures['main'].image(
            [np.asarray(arr.values.T)],
            x=self.data_range['x'][0],
            y=self.data_range['y'][0],
            dw=self.data_range['x'][1] - self.data_range['x'][0],
            dh=self.data_range['y'][1] - self.data_range['y'][0],
            color_mapper=self.color_maps['main'])

        self.add_cursor_lines(self.figures['main'])
        region_patches = self.figures['main'].patches(xs=[],
                                                      ys=[],
                                                      color='white',
                                                      alpha=0.35,
                                                      line_width=1)

        def add_point_to_region():
            if self.active_region in self.regions:
                self.regions[self.active_region]['points'].append(
                    list(self.cursor))
                update_region_display()

            self.save_app()

        def click_main_image(event):
            self.cursor = [event.x, event.y]
            if self.pointer_mode == 'region':
                add_point_to_region()

        POINTER_MODES = [
            (
                'Cursor',
                'cursor',
            ),
            (
                'Region',
                'region',
            ),
        ]

        def perform_mask(data=None, **kwargs):
            if data is None:
                data = arr

            data = normalize_to_spectrum(data)
            return apply_mask(data, self.app_context['mask'], **kwargs)

        self.app_context['perform_mask'] = perform_mask
        self.app_context['mask'] = None

        pointer_dropdown = widgets.Dropdown(label='Pointer Mode',
                                            button_type='primary',
                                            menu=POINTER_MODES)
        self.region_dropdown = widgets.Dropdown(label='Active Region',
                                                button_type='primary',
                                                menu=self.region_options)

        edge_mask_button = widgets.Button(label='Edge Mask')
        region_name_input = widgets.TextInput(placeholder='Region name...')
        add_region_button = widgets.Button(label='Add Region')

        clear_region_button = widgets.Button(label='Clear Region')
        remove_region_button = widgets.Button(label='Remove Region')

        main_color_range_slider = widgets.RangeSlider(start=0,
                                                      end=100,
                                                      value=(
                                                          0,
                                                          100,
                                                      ),
                                                      title='Color Range')

        def on_click_edge_mask():
            if self.active_region in self.regions:
                old_points = self.regions[self.active_region]['points']
                dims = [d for d in arr.dims if 'eV' != d]
                energy_index = arr.dims.index('eV')
                max_energy = np.max([p[energy_index] for p in old_points])

                other_dim = dims[0]
                other_coord = arr.coords[other_dim].values
                min_other, max_other = np.min(other_coord), np.max(other_coord)
                min_e = np.min(arr.coords['eV'].values)

                if arr.dims.index('eV') == 0:
                    before = [[min_e - 3, min_other - 1], [0, min_other - 1]]
                    after = [[0, max_other + 1], [min_e - 3, max_other + 1]]
                else:
                    before = [[min_other - 1, min_e - 3], [min_other - 1, 0]]
                    after = [[max_other + 1, 0], [max_other + 1, min_e - 3]]
                self.regions[
                    self.active_region]['points'] = before + old_points + after
                self.app_context['mask'] = self.app_context['mask'] or {}
                self.app_context['mask']['fermi'] = max_energy
                update_region_display()

            self.save_app()

        def add_region(region_name):
            if region_name not in self.regions:
                self.region_options.append((
                    region_name,
                    region_name,
                ))
                self.region_dropdown.menu = self.region_options
                self.regions[region_name] = {
                    'points': [],
                    'name': region_name,
                }

                if self.active_region is None:
                    self.active_region = region_name

                self.save_app()

        def on_change_active_region(attr, old_region_id, region_id):
            self.app_context['active_region'] = region_id
            self.active_region = region_id
            self.save_app()

        def on_change_pointer_mode(attr, old_pointer_mode, pointer_mode):
            self.app_context['pointer_mode'] = pointer_mode
            self.pointer_mode = pointer_mode
            self.save_app()

        def update_region_display():
            region_names = self.regions.keys()

            if self.app_context['mask'] is None:
                self.app_context['mask'] = {}
            self.app_context['mask'].update({
                'dims':
                arr.dims,
                'polys': [r['points'] for r in self.regions.values()]
            })

            region_patches.data_source.data = {
                'xs': [[p[0] for p in self.regions[r]['points']]
                       for r in region_names],
                'ys': [[p[1] for p in self.regions[r]['points']]
                       for r in region_names],
            }
            self.save_app()

        self.update_region_display = update_region_display

        def on_clear_region():
            if self.active_region in self.regions:
                self.regions[self.active_region]['points'] = []
                update_region_display()

        def on_remove_region():
            if self.active_region in self.regions:
                del self.regions[self.active_region]
                new_region_options = [
                    b for b in self.region_options
                    if b[0] != self.active_region
                ]
                self.region_dropdown.menu = new_region_options
                self.region_options = new_region_options
                self.active_region = None
                update_region_display()

        # Attach callbacks
        main_color_range_slider.on_change('value',
                                          self.update_colormap_for('main'))

        self.figures['main'].on_event(events.Tap, click_main_image)
        self.region_dropdown.on_change('value', on_change_active_region)
        pointer_dropdown.on_change('value', on_change_pointer_mode)
        add_region_button.on_click(lambda: add_region(region_name_input.value))
        edge_mask_button.on_click(on_click_edge_mask)
        clear_region_button.on_click(on_clear_region)
        remove_region_button.on_click(on_remove_region)

        layout = row(
            column(self.figures['main']),
            column(*[
                f for f in [
                    widgetbox(
                        pointer_dropdown,
                        self.region_dropdown,
                    ),
                    row(
                        region_name_input,
                        add_region_button,
                    ),
                    edge_mask_button if 'eV' in arr.dims else None,
                    row(
                        clear_region_button,
                        remove_region_button,
                    ),
                    widgetbox(
                        self._cursor_info,
                        main_color_range_slider,
                    ),
                ] if f is not None
            ]))

        doc.add_root(layout)
        doc.title = 'Mask Tool'
        self.load_app()
        self.save_app()
Beispiel #6
0
    def tool_handler(self, doc):
        from bokeh import events
        from bokeh.layouts import row, column, widgetbox
        from bokeh.models.mappers import LinearColorMapper
        from bokeh.models import widgets, warnings
        from bokeh.plotting import figure

        if len(self.arr.shape) != 2:
            raise AnalysisError(
                'Cannot use path tool on non image-like spectra')

        arr = self.arr
        x_coords, y_coords = arr.coords[arr.dims[0]], arr.coords[arr.dims[1]]

        default_palette = self.default_palette

        self.app_context.update({
            'path_options': [],
            'active_path': None,
            'paths': {},
            'data': arr,
            'data_range': {
                'x': (np.min(x_coords.values), np.max(x_coords.values)),
                'y': (np.min(y_coords.values), np.max(y_coords.values)),
            },
        })

        self.cursor = [
            np.mean(self.data_range['x']),
            np.mean(self.data_range['y'])
        ]

        self.color_maps['main'] = LinearColorMapper(default_palette,
                                                    low=np.min(arr.values),
                                                    high=np.max(arr.values),
                                                    nan_color='black')

        main_tools = ["wheel_zoom", "tap", "reset"]
        main_title = 'Path Tool: WARNING Unidentified'

        try:
            main_title = 'Path Tool: {}'.format(arr.S.label[:60])
        except:
            pass

        self.figures['main'] = figure(
            tools=main_tools,
            plot_width=self.app_main_size,
            plot_height=self.app_main_size,
            min_border=10,
            min_border_left=20,
            toolbar_location='left',
            x_axis_location='below',
            y_axis_location='right',
            title=main_title,
            x_range=self.data_range['x'],
            y_range=self.data_range['y'],
        )

        self.figures['main'].xaxis.axis_label = arr.dims[0]
        self.figures['main'].yaxis.axis_label = arr.dims[1]

        self.plots['main'] = self.figures['main'].image(
            [np.asarray(arr.values.T)],
            x=self.data_range['x'][0],
            y=self.data_range['y'][0],
            dw=self.data_range['x'][1] - self.data_range['x'][0],
            dh=self.data_range['y'][1] - self.data_range['y'][0],
            color_mapper=self.color_maps['main'])

        self.plots['paths'] = self.figures['main'].multi_line(
            xs=[], ys=[], line_color='white', line_width=2)

        self.add_cursor_lines(self.figures['main'])

        def add_point_to_path():
            if self.active_path in self.paths:
                self.paths[self.active_path]['points'].append(list(
                    self.cursor))
                update_path_display()

            self.save_app()

        def click_main_image(event):
            self.cursor = [event.x, event.y]
            if self.pointer_mode == 'path':
                add_point_to_path()

        POINTER_MODES = [
            (
                'Cursor',
                'cursor',
            ),
            (
                'Path',
                'path',
            ),
        ]

        def convert_to_xarray():
            """
            For each of the paths, we will create a dataset which has an index dimension,
            and datavariables for each of the coordinate dimensions
            :return:
            """
            def convert_single_path_to_xarray(points):
                vars = {
                    d: np.array([p[i] for p in points])
                    for i, d in enumerate(self.arr.dims)
                }
                coords = {
                    'index': np.array(range(len(points))),
                }
                vars = {
                    k: xr.DataArray(v, coords=coords, dims=['index'])
                    for k, v in vars.items()
                }
                return xr.Dataset(data_vars=vars, coords=coords)

            return {
                path['name']: convert_single_path_to_xarray(path['points'])
                for path in self.paths.values()
            }

        def select(data=None, **kwargs):
            if data is None:
                data = self.arr

            if len(self.paths) > 1:
                warnings.warn('Only using first path.')

            return select_along_path(path=list(
                convert_to_xarray().items())[0][1],
                                     data=data,
                                     **kwargs)

        self.app_context['to_xarray'] = convert_to_xarray
        self.app_context['select'] = select

        pointer_dropdown = widgets.Dropdown(label='Pointer Mode',
                                            button_type='primary',
                                            menu=POINTER_MODES)
        self.path_dropdown = widgets.Dropdown(label='Active Path',
                                              button_type='primary',
                                              menu=self.path_options)

        path_name_input = widgets.TextInput(placeholder='Path name...')
        add_path_button = widgets.Button(label='Add Path')

        clear_path_button = widgets.Button(label='Clear Path')
        remove_path_button = widgets.Button(label='Remove Path')

        main_color_range_slider = widgets.RangeSlider(start=0,
                                                      end=100,
                                                      value=(
                                                          0,
                                                          100,
                                                      ),
                                                      title='Color Range')

        def add_path(path_name):
            if path_name not in self.paths:
                self.path_options.append((
                    path_name,
                    path_name,
                ))
                self.path_dropdown.menu = self.path_options
                self.paths[path_name] = {
                    'points': [],
                    'name': path_name,
                }

                if self.active_path is None:
                    self.active_path = path_name

                self.save_app()

        def on_change_active_path(attr, old_path_id, path_id):
            self.debug_text = path_id
            self.app_context['active_path'] = path_id
            self.active_path = path_id
            self.save_app()

        def on_change_pointer_mode(attr, old_pointer_mode, pointer_mode):
            self.app_context['pointer_mode'] = pointer_mode
            self.pointer_mode = pointer_mode
            self.save_app()

        def update_path_display():
            self.plots['paths'].data_source.data = {
                'xs': [[point[0] for point in p['points']]
                       for p in self.paths.values()],
                'ys': [[point[1] for point in p['points']]
                       for p in self.paths.values()],
            }
            self.save_app()

        self.update_path_display = update_path_display

        def on_clear_path():
            if self.active_path in self.paths:
                self.paths[self.active_path]['points'] = []
                update_path_display()

        def on_remove_path():
            if self.active_path in self.paths:
                del self.paths[self.active_path]
                new_path_options = [
                    b for b in self.path_options if b[0] != self.active_path
                ]
                self.path_dropdown.menu = new_path_options
                self.path_options = new_path_options
                self.active_path = None
                update_path_display()

        # Attach callbacks
        main_color_range_slider.on_change('value',
                                          self.update_colormap_for('main'))

        self.figures['main'].on_event(events.Tap, click_main_image)
        self.path_dropdown.on_change('value', on_change_active_path)
        pointer_dropdown.on_change('value', on_change_pointer_mode)
        add_path_button.on_click(lambda: add_path(path_name_input.value))
        clear_path_button.on_click(on_clear_path)
        remove_path_button.on_click(on_remove_path)

        layout = row(
            column(self.figures['main']),
            column(
                widgetbox(
                    pointer_dropdown,
                    self.path_dropdown,
                ),
                row(
                    path_name_input,
                    add_path_button,
                ),
                row(
                    clear_path_button,
                    remove_path_button,
                ),
                widgetbox(
                    self._cursor_info,
                    main_color_range_slider,
                ),
                self.debug_div,
            ))

        doc.add_root(layout)
        doc.title = 'Path Tool'
        self.load_app()
        self.save_app()