Ejemplo n.º 1
0
class OutputTest(unittest.TestCase):
    def setUp(self):
        if 'OUTPUT_TESTS' not in os.environ:
            raise SkipTest('Slow: define OUTPUT_TESTS to run')
        self.data = xye.XYEDataset.from_file(
            r'tests/testdata/si640c_low_temp_cal_p1_scan0.000000_adv0_0000.xye'
        )

        class UI(object):
            color = None
            name = ''
            active = True
            markers = False

        self.data.metadata['ui'] = UI()
        self.datasets = [self.data]
        self.plot = RawDataPlot(self.datasets)
        self.plot.plot_datasets(self.datasets, scale='log')
        self.container = OverlayPlotContainer(self.plot.get_plot(),
                                              bgcolor="white",
                                              use_backbuffer=True,
                                              border_visible=False)
        self.container.request_redraw()
        self.basedir = os.path.join('tests', 'tmp')
        try:
            os.mkdir(self.basedir)
        except OSError, e:
            assert 'exists' in str(e)
Ejemplo n.º 2
0
class HistogramPlotHandler(HasTraits):
    """
    Class for handling the histograms.
    """

    # Index for the histogram plot
    index = Array

    # The selection handler object for the selected data
    selection_handler = Instance(SelectionHandler)

    # OVerlayPlotContainer for the histogram plot
    container = Instance(OverlayPlotContainer)

    # Number of bins of the histogram
    nbins = Int(10)

    # Whether the data is a pandas dataframe or a numpy array
    AS_PANDAS_DATAFRAME = Bool

    def __init__(self):
        self.index = range(self.nbins)
        self.selection_handler = SelectionHandler()
        self.container = OverlayPlotContainer()

    def draw_histogram(self):
        """
        Default function called when drawing the histogram.
        """
        for component in self.container.components:
            self.container.remove(component)

        self.selection_handler.create_selection()

        if len(self.selection_handler.selected_indices) == 1:
            tuple_list = self.selection_handler.selected_indices[0]
            if self.AS_PANDAS_DATAFRAME:
                column_name = self.data.columns[tuple_list[1]]
                y = self.data[column_name]
                self.index = np.arange(self.nbins)
                hist = np.histogram(y, self.nbins)[0]
                plotdata = ArrayPlotData(x=self.index, y=hist)
                plot = Plot(plotdata)
                plot.plot(("x", "y"), type="bar", bar_width=0.5)
                self.container.add(plot)
            else:
                column = tuple_list[1]
                y = self.data[:, column]
                self.index = np.arange(self.nbins)
                hist = np.histogram(y, self.nbins)[0]
                plotdata = ArrayPlotData(x=self.index, y=hist)
                plot = Plot(plotdata)
                plot.plot(("x", "y"), type="bar", bar_width=0.5)
                self.container.add(plot)

            self.container.request_redraw()

        self.selection_handler.flush()
Ejemplo n.º 3
0
class HistogramPlotHandler(HasTraits):
    '''
    Class for handling the histograms.
    '''

    # Index for the histogram plot
    index = Array

    # The selection handler object for the selected data
    selection_handler = Instance(SelectionHandler)

    # OVerlayPlotContainer for the histogram plot
    container = Instance(OverlayPlotContainer)

    # Number of bins of the histogram
    nbins = Int(10)

    # Whether the data is a pandas dataframe or a numpy array
    AS_PANDAS_DATAFRAME = Bool

    def __init__(self):
        self.index = range(self.nbins)
        self.selection_handler = SelectionHandler()
        self.container = OverlayPlotContainer()

    def draw_histogram(self):
        '''
        Default function called when drawing the histogram.
        '''
        for component in self.container.components:
            self.container.remove(component)

        self.selection_handler.create_selection()

        if len(self.selection_handler.selected_indices) == 1:
            tuple_list = self.selection_handler.selected_indices[0]
            if self.AS_PANDAS_DATAFRAME:
                column_name = self.data.columns[tuple_list[1]]
                y = self.data[column_name]
                self.index = np.arange(self.nbins)
                hist = np.histogram(y, self.nbins)[0]
                plotdata = ArrayPlotData(x=self.index, y=hist)
                plot = Plot(plotdata)
                plot.plot(("x", "y"), type='bar', bar_width=0.5)
                self.container.add(plot)
            else:
                column = tuple_list[1]
                y = self.data[:, column]
                self.index = np.arange(self.nbins)
                hist = np.histogram(y, self.nbins)[0]
                plotdata = ArrayPlotData(x=self.index, y=hist)
                plot = Plot(plotdata)
                plot.plot(("x", "y"), type='bar', bar_width=0.5)
                self.container.add(plot)

            self.container.request_redraw()

        self.selection_handler.flush()
Ejemplo n.º 4
0
class PCPlotHandler(HasTraits):
    '''
    Class for handling principal component plots.
    '''

    # The container for the plot.
    container = OverlayPlotContainer()

    # the sklearn.decmoposition.PCA object
    pca = PCA

    # Whether or not to normalize the data, one of the parameters of the PCA
    # object
    whiten = Bool

    # The input data.
    table = Array

    # The selection_handler instance for the tableview
    selection_handler = Instance(SelectionHandler)

    def __init__(self):
        self.pca = PCA(n_components=2)
        self.pca.whiten = True
        self.container = OverlayPlotContainer()
        self.selection_handler = SelectionHandler()

    def draw_pc_plot(self):
        '''
        Called to draw the PCA plot.
        '''
        self.selection_handler.create_selection()
        if len(self.selection_handler.selected_indices) == 1:
            top_left = self.selection_handler.selected_indices[0][0:2]
            bot_right = self.selection_handler.selected_indices[0][2:4]
            data = self.table[top_left[0]:bot_right[0],
                              top_left[1]:bot_right[1]]
            pc_red = self.pca.fit_transform(data)
            plotdata = ArrayPlotData(x=pc_red[:, 0], y=pc_red[:, 1])
            plot = Plot(plotdata)
            plot.plot(("x", "y"), type='scatter')
            self.container.add(plot)
            self.container.request_redraw()
Ejemplo n.º 5
0
class PCPlotHandler(HasTraits):
    """
    Class for handling principal component plots.
    """

    # The container for the plot.
    container = OverlayPlotContainer()

    # the sklearn.decmoposition.PCA object
    pca = PCA

    # Whether or not to normalize the data, one of the parameters of the PCA
    # object
    whiten = Bool

    # The input data.
    table = Array

    # The selection_handler instance for the tableview
    selection_handler = Instance(SelectionHandler)

    def __init__(self):
        self.pca = PCA(n_components=2)
        self.pca.whiten = True
        self.container = OverlayPlotContainer()
        self.selection_handler = SelectionHandler()

    def draw_pc_plot(self):
        """
        Called to draw the PCA plot.
        """
        self.selection_handler.create_selection()
        if len(self.selection_handler.selected_indices) == 1:
            top_left = self.selection_handler.selected_indices[0][0:2]
            bot_right = self.selection_handler.selected_indices[0][2:4]
            data = self.table[top_left[0] : bot_right[0], top_left[1] : bot_right[1]]
            pc_red = self.pca.fit_transform(data)
            plotdata = ArrayPlotData(x=pc_red[:, 0], y=pc_red[:, 1])
            plot = Plot(plotdata)
            plot.plot(("x", "y"), type="scatter")
            self.container.add(plot)
            self.container.request_redraw()
Ejemplo n.º 6
0
class OutputTest(unittest.TestCase):
    def setUp(self):
        if 'OUTPUT_TESTS' not in os.environ:
            raise SkipTest('Slow: define OUTPUT_TESTS to run')
        self.data = xye.XYEDataset.from_file(r'tests/testdata/si640c_low_temp_cal_p1_scan0.000000_adv0_0000.xye')
        class UI(object):
            color = None
            name = ''
            active = True
            markers = False
        self.data.metadata['ui'] = UI()
        self.datasets = [ self.data ]
        self.plot = RawDataPlot(self.datasets)
        self.plot.plot_datasets(self.datasets, scale='log')
        self.container = OverlayPlotContainer(self.plot.get_plot(),
            bgcolor="white", use_backbuffer=True,
            border_visible=False)
        self.container.request_redraw()
        self.basedir = os.path.join('tests', 'tmp')
        try:
            os.mkdir(self.basedir)
        except OSError, e:
            assert 'exists' in str(e)
Ejemplo n.º 7
0
class RegressionPlotHandler(HasTraits):
    '''
    Class for handling regression plots
    '''

    # The input data from the csv file
    data = Array

    # The input, or the selected column / row
    Y = Array

    # OLS fitted values of the current selection
    selection_olsfit = Array

    # the index used to plot the output
    index = Array

    # the container for the plots
    container = Instance(OverlayPlotContainer)

    # The selection handler object for the tableview
    selection_handler = Instance(SelectionHandler)

    def __init__(self):
        self.selection_handler = SelectionHandler()
        self.container = OverlayPlotContainer()

    def fit_selection(self):
        '''
        Function that computes the curve to fit.
        '''
        self.selection_handler.create_selection()
        if len(self.selection_handler.selected_indices) == 1:
            tuple_list = self.selection_handler.selected_indices[0]
            if tuple_list[1] == tuple_list[3]:
                L = tuple_list[2] - tuple_list[0]
                self.index = np.arange(L + 1)
                self.Y = self.data[:, tuple_list[1]]

                results = OLS(self.Y, self.index).fit()
                self.selection_olsfit = results.fittedvalues
        self.selection_handler.flush()

    def plot_fits(self):
        '''
        Function called to plot the regression fits.
        '''
        components = []

        for component in self.container.components:
            components.append(component)

        for component in components:
            self.container.components.remove(component)

        plotdata = ArrayPlotData(x=self.index, y=self.Y)
        plot = Plot(plotdata)
        plot.plot(("x", "y"), type='line', color='red')
        plot.line_style = 'dash'
        self.container.add(plot)

        plotdata = ArrayPlotData(x=self.index, y=self.selection_olsfit)
        plot = Plot(plotdata)
        plot.plot(("x", "y"), type='line', color='blue')
        self.container.add(plot)

        self.container.request_redraw()
Ejemplo n.º 8
0
class RegressionPlotHandler(HasTraits):
    """
    Class for handling regression plots
    """

    # The input data from the csv file
    data = Array

    # The input, or the selected column / row
    Y = Array

    # OLS fitted values of the current selection
    selection_olsfit = Array

    # the index used to plot the output
    index = Array

    # the container for the plots
    container = Instance(OverlayPlotContainer)

    # The selection handler object for the tableview
    selection_handler = Instance(SelectionHandler)

    def __init__(self):
        self.selection_handler = SelectionHandler()
        self.container = OverlayPlotContainer()

    def fit_selection(self):
        """
        Function that computes the curve to fit.
        """
        self.selection_handler.create_selection()
        if len(self.selection_handler.selected_indices) == 1:
            tuple_list = self.selection_handler.selected_indices[0]
            if tuple_list[1] == tuple_list[3]:
                L = tuple_list[2] - tuple_list[0]
                self.index = np.arange(L + 1)
                self.Y = self.data[:, tuple_list[1]]

                results = OLS(self.Y, self.index).fit()
                self.selection_olsfit = results.fittedvalues
        self.selection_handler.flush()

    def plot_fits(self):
        """
        Function called to plot the regression fits.
        """
        components = []

        for component in self.container.components:
            components.append(component)

        for component in components:
            self.container.components.remove(component)

        plotdata = ArrayPlotData(x=self.index, y=self.Y)
        plot = Plot(plotdata)
        plot.plot(("x", "y"), type="line", color="red")
        plot.line_style = "dash"
        self.container.add(plot)

        plotdata = ArrayPlotData(x=self.index, y=self.selection_olsfit)
        plot = Plot(plotdata)
        plot.plot(("x", "y"), type="line", color="blue")
        self.container.add(plot)

        self.container.request_redraw()
Ejemplo n.º 9
0
class ImagePlotHandler(HasTraits):
    """
    Class for handling image plots.
    """

    # the overlay container for the plot
    container = Instance(OverlayPlotContainer)

    # the selection handler instance for the tableview
    selection_handler = Instance(SelectionHandler)

    # the input data.
    table = Array

    # The chaco colorbar object accompanying the plot.
    colorbar = ColorBar

    def __init__(self):
        self.container = OverlayPlotContainer()

        self.selection_handler = SelectionHandler()

    def imageplot_check(self):
        """
        Called to check whether the current selection is compatible with an
        image plot.
        """
        if len(self.selection_handler.selected_indices) > 1:
            shape_list = []
            for index_tuple in self.selection_handler.selected_handler:
                x = self.table[index_tuple[0] : index_tuple[2], index_tuple[1] : index_tuple[3]]
                shape_list.append(x.shape)
            shape_ = shape_list[0]
            if shape_list.count(shape_) != len(shape_list):
                return False
            return True
        else:
            return True

    def toggle_colorbar(self, checked):
        """
        Function called to toggle the colorbar on the image plot.
        """
        if not checked:
            for component in self.container.components:
                if isinstance(component, ColorBar):
                    self.colorbar = component
            self.container.components.remove(self.colorbar)
        else:
            self.container.add(self.colorbar)
        self.container.request_redraw()

    def draw_image_plot(self):
        """
        Function called to draw the image plot.
        """
        self.top_left = self.selection_handler.selected_indices[0][0:2]
        self.bot_right = self.selection_handler.selected_indices[0][2:4]
        data = self.table[self.top_left[0] : self.bot_right[0], self.top_left[1] : self.bot_right[1]]
        plotdata = ArrayPlotData(imagedata=data)
        plot = Plot(plotdata)
        plot.img_plot("imagedata")
        plot.tools.append(PanTool(plot))
        plot.tools.append(ZoomTool(plot))
        plot.tools.append(TraitsTool(plot))
        self.container.add(plot)

        # colorbar = ColorBar(
        #    index_mapper=LinearMapper(range=plot.color_mapper.range),
        #    color_mapper = plot.color_mapper,
        #    orientation='v'
        # )
        # self.colorbar = ColorBar
        # self.container.add(colorbar)

        self.container.request_redraw()
Ejemplo n.º 10
0
class XYPlotHandler(HasTraits):
    """
    Class for handling XY plots
    """

    # Whether the data is a pandas dataframe
    AS_PANDAS_DATAFRAME = Bool

    # The container for all current plots. Gets updated everytime a plot is
    # added.
    container = OverlayPlotContainer

    # This can be removed.
    plotdata = ArrayPlotData

    # The current Plot object.
    plot = Plot

    # ColorTrait, mainly required for the TraitsUIItem view.
    color = ColorTrait("blue")

    # Marker trait for the view
    marker = marker_trait

    # Marker size trait
    marker_size = Int(4)

    # An instance of SelectionHandler for adding plots from the current
    # selection.
    selection_handler = Instance(SelectionHandler)

    # Bool traits for checking the type of the plot (discrete / continuous)
    plot_type_disc = Bool
    plot_type_cont = Bool

    # The data from which to draw the plots, same as the table attribute of
    # CsvModel
    table = Array

    # The pandas data frame if AS_PANDAS_DATAFRAME
    data_frame = Instance(DataFrame)

    # Contains the grid underlays of all the current plots
    grid_underlays = List

    # Used for viewing the list of the plots and the legend
    plot_list_view = Dict

    # TraitsUI view for plot properties, yet to find an enaml implementation
    view = View(Item("color"), Item("marker"), Item("marker_size"))

    # Trait that defines whether tools are present.
    add_pan_tool = Bool
    add_zoom_tool = Bool
    add_dragzoom = Bool

    # Whether grids and axes are visible
    show_grid = Bool

    def __init__(self):
        self.selection_handler = SelectionHandler()
        self.container = OverlayPlotContainer()
        self.underlays = []
        self.add_pan_tool = False
        self.add_zoom_tool = False
        self.add_dragzoom = False
        self.show_grid = False

    def add_xyplot_selection(self, plot_name):
        """
        Called when the 'add plot from selection button is clicked.'
        """

        self.selection_handler.create_selection()
        if self.selection_handler.xyplot_check():

            if self.AS_PANDAS_DATAFRAME:
                x_column = self.data_frame.columns[self.selection_handler.selected_indices[0][1]]
                y_column = self.data_frame.columns[self.selection_handler.selected_indices[1][1]]
                x = np.array(self.data_frame[x_column])
                y = np.array(self.data_frame[y_column])
                self.plotdata = ArrayPlotData(x=x, y=y)

            else:

                first_column = self.selection_handler.selected_indices[0]
                second_column = self.selection_handler.selected_indices[1]
                self.plotdata = ArrayPlotData(x=self.table[:, first_column[1]], y=self.table[:, second_column[1]])

            plot = Plot(self.plotdata)

            if self.plot_type_disc:
                plot_type = "scatter"
            else:
                plot_type = "line"
            plot.plot(("x", "y"), type=plot_type, color=self.color, marker=self.marker, marker_size=self.marker_size)

            self.plot = plot

            for underlay in self.plot.underlays:
                if isinstance(underlay, PlotGrid):
                    if underlay not in self.grid_underlays:
                        self.grid_underlays.append(underlay)

            for underlay in self.grid_underlays:
                if underlay in self.plot.underlays:
                    self.plot.underlays.remove(underlay)

            if plot_name == "":
                self.plot_list_view["plot" + str(len(self.plot_list_view))] = self.plot
            else:
                self.plot_list_view[plot_name] = self.plot
            self.container.add(self.plot)

            self.container.request_redraw()

        self.selection_handler.flush()

    def grid_toggle(self, checked):
        """
        Called when the 'Show Grid' checkbox ins toggled
        """
        if not checked:
            for plot in self.container.components:
                for underlay in self.grid_underlays:
                    if underlay in plot.underlays:
                        plot.underlays.remove(underlay)
        else:
            for plot in self.container.components:
                for underlay in self.grid_underlays:
                    if underlay not in plot.underlays:
                        plot.underlays.append(underlay)
        self.container.request_redraw()

    def remove_selected_plots(self, selection):
        """
        Called when the 'Remove Selected Plots' button is clicked
        """
        remove_indices = []
        for model_index in selection:
            remove_indices.append(model_index[0].row)
        remove_plots = []
        for index in remove_indices:
            remove_plots.append(self.plot_list_view.keys()[index])

        removed_plots = []
        for plot in remove_plots:
            removed_plots.append(self.plot_list_view.pop(plot))
        for plot in self.container.components:
            self.container.remove(plot)
        for plot in self.plot_list_view.keys():
            self.container.add(self.plot_list_view[plot])
        self.container.request_redraw()

    def edit_selection(self, show_grid, plot_visible, plot_type_disc):
        """
        Called to start editing the selected plot. Should accompany the 'Edit
        Plot' dialog.
        """

        # self.selection_handler.create_selection()
        # index = self.selection_handler.selected_indices[0][0]
        # plot_name = self.plot_list_view.keys()[index]
        # plot = self.plot_list_view[plot_name]

        self.container.remove(self.plot)

        self.plot_type_disc = plot_type_disc

        if self.plot_type_disc:
            plot_type = "scatter"
        else:
            plot_type = "line"

        plot = Plot(self.plot.data)
        plot.plot(("x", "y"), color=self.color, type=plot_type, marker=self.marker, marker_size=self.marker_size)
        self.plot = plot
        self.plot.visible = plot_visible

        grid_underlays = []

        if not show_grid:
            for underlay in self.plot.underlays:
                if isinstance(underlay, PlotGrid):
                    grid_underlays.append(underlay)
            for underlay in grid_underlays:
                self.plot.underlays.remove(underlay)

        self.container.add(self.plot)
        self.container.request_redraw()

        self.selection_handler.flush()

    def _add_pan_tool_changed(self):
        """
        Method called when the Pan Tool checkbox is checked or unchecked.
        Adds the Pan Tool to the plot container if it isn't there and vice versa.
        """

        broadcaster = BroadcasterTool()
        for plot in self.container.components:
            if self.add_pan_tool:
                pan = PanTool(plot)
                broadcaster.tools.append(pan)
                self.container.tools.append(broadcaster)
            else:
                for tool in self.container.tools:
                    if isinstance(tool, BroadcasterTool):
                        self.container.tools.remove(tool)

    def _add_zoom_tool_changed(self):
        """
        Method called when the Zoom Tool checkbox is checked or unchecked.
        Adds the Zoom Tool to the plot container if it isn't there and vice versa.        
        """

        broadcaster = BroadcasterTool()
        for plot in self.container.components:
            if self.add_zoom_tool:
                pan = ZoomTool(plot)
                broadcaster.tools.append(pan)
                self.container.tools.append(broadcaster)
            else:
                for tool in self.container.tools:
                    if isinstance(tool, BroadcasterTool):
                        self.container.tools.remove(tool)

    def _add_dragzoom_changed(self):
        """
        Method called when the Pan Tool checkbox is checked or unchecked.
        Adds the Pan Tool to the plot container if it isn't there and vice versa.
        """

        broadcaster = BroadcasterTool()
        for plot in self.container.components:
            if self.add_dragzoom:
                pan = BetterSelectingZoom(
                    plot,
                    always_on=True,
                    tool_mode="box",
                    drag_button="left",
                    color="lightskyblue",
                    alpha=0.4,
                    border_color="dodgerblue",
                )
                broadcaster.tools.append(pan)
                self.container.tools.append(broadcaster)
            else:
                for tool in self.container.tools:
                    if isinstance(tool, BroadcasterTool):
                        self.container.tools.remove(tool)

    def _show_grid_changed(self):
        """
        Called when the Show grid checkbox is checked or unchecked. Adds a grid
        if one is not present and removes if present.
        """

        if not self.show_grid:
            for plot in self.container.components:
                for underlay in self.grid_underlays:
                    if underlay in plot.underlays:
                        plot.underlays.remove(underlay)
        else:
            for plot in self.container.components:
                for underlay in self.grid_underlays:
                    if underlay not in plot.underlays:
                        plot.underlays.append(underlay)

        self.container.request_redraw()

    def reassign_current_plot(self):
        """
        Reassigns the currently selected plot. 
        """

        self.selection_handler.create_selection()
        plot_index = self.selection_handler.selected_indices[0][0]
        plot_name = self.plot_list_view.keys()[plot_index]
        self.plot = self.plot_list_view[plot_name]
        self.selection_handler.flush()
Ejemplo n.º 11
0
class StackedPlot(ChacoPlot):
    offset = Range(0.0, 1.0, 0.015)
    value_range = Range(0.01, 1.05, 1.00)
    flip_order = Bool(False)

    def _get_traits_group(self):
        return VGroup(
            HGroup(Item("flip_order"), Item("offset"), Item("value_range")),
            UItem("component", editor=ComponentEditor()),
        )

    def __init__(self):
        super(StackedPlot, self).__init__()
        self.container = OverlayPlotContainer(
            bgcolor="white", use_backbuffer=True, border_visible=True, padding=50, padding_left=110, fill_padding=True
        )
        self.data = ArrayPlotData()
        self.chaco_plot = None
        self.value_mapper = None
        self.index_mapper = None
        self.x_axis = PlotAxis(
            component=self.container,
            orientation="bottom",
            title=u"Angle (2\u0398)",
            title_font=settings.axis_title_font,
            tick_label_font=settings.tick_font,
        )
        y_axis_title = "Normalized intensity (%s)" % get_value_scale_label("linear")
        self.y_axis = PlotAxis(
            component=self.container,
            orientation="left",
            title=y_axis_title,
            title_font=settings.axis_title_font,
            tick_label_font=settings.tick_font,
        )
        self.container.overlays.extend([self.x_axis, self.y_axis])
        self.container.tools.append(TraitsTool(self.container, classes=[LinePlot, PlotAxis]))
        self.colors = []
        self.last_flip_order = self.flip_order

    @on_trait_change("offset, value_range, flip_order")
    def _replot_data(self):
        self._plot(self.data_x, None, self.data_z, self.scale)
        self.container.request_redraw()

    def _prepare_data(self, datasets):
        interpolate = True
        stack = stack_datasets(datasets)
        if interpolate:
            (x, z) = interpolate_datasets(stack, points=4800)
            x = array([x] * len(datasets))
        else:
            x, z = map(np.transpose, np.transpose(stack))
        return x, None, z

    def _plot(self, x, y, z, scale):
        self.data_x, self.data_z, self.scale = x, z, scale
        if self.container.components:
            self.colors = map(lambda plot: plot.color, self.container.components)
            if self.last_flip_order != self.flip_order:
                self.colors.reverse()
            self.container.remove(*self.container.components)
        # Use a custom renderer so plot lines are clickable
        self.chaco_plot = Plot(self.data, renderer_map={"line": ClickableLinePlot})
        self.chaco_plot.bgcolor = "white"
        self.value_mapper = None
        self.index_mapper = None

        if len(self.data_x) == len(self.colors):
            colors = self.colors[:]
        else:
            colors = ["black"] * len(self.data_x)

        if self.flip_order:
            z = z[::-1]

        spacing = (z.max(axis=1) - z.min(axis=1)).min() * self.value_range
        offset = spacing * self.offset
        for i, (x_row, z_row) in enumerate(zip(x, z)):
            self.data.set_data("data_x_" + str(i), x_row)
            self.data.set_data("data_y_offset_" + str(i), z_row * self.value_range + offset * i)
            plots = self.chaco_plot.plot(("data_x_" + str(i), "data_y_offset_" + str(i)), color=colors[i], type="line")
            plot = plots[0]
            self.container.add(plot)
            # Required for double-clicking plots
            plot.index.sort_order = "ascending"
            plot.value.sort_order = "ascending"

            if self.value_mapper is None:
                self.index_mapper = plot.index_mapper
                self.value_mapper = plot.value_mapper
            else:
                plot.value_mapper = self.value_mapper
                self.value_mapper.range.add(plot.value)
                plot.index_mapper = self.index_mapper
                self.index_mapper.range.add(plot.index)
        range = self.value_mapper.range
        range.high = (range.high - range.low) * self.value_range + range.low
        self.x_axis.mapper = self.index_mapper
        self.y_axis.mapper = self.value_mapper
        self.y_axis.title = "Normalized intensity (%s)" % get_value_scale_label(scale)
        self.zoom_tool = ClickUndoZoomTool(
            plot,
            tool_mode="box",
            always_on=True,
            pointer="cross",
            drag_button=settings.zoom_button,
            undo_button=settings.undo_button,
        )
        plot.overlays.append(self.zoom_tool)
        self.last_flip_order = self.flip_order
        return self.container

    def _reset_view(self):
        self.zoom_tool.revert_history_all()
Ejemplo n.º 12
0
class KMeansPlotHandler(HasTraits):
    '''
    Class for plotting the k-means clusters.
    
    '''

    # the data to cluster
    data = Array

    # the dataset created after preprocessing
    dataset = Array

    # the sklearn.cluster.KMeans object
    kmeans = Instance(KMeans)

    # Number of clusters
    n_clusters = Int

    # Maximum iterations for the clustering algorithm
    max_iter = Int

    # Container for the cluster plots
    container = Instance(OverlayPlotContainer)

    # the columns from the dataset to omit when performing clustering
    to_omit = List

    def __init__(self):
        self.kmeans = KMeans()
        self.container = OverlayPlotContainer()

    def create_dataset(self):
        '''
        Creates a numpy array from the current selection to pass to the 
        sklearn.cluster.kmeans object.
        '''

        if self.to_omit:
            if len(self.to_omit) > 0:
                n_rows = self.data.shape[0]
                n_cols = self.data.shape[1]
                to_omit = []
                for elem in self.to_omit:
                    if elem.isdigit():
                        to_omit.append(int(elem))

                dataset = self.data[:, 0].reshape((n_rows, 1))

                for elem in range(n_cols):
                    if elem not in to_omit:
                        if elem > 0:
                            dataset = np.hstack(
                                (dataset, self.data[:, elem].reshape(
                                    (n_rows, 1))))
            self.dataset = dataset

    def plot_clusters(self):
        '''
        Plots the clusters after calling the .fit method of the sklearn kmeans 
        estimator.
        '''

        self.kmeans.n_clusters = self.n_clusters
        self.kmeans.fit(self.dataset)

        # Reducing dimensions of the dataset and the cluster centers for
        # plottting
        pca = PCA(n_components=2, whiten=True)
        cluster_centers = pca.fit_transform(self.kmeans.cluster_centers_)
        dataset_red = pca.fit_transform(self.dataset)

        removed_components = []
        for component in self.container.components:
            removed_components.append(component)

        for component in removed_components:
            self.container.remove(component)

        for i in range(self.n_clusters):

            current_indices = find(self.kmeans.labels_ == i)
            current_data = dataset_red[current_indices, :]

            plotdata = ArrayPlotData(x=current_data[:, 0],
                                     y=current_data[:, 1])
            plot = Plot(plotdata)
            plot.plot(("x", "y"),
                      type='scatter',
                      color=tuple(COLOR_PALETTE[i]))
            self.container.add(plot)

        plotdata_cent = ArrayPlotData(x=cluster_centers[:, 0],
                                      y=cluster_centers[:, 1])
        plot_cent = Plot(plotdata_cent)
        plot_cent.plot(("x", "y"),
                       type='scatter',
                       marker='cross',
                       marker_size=8)
        self.container.add(plot_cent)

        self.container.request_redraw()
Ejemplo n.º 13
0
class StackedPlot(ChacoPlot):
    offset = Range(0.0, 1.0, 0.015)
    value_range = Range(0.01, 1.05, 1.00)
    flip_order = Bool(False)

    def _get_traits_group(self):
        return VGroup(
                   HGroup(
                       Item('flip_order'),
                       Item('offset'),
                       Item('value_range'),
                   ),
                   UItem('component', editor=ComponentEditor()),
               )

    def __init__(self):
        super(StackedPlot, self).__init__()
        self.container = OverlayPlotContainer(bgcolor='white',
                                         use_backbuffer=True,
                                         border_visible=True,
                                         padding=50,
                                         padding_left=110,
                                         fill_padding=True
                                             )
        self.data = ArrayPlotData()
        self.chaco_plot = None
        self.value_mapper = None
        self.index_mapper = None
        self.x_axis = PlotAxis(component=self.container,
                          orientation='bottom',
                          title=u'Angle (2\u0398)',
                          title_font=settings.axis_title_font,
                          tick_label_font=settings.tick_font)
        y_axis_title = 'Normalized intensity (%s)' % get_value_scale_label('linear')
        self.y_axis = PlotAxis(component=self.container,
                          orientation='left',
                          title=y_axis_title,
                          title_font=settings.axis_title_font,
                          tick_label_font=settings.tick_font)
        self.container.overlays.extend([self.x_axis, self.y_axis])
        self.container.tools.append(
            TraitsTool(self.container, classes=[LinePlot,PlotAxis]))
        self.colors = []
        self.last_flip_order = self.flip_order

    @on_trait_change('offset, value_range, flip_order')
    def _replot_data(self):
        self._plot(self.data_x, None, self.data_z, self.scale)
        self.container.request_redraw()

    def _prepare_data(self, datasets):
        interpolate = True
        stack = stack_datasets(datasets)
        if interpolate:
            (x, z) = interpolate_datasets(stack, points=4800)
            x = array([x] * len(datasets))
        else:
            x, z = map(np.transpose, np.transpose(stack))
        return x, None, z

    def _plot(self, x, y, z, scale):
        self.data_x, self.data_z, self.scale = x, z, scale
        if self.container.components:
            self.colors = map(lambda plot: plot.color, self.container.components)
            if self.last_flip_order != self.flip_order:
                self.colors.reverse()
            self.container.remove(*self.container.components)
        # Use a custom renderer so plot lines are clickable
        self.chaco_plot = Plot(self.data,
                               renderer_map={ 'line': ClickableLinePlot })
        self.chaco_plot.bgcolor = 'white'
        self.value_mapper = None
        self.index_mapper = None

        if len(self.data_x) == len(self.colors):
            colors = self.colors[:]
        else:
            colors = ['black'] * len(self.data_x)

        if self.flip_order:
            z = z[::-1]

        spacing = (z.max(axis=1) - z.min(axis=1)).min() * self.value_range
        offset = spacing * self.offset
        for i, (x_row, z_row) in enumerate(zip(x, z)):
            self.data.set_data('data_x_' + str(i), x_row)
            self.data.set_data('data_y_offset_' + str(i), z_row * self.value_range + offset * i)
            plots = self.chaco_plot.plot(('data_x_' + str(i), 'data_y_offset_' + str(i)), color=colors[i], type='line')
            plot = plots[0]
            self.container.add(plot)
            # Required for double-clicking plots
            plot.index.sort_order = 'ascending'
            plot.value.sort_order = 'ascending'

            if self.value_mapper is None:
                self.index_mapper = plot.index_mapper
                self.value_mapper = plot.value_mapper
            else:
                plot.value_mapper = self.value_mapper
                self.value_mapper.range.add(plot.value)
                plot.index_mapper = self.index_mapper
                self.index_mapper.range.add(plot.index)
        range = self.value_mapper.range
        range.high = (range.high - range.low) * self.value_range + range.low
        self.x_axis.mapper = self.index_mapper
        self.y_axis.mapper = self.value_mapper
        self.y_axis.title = 'Normalized intensity (%s)' % \
                get_value_scale_label(scale)
        self.zoom_tool = ClickUndoZoomTool(
            plot, tool_mode="box", always_on=True, pointer="cross",
            drag_button=settings.zoom_button,
            undo_button=settings.undo_button,
        )
        plot.overlays.append(self.zoom_tool)
        self.last_flip_order = self.flip_order
        return self.container

    def _reset_view(self):
        self.zoom_tool.revert_history_all()
Ejemplo n.º 14
0
class KMeansPlotHandler(HasTraits):
    """
    Class for plotting the k-means clusters.
    
    """

    # the data to cluster
    data = Array

    # the dataset created after preprocessing
    dataset = Array

    # the sklearn.cluster.KMeans object
    kmeans = Instance(KMeans)

    # Number of clusters
    n_clusters = Int

    # Maximum iterations for the clustering algorithm
    max_iter = Int

    # Container for the cluster plots
    container = Instance(OverlayPlotContainer)

    # the columns from the dataset to omit when performing clustering
    to_omit = List

    def __init__(self):
        self.kmeans = KMeans()
        self.container = OverlayPlotContainer()

    def create_dataset(self):
        """
        Creates a numpy array from the current selection to pass to the 
        sklearn.cluster.kmeans object.
        """

        if self.to_omit:
            if len(self.to_omit) > 0:
                n_rows = self.data.shape[0]
                n_cols = self.data.shape[1]
                to_omit = []
                for elem in self.to_omit:
                    if elem.isdigit():
                        to_omit.append(int(elem))

                dataset = self.data[:, 0].reshape((n_rows, 1))

                for elem in range(n_cols):
                    if elem not in to_omit:
                        if elem > 0:
                            dataset = np.hstack((dataset, self.data[:, elem].reshape((n_rows, 1))))
            self.dataset = dataset

    def plot_clusters(self):
        """
        Plots the clusters after calling the .fit method of the sklearn kmeans 
        estimator.
        """

        self.kmeans.n_clusters = self.n_clusters
        self.kmeans.fit(self.dataset)

        # Reducing dimensions of the dataset and the cluster centers for
        # plottting
        pca = PCA(n_components=2, whiten=True)
        cluster_centers = pca.fit_transform(self.kmeans.cluster_centers_)
        dataset_red = pca.fit_transform(self.dataset)

        removed_components = []
        for component in self.container.components:
            removed_components.append(component)

        for component in removed_components:
            self.container.remove(component)

        for i in range(self.n_clusters):

            current_indices = find(self.kmeans.labels_ == i)
            current_data = dataset_red[current_indices, :]

            plotdata = ArrayPlotData(x=current_data[:, 0], y=current_data[:, 1])
            plot = Plot(plotdata)
            plot.plot(("x", "y"), type="scatter", color=tuple(COLOR_PALETTE[i]))
            self.container.add(plot)

        plotdata_cent = ArrayPlotData(x=cluster_centers[:, 0], y=cluster_centers[:, 1])
        plot_cent = Plot(plotdata_cent)
        plot_cent.plot(("x", "y"), type="scatter", marker="cross", marker_size=8)
        self.container.add(plot_cent)

        self.container.request_redraw()
Ejemplo n.º 15
0
class ImagePlotHandler(HasTraits):
    '''
    Class for handling image plots.
    '''

    # the overlay container for the plot
    container = Instance(OverlayPlotContainer)

    # the selection handler instance for the tableview
    selection_handler = Instance(SelectionHandler)

    # the input data.
    table = Array

    # The chaco colorbar object accompanying the plot.
    colorbar = ColorBar

    def __init__(self):
        self.container = OverlayPlotContainer()

        self.selection_handler = SelectionHandler()

    def imageplot_check(self):
        '''
        Called to check whether the current selection is compatible with an
        image plot.
        '''
        if len(self.selection_handler.selected_indices) > 1:
            shape_list = []
            for index_tuple in self.selection_handler.selected_handler:
                x = self.table[index_tuple[0]:index_tuple[2],
                               index_tuple[1]:index_tuple[3]]
                shape_list.append(x.shape)
            shape_ = shape_list[0]
            if shape_list.count(shape_) != len(shape_list):
                return False
            return True
        else:
            return True

    def toggle_colorbar(self, checked):
        '''
        Function called to toggle the colorbar on the image plot.
        '''
        if not checked:
            for component in self.container.components:
                if isinstance(component, ColorBar):
                    self.colorbar = component
            self.container.components.remove(self.colorbar)
        else:
            self.container.add(self.colorbar)
        self.container.request_redraw()

    def draw_image_plot(self):
        '''
        Function called to draw the image plot.
        '''
        self.top_left = self.selection_handler.selected_indices[0][0:2]
        self.bot_right = self.selection_handler.selected_indices[0][2:4]
        data = self.table[self.top_left[0]:self.bot_right[0],
                          self.top_left[1]:self.bot_right[1]]
        plotdata = ArrayPlotData(imagedata=data)
        plot = Plot(plotdata)
        plot.img_plot('imagedata')
        plot.tools.append(PanTool(plot))
        plot.tools.append(ZoomTool(plot))
        plot.tools.append(TraitsTool(plot))
        self.container.add(plot)

        #colorbar = ColorBar(
        #    index_mapper=LinearMapper(range=plot.color_mapper.range),
        #    color_mapper = plot.color_mapper,
        #    orientation='v'
        #)
        #self.colorbar = ColorBar
        #self.container.add(colorbar)

        self.container.request_redraw()
Ejemplo n.º 16
0
class XYPlotHandler(HasTraits):
    '''
    Class for handling XY plots
    '''

    # Whether the data is a pandas dataframe
    AS_PANDAS_DATAFRAME = Bool

    # The container for all current plots. Gets updated everytime a plot is
    # added.
    container = OverlayPlotContainer

    # This can be removed.
    plotdata = ArrayPlotData

    # The current Plot object.
    plot = Plot

    # ColorTrait, mainly required for the TraitsUIItem view.
    color = ColorTrait("blue")

    # Marker trait for the view
    marker = marker_trait

    # Marker size trait
    marker_size = Int(4)

    # An instance of SelectionHandler for adding plots from the current
    # selection.
    selection_handler = Instance(SelectionHandler)

    # Bool traits for checking the type of the plot (discrete / continuous)
    plot_type_disc = Bool
    plot_type_cont = Bool

    # The data from which to draw the plots, same as the table attribute of
    # CsvModel
    table = Array

    # The pandas data frame if AS_PANDAS_DATAFRAME
    data_frame = Instance(DataFrame)

    # Contains the grid underlays of all the current plots
    grid_underlays = List

    # Used for viewing the list of the plots and the legend
    plot_list_view = Dict

    # TraitsUI view for plot properties, yet to find an enaml implementation
    view = View(Item('color'), Item('marker'), Item('marker_size'))

    # Trait that defines whether tools are present.
    add_pan_tool = Bool
    add_zoom_tool = Bool
    add_dragzoom = Bool

    # Whether grids and axes are visible
    show_grid = Bool

    def __init__(self):
        self.selection_handler = SelectionHandler()
        self.container = OverlayPlotContainer()
        self.underlays = []
        self.add_pan_tool = False
        self.add_zoom_tool = False
        self.add_dragzoom = False
        self.show_grid = False

    def add_xyplot_selection(self, plot_name):
        '''
        Called when the 'add plot from selection button is clicked.'
        '''

        self.selection_handler.create_selection()
        if self.selection_handler.xyplot_check():

            if self.AS_PANDAS_DATAFRAME:
                x_column = self.data_frame.columns[
                    self.selection_handler.selected_indices[0][1]]
                y_column = self.data_frame.columns[
                    self.selection_handler.selected_indices[1][1]]
                x = np.array(self.data_frame[x_column])
                y = np.array(self.data_frame[y_column])
                self.plotdata = ArrayPlotData(x=x, y=y)

            else:

                first_column = self.selection_handler.selected_indices[0]
                second_column = self.selection_handler.selected_indices[1]
                self.plotdata = ArrayPlotData(x=self.table[:, first_column[1]],
                                              y=self.table[:,
                                                           second_column[1]])

            plot = Plot(self.plotdata)

            if self.plot_type_disc:
                plot_type = 'scatter'
            else:
                plot_type = 'line'
            plot.plot(("x", "y"),
                      type=plot_type,
                      color=self.color,
                      marker=self.marker,
                      marker_size=self.marker_size)

            self.plot = plot

            for underlay in self.plot.underlays:
                if isinstance(underlay, PlotGrid):
                    if underlay not in self.grid_underlays:
                        self.grid_underlays.append(underlay)

            for underlay in self.grid_underlays:
                if underlay in self.plot.underlays:
                    self.plot.underlays.remove(underlay)

            if plot_name == '':
                self.plot_list_view['plot' +
                                    str(len(self.plot_list_view))] = self.plot
            else:
                self.plot_list_view[plot_name] = self.plot
            self.container.add(self.plot)

            self.container.request_redraw()

        self.selection_handler.flush()

    def grid_toggle(self, checked):
        '''
        Called when the 'Show Grid' checkbox ins toggled
        '''
        if not checked:
            for plot in self.container.components:
                for underlay in self.grid_underlays:
                    if underlay in plot.underlays:
                        plot.underlays.remove(underlay)
        else:
            for plot in self.container.components:
                for underlay in self.grid_underlays:
                    if underlay not in plot.underlays:
                        plot.underlays.append(underlay)
        self.container.request_redraw()

    def remove_selected_plots(self, selection):
        '''
        Called when the 'Remove Selected Plots' button is clicked
        '''
        remove_indices = []
        for model_index in selection:
            remove_indices.append(model_index[0].row)
        remove_plots = []
        for index in remove_indices:
            remove_plots.append(self.plot_list_view.keys()[index])

        removed_plots = []
        for plot in remove_plots:
            removed_plots.append(self.plot_list_view.pop(plot))
        for plot in self.container.components:
            self.container.remove(plot)
        for plot in self.plot_list_view.keys():
            self.container.add(self.plot_list_view[plot])
        self.container.request_redraw()

    def edit_selection(self, show_grid, plot_visible, plot_type_disc):
        '''
        Called to start editing the selected plot. Should accompany the 'Edit
        Plot' dialog.
        '''

        #self.selection_handler.create_selection()
        #index = self.selection_handler.selected_indices[0][0]
        #plot_name = self.plot_list_view.keys()[index]
        #plot = self.plot_list_view[plot_name]

        self.container.remove(self.plot)

        self.plot_type_disc = plot_type_disc

        if self.plot_type_disc:
            plot_type = 'scatter'
        else:
            plot_type = 'line'

        plot = Plot(self.plot.data)
        plot.plot(("x", "y"),
                  color=self.color,
                  type=plot_type,
                  marker=self.marker,
                  marker_size=self.marker_size)
        self.plot = plot
        self.plot.visible = plot_visible

        grid_underlays = []

        if not show_grid:
            for underlay in self.plot.underlays:
                if isinstance(underlay, PlotGrid):
                    grid_underlays.append(underlay)
            for underlay in grid_underlays:
                self.plot.underlays.remove(underlay)

        self.container.add(self.plot)
        self.container.request_redraw()

        self.selection_handler.flush()

    def _add_pan_tool_changed(self):
        '''
        Method called when the Pan Tool checkbox is checked or unchecked.
        Adds the Pan Tool to the plot container if it isn't there and vice versa.
        '''

        broadcaster = BroadcasterTool()
        for plot in self.container.components:
            if self.add_pan_tool:
                pan = PanTool(plot)
                broadcaster.tools.append(pan)
                self.container.tools.append(broadcaster)
            else:
                for tool in self.container.tools:
                    if isinstance(tool, BroadcasterTool):
                        self.container.tools.remove(tool)

    def _add_zoom_tool_changed(self):
        '''
        Method called when the Zoom Tool checkbox is checked or unchecked.
        Adds the Zoom Tool to the plot container if it isn't there and vice versa.        
        '''

        broadcaster = BroadcasterTool()
        for plot in self.container.components:
            if self.add_zoom_tool:
                pan = ZoomTool(plot)
                broadcaster.tools.append(pan)
                self.container.tools.append(broadcaster)
            else:
                for tool in self.container.tools:
                    if isinstance(tool, BroadcasterTool):
                        self.container.tools.remove(tool)

    def _add_dragzoom_changed(self):
        '''
        Method called when the Pan Tool checkbox is checked or unchecked.
        Adds the Pan Tool to the plot container if it isn't there and vice versa.
        '''

        broadcaster = BroadcasterTool()
        for plot in self.container.components:
            if self.add_dragzoom:
                pan = BetterSelectingZoom(plot,
                                          always_on=True,
                                          tool_mode='box',
                                          drag_button='left',
                                          color='lightskyblue',
                                          alpha=0.4,
                                          border_color='dodgerblue')
                broadcaster.tools.append(pan)
                self.container.tools.append(broadcaster)
            else:
                for tool in self.container.tools:
                    if isinstance(tool, BroadcasterTool):
                        self.container.tools.remove(tool)

    def _show_grid_changed(self):
        '''
        Called when the Show grid checkbox is checked or unchecked. Adds a grid
        if one is not present and removes if present.
        '''

        if not self.show_grid:
            for plot in self.container.components:
                for underlay in self.grid_underlays:
                    if underlay in plot.underlays:
                        plot.underlays.remove(underlay)
        else:
            for plot in self.container.components:
                for underlay in self.grid_underlays:
                    if underlay not in plot.underlays:
                        plot.underlays.append(underlay)

        self.container.request_redraw()

    def reassign_current_plot(self):
        '''
        Reassigns the currently selected plot. 
        '''

        self.selection_handler.create_selection()
        plot_index = self.selection_handler.selected_indices[0][0]
        plot_name = self.plot_list_view.keys()[plot_index]
        self.plot = self.plot_list_view[plot_name]
        self.selection_handler.flush()
Ejemplo n.º 17
0
class StackedPlot(ChacoPlot):
    offset = Range(0.0, 1.0, 0.015)
    value_range = Range(0.01, 1.05, 1.00)
    flip_order = Bool(False)

    def _get_traits_group(self):
        return VGroup(
                   HGroup(
                       Item('flip_order'),
                       Item('offset'),
                       Item('value_range'),
                   ),
                   UItem('component', editor=ComponentEditor()),
               )

    def __init__(self):
        super(StackedPlot, self).__init__()
        self.container = OverlayPlotContainer(bgcolor='white',
                                         use_backbuffer=True,
                                         border_visible=True,
                                         padding=50,
                                         padding_left=110,
                                         fill_padding=True
                                             )
        self.data = ArrayPlotData()
        self.chaco_plot = None
        self.value_mapper = None
        self.index_mapper = None
        self.x_axis = MyPlotAxis(component=self.container,
                          orientation='bottom',
                          title=u'Angle (2\u0398)',
                          title_font=settings.axis_title_font,
                          tick_label_font=settings.tick_font)
        y_axis_title = 'Normalized intensity (%s)' % get_value_scale_label('linear')
        self.y_axis = MyPlotAxis(component=self.container,
                          orientation='left',
                          title=y_axis_title,
                          title_font=settings.axis_title_font,
                          tick_label_font=settings.tick_font)
        self.container.overlays.extend([self.x_axis, self.y_axis])
        self.container.tools.append(
            TraitsTool(self.container, classes=[LinePlot,MyPlotAxis]))
        self.colors = []
        self.last_flip_order = self.flip_order

    @on_trait_change('offset, value_range, flip_order')
    def _replot_data(self):
        self._plot(self.data_x, None, self.data_z, self.scale)
        self.container.request_redraw()

#    def _prepare_data(self, datasets):
    def _prepare_data(self, stack):
        # stack = stack_datasets(datasets)
        x = stack[:,:,0]
        z = stack[:,:,2]
        return x, None, z

    def _plot(self, x, y, z, scale):
        self.data_x, self.data_z, self.scale = x, z, scale
        if self.container.components:
            self.colors = map(lambda plot: plot.color, self.container.components)
            if self.last_flip_order != self.flip_order:
                self.colors.reverse()
            self.container.remove(*self.container.components)
        # Use a custom renderer so plot lines are clickable
        self.chaco_plot = Plot(self.data,
                               renderer_map={ 'line': ClickableLinePlot })
        self.chaco_plot.bgcolor = 'white'
        self.value_mapper = None
        self.index_mapper = None

        if len(self.data_x) == len(self.colors):
            colors = self.colors[:]
        else:
            colors = ['black'] * len(self.data_x)

        if self.flip_order:
            z = z[::-1]

        spacing = (z.max(axis=1) - z.min(axis=1)).min() * self.value_range
        offset = spacing * self.offset
        for i, (x_row, z_row) in enumerate(zip(x, z)):
            self.data.set_data('data_x_' + str(i), x_row)
            self.data.set_data('data_y_offset_' + str(i), z_row * self.value_range + offset * i)
            plots = self.chaco_plot.plot(('data_x_' + str(i), 'data_y_offset_' + str(i)), color=colors[i], type='line')
            plot = plots[0]
            self.container.add(plot)
            # Required for double-clicking plots
            plot.index.sort_order = 'ascending'
            plot.value.sort_order = 'ascending'

            if self.value_mapper is None:
                self.index_mapper = plot.index_mapper
                self.value_mapper = plot.value_mapper
            else:
                plot.value_mapper = self.value_mapper
                self.value_mapper.range.add(plot.value)
                plot.index_mapper = self.index_mapper
                self.index_mapper.range.add(plot.index)
        range = self.value_mapper.range
        range.high = (range.high - range.low) * self.value_range + range.low
        self.x_axis.mapper = self.index_mapper
        self.y_axis.mapper = self.value_mapper
        self.y_axis.title = 'Normalized intensity (%s)' % \
                get_value_scale_label(scale)
        self.zoom_tool = ClickUndoZoomTool(
            plot, tool_mode="box", always_on=True, pointer="cross",
            drag_button=settings.zoom_button,
            undo_button=settings.undo_button,
        )
        plot.overlays.append(self.zoom_tool)
        self.last_flip_order = self.flip_order
        return self.container

    def _reset_view(self):
        self.zoom_tool.revert_history_all()
Ejemplo n.º 18
0
class MainApp(HasTraits):
    container = Instance(OverlayPlotContainer)

    file_paths = List(Str)
    # Button group above tabbed area
    open_files = Button("Open files...")
    edit_datasets = Button("Edit datasets...")
    generate_plot = Button("Generate plot...")
    help_button = Button("Help...")

    # View tab
    scale = Enum('linear', 'log', 'sqrt')
    options = List
    reset_button = Button("Reset view")
    copy_to_clipboard = Button("Copy to clipboard")
    save_as_image = Button("Save as image...")

    # Process tab
    merge_positions = Enum('all', 'p1+p2', 'p3+p4', 'p12+p34')('p1+p2')
    load_partners = Button
    splice = Bool(True)
    merge = Bool(False)
    merge_regrid = Bool(False)
    normalise = Bool(True)
    # See comment in class Global() for an explanation of the following traits
    g = Instance(Global, ())
    file_list = DelegatesTo('g')
    normalisation_source_filenames = Enum(values='file_list')
    def _g_default(self):
        return g

    correction = Float(0.0)
    align_positions = Bool(False)
    bt_start_peak_select = Button
    bt_end_peak_select = Button
    peak_selecting = Bool(False)

    what_to_plot = Enum('Plot new', 'Plot old and new')('Plot old and new')

    bt_process = Button("Apply")
    bt_undo_processing = Button("Undo")
    bt_save = Button("Save...")

    # Background removal tab
    bt_manually_define_background = Button("Define")
    polynomial_order = Range(1, 20)(7)
    bt_poly_fit = Button("Poly fit")
    bt_load_background = Button("Load...")

    # theta/d/Q tab
    filename_field = Str("d")
    bt_convertscale_abscissa = Button("Convert/scale abscissa...")

    raw_data_plot = Instance(RawDataPlot)

    #-------------------------------------------------------------------------------------
    # MVC View
    view_group = VGroup(
        Label('Scale:'),
        UItem('scale', enabled_when='object._has_data()'),
        UItem('options', editor=CheckListEditor(name='_options'), style='custom', enabled_when='object._has_data()'),
        UItem('reset_button', enabled_when='object._has_data()'),
        spring,
        '_',
        spring,
        UItem('copy_to_clipboard', enabled_when='object._has_data()'),
        UItem('save_as_image', enabled_when='object._has_data()'),
        label='View',
        springy=False,
    )

    process_group = VGroup(
        VGroup(
            Label('Positions to process:'),
            UItem(name='merge_positions',
                 style='custom',
                 editor=EnumEditor(values={
                     'p1+p2'   : '1: p1+p2',
                     'p3+p4'   : '2: p3+p4',
                     'p12+p34' : '3: p12+p34',
                     'all'     : '4: all',
                 }, cols=2),
                 enabled_when='object._has_data()'
            ),
            UItem('load_partners', enabled_when='object._has_data() and (object.merge_positions != "all")'),
            show_border = True,
        ),
        VGroup(
            HGroup(Item('align_positions'), enabled_when='object._has_data() and (object.merge_positions != "all")'),
            HGroup(
                UItem('bt_start_peak_select', label='Select peak',
                      enabled_when='object.align_positions and not object.peak_selecting and (object.merge_positions != "all")'),
                UItem('bt_end_peak_select', label='Align',
                      enabled_when='object.peak_selecting and (object.merge_positions != "all")'),
            ),
            Item('correction', label='Zero correction:', enabled_when='object._has_data()'),
            show_border = True,
        ),
        VGroup(
            HGroup(
                Item('splice'),
                Item('merge', enabled_when='object.merge_positions != "p12+p34"'),
                enabled_when='object._has_data() and (object.merge_positions != "all")'
            ),
            HGroup(
                Item('normalise', label='Normalise', enabled_when='object._has_data() and (object.merge_positions != "p12+p34")'),
                Item('merge_regrid', label='Grid', enabled_when='object._has_data()'),
            ),
            VGroup(
                Label('Normalise to:'),
                UItem('normalisation_source_filenames', style='simple',
                     enabled_when='object.normalise and object._has_data()'),
            ),
            show_border = True,
        ),
        spring,
        UItem('what_to_plot', editor=DefaultOverride(cols=2), style='custom',
              enabled_when='object._has_data()'),
        spring,
        UItem('bt_process', enabled_when='object._has_data()'),
        UItem('bt_undo_processing', enabled_when='object.undo_state is not None'),
        UItem('bt_save', enabled_when='object._has_data()'),
        label='Process',
        springy=False,
    )

    background_removal_group =  VGroup(
        VGroup(
            Label('Manually define:'),
            UItem('bt_manually_define_background', enabled_when='object._has_data()'),
            show_border = True,
        ),
        VGroup(
            Label('Fit polynomial:'),
            HGroup(
                   Item('polynomial_order', label='order', enabled_when='object._has_data()'),
            ),
            UItem('bt_poly_fit', enabled_when='object._has_data()'),
            show_border = True,
        ),
        VGroup(
            Label('Load from file:'),
            UItem('bt_load_background', enabled_when='object._has_data()'),
            show_border = True,
        ),
        label='Backgrnd',
        springy=False,
    )

    convert_xscale_group = VGroup(
        Label('Filename label (prefix_<label>_nnnn.xye):'),
        UItem('filename_field',
             enabled_when='object._has_data()'),
        UItem('bt_convertscale_abscissa',
              label='Convert/scale abscissa...',
              enabled_when='object._has_data()',
        ),
    label=ur'\u0398 d Q',
    springy=True,
    )

    traits_view = View(
        HGroup(
            VGroup(
                UItem('open_files'),
                UItem('edit_datasets', enabled_when='object._has_data()'),
                UItem('generate_plot', enabled_when='object._has_data()'),
                UItem('help_button'),
                spring,
                spring,
                Tabbed(
                    view_group,
                    process_group,
                    # background_removal_group,
                    convert_xscale_group,
                    springy=False,
                ),
                show_border=False,
            ),
            UItem('container', editor=ComponentEditor(bgcolor='white')),
            show_border=False,
        ),
        resizable=True, title=title, width=size[0], height=size[1]
    )

    #-------------------------------------------------------------------------------------
    # MVC Control

    def _has_data(self):
        return len(self.datasets) != 0

    def __init__(self, *args, **kws):
        """
        self.datasets = [ <XYEDataset>, ..., <XYEDataset> ]
        self.dataset_pairs = set([ (<XYEDataset-p1>, <XYEDataset-p2>),
                                   ...,
                                   (<XYEDataset-p1>, <XYEDataset-p2>) ])
        """
        super(MainApp, self).__init__(*args, **kws)
        self.datasets = []
        self.dataset_pairs = set()
        self.undo_state = None
        self.raw_data_plot = RawDataPlot()
        self.plot = self.raw_data_plot.get_plot()
        self.container = OverlayPlotContainer(self.plot,
            bgcolor="white", use_backbuffer=True,
            border_visible=False)
        self.pan_tool = None
        # The list of all options.
        self._options = [ 'Show legend', 'Show gridlines', 'Show crosslines' ]
        # The list of currently set options, updated by the UI.
        self.options = self._options
        self.file_paths = []

    def _open_files_changed(self):
        file_list = get_file_list_from_dialog()
        if file_list:
            self.file_paths = file_list

    def _options_changed(self, opts):
        # opts just contains the keys that are true.
        # Create a dict all_options that has True/False for each item.
        all_options = dict.fromkeys(self._options, False)
        true_options = dict.fromkeys(opts, True)
        all_options.update(true_options)
        self.raw_data_plot.show_legend(all_options['Show legend'])
        self.raw_data_plot.show_grids(all_options['Show gridlines'])
        self.raw_data_plot.show_crosslines(all_options['Show crosslines'])
        self.container.request_redraw()

    def _bt_start_peak_select_changed(self):
        self.raw_data_plot.start_range_select()
        self.peak_selecting = True

    def _bt_end_peak_select_changed(self):
        self.peak_selecting = False
        selection_range = self.raw_data_plot.end_range_select()
        if not selection_range:
            return

        range_low, range_high = selection_range
        # fit the peak in all loaded dataseries
        self._get_partners()
        for datapair in self._get_dataset_pairs():
            processing.fit_peaks_for_a_dataset_pair(
                range_low, range_high, datapair, self.normalise)
        editor = PeakFitWindow(dataset_pairs=self._get_dataset_pairs(),
                               range=selection_range)
        editor.edit_traits()

    def _get_dataset_pairs(self):
        datasets_dict = dict([ (d.name, d) for d in self.datasets ])
        return [ (datasets_dict[file1], datasets_dict[file2]) \
                    for file1, file2 in self.dataset_pairs ]

    def _bt_process_changed(self):
        '''
        Button click event handler for processing. 
        '''
        # Save the unprocessed data series at this point for later undoing
        processed_datasets = []
        processor = DatasetProcessor(self.normalise, self.correction,
                                     self.align_positions,
                                     self.splice, self.merge, self.merge_regrid,
                                     self.normalisation_source_filenames,
                                     self.datasets)
        # Processing at this point depends on the "Positions to process:" radiobutton
        # selection:
        # If Splice==True, get all pairs and splice them
        # If Merge==True, get all pairs and merge them
        # If Normalise==True, always normalise
        # If Grid===True, output gridded and ungridded
        # The following processing code sould really be placed into a processor.process()
        # method, but I only worked out how to pass required stuff late in the day, so
        # I do this stuff here.
        if self.merge_positions == 'p12+p34':
            self._get_partners()        # pair up datasets corresponding to the radiobutton selection
            for dataset_pair in self._get_dataset_pairs():
                datasets = processor.splice_overlapping_datasets(dataset_pair)
                for dataset in datasets:
                    dataset.metadata['ui'].name = dataset.name + ' (processed)'
                    dataset.metadata['ui'].color = None
                processed_datasets.extend(datasets)
        elif self.merge_positions == 'all':
            # Handle "all" selection for regrid and normalise
            for d in self.datasets:
                dataset = processor.normalise_me(d)
                if dataset is not None:
                    processed_datasets.extend([dataset])
                    dataset.metadata['ui'].name = dataset.name + ' (processed)'
                    dataset.metadata['ui'].color = None
                    d = dataset

                dataset = processor.regrid_me(d)
                if dataset is not None:
                    processed_datasets.extend([dataset])
                    dataset.metadata['ui'].name = dataset.name + ' (processed)'
                    dataset.metadata['ui'].color = None
        else:
            self._get_partners()        # pair up datasets corresponding to the radiobutton selection
            for dataset_pair in self._get_dataset_pairs():
                datasets = processor.process_dataset_pair(dataset_pair)
                for dataset in datasets:
                    dataset.metadata['ui'].name = dataset.name + ' (processed)'
                    dataset.metadata['ui'].color = None
                processed_datasets.extend(datasets)

        self.processed_datasets = processed_datasets
        self._plot_processed_datasets()

    def _plot_processed_datasets(self):
        self._save_state()
        self.dataset_pairs = set()          # TODO: Check whether this line should be removed
        if 'old' not in self.what_to_plot:
            self.datasets = []
        if 'new' in self.what_to_plot:
            self.datasets.extend(self.processed_datasets)
        self._plot_datasets()

    def _save_state(self):
        self.undo_state = (self.datasets[:], self.dataset_pairs.copy())

    def _restore_state(self):
        if self.undo_state is not None:
            self.datasets, self.dataset_pairs = self.undo_state
            self.undo_state = None

    def _bt_undo_processing_changed(self):
        self._restore_state()
        self._plot_datasets()

    def _bt_save_changed(self):
        wildcard = 'All files (*.*)|*.*'
        default_filename = 'prefix_'
        dlg = FileDialog(title='Save results', action='save as',
                         default_filename=default_filename, wildcard=wildcard)
        if dlg.open() == OK:
            for dataset in self.processed_datasets:
                filename = os.path.join(dlg.directory, dlg.filename + dataset.name)
                dataset.save(filename)
            open_file_dir_with_default_handler(dlg.path)

    def _save_as_image_changed(self):
        if len(self.datasets) == 0:
            return
        filename = get_save_as_filename()
        if filename:
            PlotOutput.save_as_image(self.container, filename)
            open_file_dir_with_default_handler(filename)

    def _copy_to_clipboard_changed(self):
        if self.datasets:
            PlotOutput.copy_to_clipboard(self.container)

    def _scale_changed(self):
        self._plot_datasets()

    def _get_partner(self, position_index):
        # return index of partner; i.e., 2=>1, 1=>2, 3=>4, 4=>3, 12=>34, 34=>12
        if position_index in [1,2,3,4]:
            partner = ((position_index-1)^1)+1
        elif position_index==12:
            partner = 34
        elif position_index==34:
            partner = 12
        else:
            raise 'unparsable position'
        return partner

    def _get_position(self, filename):
        m = re.search('_p([0-9]*)_', filename)
        try:
            return int(m.group(1))
        except (AttributeError, ValueError):
            return None

    def _add_dataset_pair(self, filename):
        current_directory, filebase = os.path.split(filename)
        position_index = self._get_position(filename)
        if position_index is None:
            return

        # base filename for the associated position.
        other_filebase = re.sub('_p{}_'.format(position_index),
                                '_p{}_'.format(self._get_partner(position_index)),
                                filebase)
        other_filename = os.path.join(current_directory, other_filebase)
        if not os.path.exists(other_filename):
            return

        # OK, we've got the names and paths, now add the actual data and references.
        if other_filename not in self.file_paths:
            self._add_xye_dataset(other_filename)
            # We also need to append the new path to the file_paths List trait which is
            # already populated by the files selected using the file selection dialog
            self.file_paths.append(other_filename)
        self._refresh_normalise_to_list()

    def _get_partners(self):
        """
        Populates the self.dataset_pairs list with all dataset partners in
        self.file_paths corresponding to the merge_positions radiobutton selection.
        """
        matching_re = {'all':    ''            ,
                       'p1+p2':  '_p[12]_'     ,
                       'p3+p4':  '_p[34]_'     ,
                       'p12+p34':'_p(?:12|34)_',
                      }
        basenames = [os.path.basename(f) for f in self.file_paths]
        filtered_paths = [f for f in basenames if re.search(matching_re[self.merge_positions], f) is not None]
        self.dataset_pairs = set()
        for filebase in filtered_paths:
            # base filename for the first position.
            position_index = self._get_position(filebase)
            if position_index is None:
                return
            other_filebase = re.sub('_p{}_'.format(position_index),
                                    '_p{}_'.format(self._get_partner(position_index)),
                                    filebase)
            if filebase in basenames and other_filebase in basenames:
                if position_index!=12 and (position_index&1)==0:
                    self.dataset_pairs.add((other_filebase, filebase))
                else:
                    self.dataset_pairs.add((filebase, other_filebase))
        return self.dataset_pairs

    def _refresh_normalise_to_list(self):
        g.populate_list(self.file_paths)

    def _file_paths_changed(self, new):
        """
        When the file dialog box is closed with a selection of filenames,
        just generate a list of all the filenames
        """
        self.datasets = []
        # self.file_paths is modified by _add_dataset_pair() so iterate over a copy of it.
        for filename in self.file_paths[:]:
            self._add_xye_dataset(filename)
        self._plot_datasets()
        self.datasets.sort(key=lambda d: d.name)
        self._refresh_normalise_to_list()

    def _load_partners_changed(self):
        for filename in self.file_paths[:]:
            self._add_dataset_pair(filename)
        self._plot_datasets()
        self.datasets.sort(key=lambda d: d.name)

    def _plot_datasets(self, reset_view=True):
        self.raw_data_plot.plot_datasets(self.datasets, scale=self.scale,
                                         reset_view=reset_view)
        self._options_changed(self.options)
        self.container.request_redraw()

    def _edit_datasets_changed(self):
        editor = DatasetEditor(datasets=self.datasets)
        editor.edit_traits()
        self._plot_datasets(reset_view=False)

    def _generate_plot_changed(self):
        if self.datasets:
            generator = PlotGenerator(datasets=self.datasets)
            generator.show()

    def _help_button_changed(self):
        help_box = HelpBox()
        help_box.edit_traits()

    def _reset_button_changed(self):
        self.raw_data_plot.reset_view()

    def _add_xye_dataset(self, file_path):
        try:
            dataset = XYEDataset.from_file(file_path)
        except IOError:
            return
        self.datasets.append(dataset)
        create_datasetui(dataset)

    def _bt_convertscale_abscissa_changed(self):
        editor = WavelengthEditor(datasets=self.datasets, filename_field=self.filename_field)
        editor.edit_traits()
        self._plot_datasets(reset_view=False)
Ejemplo n.º 19
0
class TemplatePicker(HasTraits):
    template = Array
    CC = Array
    peaks = List
    zero=Int(0)
    tmp_size = Range(low=2, high=512, value=64, cols=4)
    max_pos_x=Property(depends_on=['tmp_size'])
    max_pos_y=Property(depends_on=['tmp_size'])
    top = Range(low='zero',high='max_pos_y', value=20, cols=4)
    left = Range(low='zero',high='max_pos_x', value=20, cols=4)
    is_square = Bool
    img_plot = Instance(Plot)
    tmp_plot = Instance(Plot)
    findpeaks = Button
    next_img = Button
    prev_img = Button
    peak_width = Range(low=2, high=200, value=10)
    tab_selected = Event
    ShowCC = Bool
    img_container = Instance(Component)
    container = Instance(Component)
    colorbar= Instance(Component)
    numpeaks_total = Int(0,cols=5)
    numpeaks_img = Int(0,cols=5)
    OK_custom=OK_custom_handler
    cbar_selection = Instance(RangeSelection)
    cbar_selected = Event
    thresh=Trait(None,None,List,Tuple,Array)
    thresh_upper=Float(1.0)
    thresh_lower=Float(0.0)
    numfiles=Int(1)
    img_idx=Int(0)
    tmp_img_idx=Int(0)

    csr=Instance(BaseCursorTool)

    traits_view = View(
        Group(
            Group(
                Item("img_container",editor=ComponentEditor(), show_label=False),
                HGroup(
                    Item("ShowCC", editor=BooleanEditor(), label="Show cross correlation image"),
                    Spring(),
                    Item("prev_img",editor=ButtonEditor(label="<"),show_label=False, enabled_when='numfiles > 1'),
                    Item("next_img",editor=ButtonEditor(label=">"),show_label=False, enabled_when='numfiles > 1'),
                    ),
                label="Original image", show_border=True, trait_modified="tab_selected",
                orientation='vertical',),
            VGroup(
                Group(
                    HGroup(
                        Item("left", label="Left coordinate", style="custom"),
                        Spring(),
                        Item("top", label="Top coordinate", style="custom"),
                        ),
                    Item("tmp_size", label="Template size", style="custom"),
                    Item("tmp_plot",editor=ComponentEditor(height=256, width=256), show_label=False, resizable=True),
                    label="Template", show_border=True),
                Group(
                    HGroup(
                        Item("peak_width", label="Peak width", style="custom"),
                        Spring(),
                        Item("findpeaks",editor=ButtonEditor(label="Find Peaks"),show_label=False),
                        ),
                    HGroup(
                        Item("thresh_lower",label="Threshold Lower Value", editor=TextEditor(evaluate=float,
                                                                                             format_str='%1.4f')),
                        Spring(),
                        Item("thresh_upper",label="Threshold Upper Value", editor=TextEditor(evaluate=float,
                                                                                             format_str='%1.4f')),
                    ),
                    HGroup(
                        Item("numpeaks_img",label="Number of Cells selected (this image)",style='readonly'),
                        Spring(),
                        Item("numpeaks_total",label="Total",style='readonly'),                          
                        ),
                    label="Peak parameters", show_border=True),
                ),
            orientation='horizontal'),
        buttons = [ Action(name='OK', enabled_when = 'numpeaks_total > 0' ),
            CancelButton ],
        title="Template Picker",
        handler=OK_custom, kind='livemodal',
        key_bindings = key_bindings,
        width=940, height=530,resizable=True)        

    def __init__(self, signal_instance):
        super(TemplatePicker, self).__init__()
        try:
            import cv
        except:
            try:
                import cv2.cv as cv
            except:
                print "OpenCV unavailable.  Can't do cross correlation without it.  Aborting."
                return None
        self.OK_custom=OK_custom_handler()
        self.sig=signal_instance
        if not hasattr(self.sig.mapped_parameters,"original_files"):
            self.titles=[os.path.splitext(self.sig.mapped_parameters.title)[0]]
        else:
            self.numfiles=len(self.sig.mapped_parameters.original_files.keys())
            self.titles=self.sig.mapped_parameters.original_files.keys()
        tmp_plot_data=ArrayPlotData(imagedata=self.sig.data[self.img_idx,self.top:self.top+self.tmp_size,self.left:self.left+self.tmp_size])
        tmp_plot=Plot(tmp_plot_data,default_origin="top left")
        tmp_plot.img_plot("imagedata", colormap=jet)
        tmp_plot.aspect_ratio=1.0
        self.tmp_plot=tmp_plot
        self.tmp_plotdata=tmp_plot_data
        self.img_plotdata=ArrayPlotData(imagedata=self.sig.data[self.img_idx,:,:])
        self.img_container=self._image_plot_container()

        self.crop_sig=None

    def render_image(self):
        plot = Plot(self.img_plotdata,default_origin="top left")
        img=plot.img_plot("imagedata", colormap=gray)[0]
        plot.title="%s of %s: "%(self.img_idx+1,self.numfiles)+self.titles[self.img_idx]
        plot.aspect_ratio=float(self.sig.data.shape[2])/float(self.sig.data.shape[1])

        csr = CursorTool(img, drag_button='left', color='white',
                         line_width=2.0)
        self.csr=csr
        csr.current_position=self.left, self.top
        img.overlays.append(csr)

        # attach the rectangle tool
        plot.tools.append(PanTool(plot,drag_button="right"))
        zoom = ZoomTool(plot, tool_mode="box", always_on=False, aspect_ratio=plot.aspect_ratio)
        plot.overlays.append(zoom)
        self.img_plot=plot
        return plot

    def render_scatplot(self):
        peakdata=ArrayPlotData()
        peakdata.set_data("index",self.peaks[self.img_idx][:,0])
        peakdata.set_data("value",self.peaks[self.img_idx][:,1])
        peakdata.set_data("color",self.peaks[self.img_idx][:,2])
        scatplot=Plot(peakdata,aspect_ratio=self.img_plot.aspect_ratio,default_origin="top left")
        scatplot.plot(("index", "value", "color"),
                      type="cmap_scatter",
                      name="my_plot",
                      color_mapper=jet(DataRange1D(low = 0.0,
                                       high = 1.0)),
                      marker = "circle",
                      fill_alpha = 0.5,
                      marker_size = 6,
                      )
        scatplot.x_grid.visible = False
        scatplot.y_grid.visible = False
        scatplot.range2d=self.img_plot.range2d
        self.scatplot=scatplot
        self.peakdata=peakdata
        return scatplot

    def _image_plot_container(self):
        plot = self.render_image()

        # Create a container to position the plot and the colorbar side-by-side
        self.container=OverlayPlotContainer()
        self.container.add(plot)
        self.img_container = HPlotContainer(use_backbuffer = False)
        self.img_container.add(self.container)
        self.img_container.bgcolor = "white"

        if self.numpeaks_img>0:
            scatplot = self.render_scatplot()
            self.container.add(scatplot)
            colorbar = self.draw_colorbar()
            self.img_container.add(colorbar)
        return self.img_container

    def draw_colorbar(self):
        scatplot=self.scatplot
        cmap_renderer = scatplot.plots["my_plot"][0]
        selection = ColormappedSelectionOverlay(cmap_renderer, fade_alpha=0.35, 
                                                selection_type="range")
        cmap_renderer.overlays.append(selection)
        if self.thresh is not None:
            cmap_renderer.color_data.metadata['selections']=self.thresh
            cmap_renderer.color_data.metadata_changed={'selections':self.thresh}
        # Create the colorbar, handing in the appropriate range and colormap
        colormap=scatplot.color_mapper
        colorbar = ColorBar(index_mapper=LinearMapper(range=DataRange1D(low = 0.0,
                                       high = 1.0)),
                        orientation='v',
                        resizable='v',
                        width=30,
                        padding=20)
        colorbar_selection=RangeSelection(component=colorbar)
        colorbar.tools.append(colorbar_selection)
        ovr=colorbar.overlays.append(RangeSelectionOverlay(component=colorbar,
                                                   border_color="white",
                                                   alpha=0.8,
                                                   fill_color="lightgray",
                                                   metadata_name='selections'))
        #ipshell('colorbar, colorbar_selection and ovr available:')
        self.cbar_selection=colorbar_selection
        self.cmap_renderer=cmap_renderer
        colorbar.plot = cmap_renderer
        colorbar.padding_top = scatplot.padding_top
        colorbar.padding_bottom = scatplot.padding_bottom
        self.colorbar=colorbar
        return colorbar

    @on_trait_change('ShowCC')
    def toggle_cc_view(self):
        if self.ShowCC:
            self.update_CC()
            grid_data_source = self.img_plot.range2d.sources[0]
            grid_data_source.set_data(np.arange(self.CC.shape[1]), 
                                      np.arange(self.CC.shape[0]))
        else:
            self.img_plotdata.set_data("imagedata",self.sig.data[self.img_idx,:,:])
        self.redraw_plots()

    @on_trait_change("img_idx")
    def update_img_depth(self):
        if self.ShowCC:
            self.update_CC()
        else:
            self.img_plotdata.set_data("imagedata",self.sig.data[self.img_idx,:,:])
        self.img_plot.title="%s of %s: "%(self.img_idx+1,self.numfiles)+self.titles[self.img_idx]
        self.redraw_plots()        

    def _get_max_pos_x(self):
        max_pos_x=self.sig.data.shape[-1]-self.tmp_size-1
        if max_pos_x>0:
            return max_pos_x
        else:
            return None

    def _get_max_pos_y(self):
        max_pos_y=self.sig.data.shape[-2]-self.tmp_size-1
        if max_pos_y>0:
            return max_pos_y
        else:
            return None

    @on_trait_change('next_img')
    def increase_img_idx(self,info):
        if self.img_idx==(self.numfiles-1):
            self.img_idx=0
        else:
            self.img_idx+=1

    @on_trait_change('prev_img')
    def decrease_img_idx(self,info):
        if self.img_idx==0:
            self.img_idx=self.numfiles-1
        else:
            self.img_idx-=1

    @on_trait_change('left, top')
    def update_csr_position(self):
        if self.left>0:        
            self.csr.current_position=self.left,self.top

    @on_trait_change('csr:current_position')
    def update_top_left(self):
        if self.csr.current_position[0]>0 or self.csr.current_position[1]>0:
            if self.csr.current_position[0]>self.max_pos_x:
                if self.csr.current_position[1]<self.max_pos_y:
                    self.top=self.csr.current_position[1]
                else:
                    self.csr.current_position=self.max_pos_x, self.max_pos_y
            elif self.csr.current_position[1]>self.max_pos_y:
                self.left,self.top=self.csr.current_position[0],self.max_pos_y
            else:
                self.left,self.top=self.csr.current_position
        
    @on_trait_change('left, top, tmp_size')
    def update_tmp_plot(self):
        self.tmp_plotdata.set_data("imagedata", 
                                   self.sig.data[self.img_idx,self.top:self.top+self.tmp_size,self.left:self.left+self.tmp_size])
        grid_data_source = self.tmp_plot.range2d.sources[0]
        grid_data_source.set_data(np.arange(self.tmp_size), np.arange(self.tmp_size))
        self.tmp_img_idx=self.img_idx
        if self.numpeaks_total>0:
            print "clearing peaks"
            self.peaks=[np.array([[0,0,-1]])]
        return

    @on_trait_change('left, top, tmp_size')
    def update_CC(self):
        if self.ShowCC:
            self.CC = cv_funcs.xcorr(self.sig.data[self.tmp_img_idx,self.top:self.top+self.tmp_size,
                                                   self.left:self.left+self.tmp_size],
                                     self.sig.data[self.img_idx,:,:])
            self.img_plotdata.set_data("imagedata",self.CC)

    @on_trait_change('cbar_selection:selection')
    def update_thresh(self):
        try:
            thresh=self.cbar_selection.selection
            self.thresh=thresh
            self.cmap_renderer.color_data.metadata['selections']=thresh
            self.thresh_lower=thresh[0]
            self.thresh_upper=thresh[1]
            #cmap_renderer.color_data.metadata['selection_masks']=self.thresh
            self.cmap_renderer.color_data.metadata_changed={'selections':thresh}
            self.container.request_redraw()
            self.img_container.request_redraw()
        except:
            pass

    @on_trait_change('thresh_upper,thresh_lower')
    def manual_thresh_update(self):
        self.thresh=[self.thresh_lower,self.thresh_upper]
        self.cmap_renderer.color_data.metadata['selections']=self.thresh
        self.cmap_renderer.color_data.metadata_changed={'selections':self.thresh}
        self.container.request_redraw()
        self.img_container.request_redraw()

    @on_trait_change('peaks,cbar_selection:selection,img_idx')
    def calc_numpeaks(self):
        try:
            thresh=self.cbar_selection.selection
            self.thresh=thresh
        except:
            thresh=[]
        if thresh==[] or thresh==() or thresh==None:
            thresh=(0,1)
        self.numpeaks_total=int(np.sum([np.sum(np.ma.masked_inside(self.peaks[i][:,2],thresh[0],thresh[1]).mask) for i in xrange(len(self.peaks))]))
        try:
            self.numpeaks_img=int(np.sum(np.ma.masked_inside(self.peaks[self.img_idx][:,2],thresh[0],thresh[1]).mask))
        except:
            self.numpeaks_img=0

    @on_trait_change('findpeaks')
    def locate_peaks(self):
        from hyperspy import peak_char as pc
        peaks=[]
        """from hyperspy.misc.progressbar import ProgressBar, \
            Percentage, RotatingMarker, ETA, Bar
        widgets = ['Locating peaks: ', Percentage(), ' ', Bar(marker=RotatingMarker()),
                   ' ', ETA()]
        pbar = ProgressBar(widgets=widgets, maxval=100).start()"""
        progress = ProgressDialog(title="Peak finder progress", message="Finding peaks on %s images"%self.numfiles, max=self.numfiles, show_time=True, can_cancel=False)
        progress.open()
        for idx in xrange(self.numfiles):
            #pbar.update(float(idx)/self.numfiles*100)
            self.CC = cv_funcs.xcorr(self.sig.data[self.tmp_img_idx,
                                                   self.top:self.top+self.tmp_size,
                                                   self.left:self.left+self.tmp_size],
                                               self.sig.data[idx,:,:])
            # peak finder needs peaks greater than 1.  Multiply by 255 to scale them.
            pks=pc.two_dim_findpeaks(self.CC*255, peak_width=self.peak_width, medfilt_radius=None)
            pks[:,2]=pks[:,2]/255.
            peaks.append(pks)
            progress.update(idx+1)
        #pbar.finish()
        self.peaks=peaks
        
    def mask_peaks(self,idx):
        thresh=self.cbar_selection.selection
        if thresh==[]:
            thresh=(0,1)
        mpeaks=np.ma.asarray(self.peaks[idx])
        mpeaks[:,2]=np.ma.masked_outside(mpeaks[:,2],thresh[0],thresh[1])
        return mpeaks

    @on_trait_change("peaks")
    def redraw_plots(self):
        oldplot=self.img_plot
        self.container.remove(oldplot)
        newplot=self.render_image()
        self.container.add(newplot)
        self.img_plot=newplot

        try:
            # if these haven't been created before, this will fail.  wrap in try to prevent that.
            oldscat=self.scatplot
            self.container.remove(oldscat)
            oldcolorbar = self.colorbar
            self.img_container.remove(oldcolorbar)
        except:
            pass

        if self.numpeaks_img>0:
            newscat=self.render_scatplot()
            self.container.add(newscat)
            self.scatplot=newscat
            colorbar = self.draw_colorbar()
            self.img_container.add(colorbar)
            self.colorbar=colorbar

        self.container.request_redraw()
        self.img_container.request_redraw()

    def crop_cells_stack(self):
        from hyperspy.signals.aggregate import AggregateCells
        if self.numfiles==1:
            self.crop_sig=self.crop_cells()
            return
        else:
            crop_agg=[]
            for idx in xrange(self.numfiles):
                peaks=np.ma.compress_rows(self.mask_peaks(idx))
                if peaks.any():
                    crop_agg.append(self.crop_cells(idx))
            self.crop_sig=AggregateCells(*crop_agg)
            return

    def crop_cells(self,idx=0):
        print "cropping cells..."
        from hyperspy.signals.image import Image
        # filter the peaks that are outside the selected threshold
        peaks=np.ma.compress_rows(self.mask_peaks(idx))
        tmp_sz=self.tmp_size
        data=np.zeros((peaks.shape[0],tmp_sz,tmp_sz))
        if not hasattr(self.sig.mapped_parameters,"original_files"):
            parent=self.sig
        else:
            parent=self.sig.mapped_parameters.original_files[self.titles[idx]]
        pmp=parent.mapped_parameters
        positions=np.zeros((peaks.shape[0],1),dtype=[('filename','a256'),('id','i4'),('position','f4',(1,2))])
        for i in xrange(peaks.shape[0]):
            # crop the cells from the given locations
            data[i,:,:]=self.sig.data[idx,peaks[i,1]:peaks[i,1]+tmp_sz,peaks[i,0]:peaks[i,0]+tmp_sz]
            positions[i]=(self.titles[idx],i,peaks[i,:2])
            crop_sig=Image({'data':data,
                            'mapped_parameters':{
                               'title':'Cropped cells from %s'%self.titles[idx],
                               'record_by':'image',
                               'locations':positions,
                               'original_files':{pmp.title:parent},
                               }
                            })
            
        return crop_sig
Ejemplo n.º 20
0
class CellCropper(StackViewer):
    template = Array
    CC = Array
    peaks = List
    zero=Int(0)
    tmp_size = Range(low=2, high=512, value=64, cols=4)
    max_pos_x=Property(depends_on=['tmp_size'])
    max_pos_y=Property(depends_on=['tmp_size'])
    top = Range(low='zero',high='max_pos_y', value=20, cols=4)
    left = Range(low='zero',high='max_pos_x', value=20, cols=4)
    is_square = Bool
    tmp_plot = Instance(Plot)
    findpeaks = Button
    peak_width = Range(low=2, high=200, value=10)
    ShowCC = Bool
    numpeaks_total = Int(0,cols=5)
    numpeaks_img = Int(0,cols=5)
    OK_custom=OK_custom_handler
    thresh=Trait(None,None,List,Tuple,Array)
    thresh_upper=Float(1.0)
    thresh_lower=Float(-1.0)
    tmp_img_idx=Int(0)

    csr=Instance(BaseCursorTool)

    traits_view = View(
        Group(
            Group(
                Item("img_container",editor=ComponentEditor(), show_label=False),
                HGroup(
                    Item("ShowCC", editor=BooleanEditor(), label="Show cross correlation image"),
                    Spring(),
                    Item("prev_img",editor=ButtonEditor(label="<"),show_label=False, enabled_when='numfiles > 1'),
                    Item("next_img",editor=ButtonEditor(label=">"),show_label=False, enabled_when='numfiles > 1'),
                    ),
                label="Original image", show_border=True, trait_modified="tab_selected",
                orientation='vertical',),
            VGroup(
                Group(
                    HGroup(
                        Item("left", label="Left coordinate", style="custom"),
                        Spring(),
                        Item("top", label="Top coordinate", style="custom"),
                        ),
                    Item("tmp_size", label="Template size", style="custom"),
                    Item("tmp_plot",editor=ComponentEditor(height=256, width=256), show_label=False, resizable=True),
                    label="Template", show_border=True),
                Group(
                    HGroup(
                        Item("peak_width", label="Peak width", style="custom"),
                        Spring(),
                        Item("findpeaks",editor=ButtonEditor(label="Find Peaks"),show_label=False),
                        ),
                    HGroup(
                        Item("thresh_lower",label="Threshold Lower Value", editor=TextEditor(evaluate=float,
                                                                                             format_str='%1.4f')),
                        Spring(),
                        Item("thresh_upper",label="Threshold Upper Value", editor=TextEditor(evaluate=float,
                                                                                             format_str='%1.4f')),
                    ),
                    HGroup(
                        Item("numpeaks_img",label="Number of Cells selected (this image)",style='readonly'),
                        Spring(),
                        Item("numpeaks_total",label="Total",style='readonly'),                          
                        ),
                    label="Peak parameters", show_border=True),
                ),
            orientation='horizontal'),
        buttons = [ Action(name='OK', enabled_when = 'numpeaks_total > 0' ),
            CancelButton ],
        title="Template Picker",
        handler=OK_custom, kind='livemodal',
        key_bindings = key_bindings,
        width=940, height=530,resizable=True)        

    def __init__(self, controller, *args, **kw):
        super(CellCropper, self).__init__(controller, *args, **kw)
        try:
            import cv
        except:
            try:
                import cv2.cv as cv
            except:
                print "OpenCV unavailable.  Can't do cross correlation without it.  Aborting."
                return None
        self.OK_custom=OK_custom_handler()
        self.template = self.data[self.top:self.top+self.tmp_size,self.left:self.left+self.tmp_size]
        tmp_plot_data=ArrayPlotData(imagedata=self.template)
        tmp_plot=Plot(tmp_plot_data,default_origin="top left")
        tmp_plot.img_plot("imagedata", colormap=jet)
        tmp_plot.aspect_ratio=1.0
        self.tmp_plot=tmp_plot
        self.tmp_plotdata=tmp_plot_data
        self.crop_sig=None

    def render_image(self):
        plot = super(CellCropper,self).render_image()
        img=plot.img_plot("imagedata", colormap=gray)[0]
        csr = CursorTool(img, drag_button='left', color='white',
                         line_width=2.0)
        self.csr=csr
        csr.current_position=self.left, self.top
        img.overlays.append(csr)
        self.img_plot=plot
        return plot

    # TODO: use base class render_scatter_overlay method
    def render_scatplot(self):
        peakdata=ArrayPlotData()
        peakdata.set_data("index",self.peaks[self.img_idx][:,0])
        peakdata.set_data("value",self.peaks[self.img_idx][:,1])
        peakdata.set_data("color",self.peaks[self.img_idx][:,2])
        scatplot=Plot(peakdata,aspect_ratio=self.img_plot.aspect_ratio,default_origin="top left")
        scatplot.plot(("index", "value", "color"),
                      type="cmap_scatter",
                      name="my_plot",
                      color_mapper=jet(DataRange1D(low = 0.0,
                                       high = 1.0)),
                      marker = "circle",
                      fill_alpha = 0.5,
                      marker_size = 6,
                      )
        scatplot.x_grid.visible = False
        scatplot.y_grid.visible = False
        scatplot.range2d=self.img_plot.range2d
        self.scatplot=scatplot
        self.peakdata=peakdata
        return scatplot

    def _image_plot_container(self):
        plot = self.render_image()

        # Create a container to position the plot and the colorbar side-by-side
        self.container=OverlayPlotContainer()
        self.container.add(plot)
        self.img_container = HPlotContainer(use_backbuffer = False)
        self.img_container.add(self.container)
        self.img_container.bgcolor = "white"
        #ipdb.set_trace()

        if self.numpeaks_img>0:
            scatplot = self.render_scatplot()
            self.container.add(scatplot)
            colorbar = self.draw_colorbar()
            self.img_container.add(colorbar)
        return self.img_container

    def draw_colorbar(self):
        scatplot=self.scatplot
        cmap_renderer = scatplot.plots["my_plot"][0]
        selection = ColormappedSelectionOverlay(cmap_renderer, fade_alpha=0.35, 
                                                selection_type="range")
        cmap_renderer.overlays.append(selection)
        if self.thresh is not None:
            cmap_renderer.color_data.metadata['selections']=self.thresh
            cmap_renderer.color_data.metadata_changed={'selections':self.thresh}
        # Create the colorbar, handing in the appropriate range and colormap
        colormap=scatplot.color_mapper
        colorbar = ColorBar(
            index_mapper=LinearMapper(
                range=DataRange1D(
                    low = -1.0,
                    high = 1.0)
                ),
            orientation='v',
            resizable='v',
            width=30,
            padding=20)
        colorbar_selection=RangeSelection(component=colorbar)
        colorbar.tools.append(colorbar_selection)
        ovr=colorbar.overlays.append(
            RangeSelectionOverlay(component=colorbar,
                                  border_color="white",
                                  alpha=0.8,
                                  fill_color="lightgray",
                                  metadata_name='selections'))
        #ipshell('colorbar, colorbar_selection and ovr available:')
        self.cbar_selection=colorbar_selection
        self.cmap_renderer=cmap_renderer
        colorbar.plot = cmap_renderer
        colorbar.padding_top = scatplot.padding_top
        colorbar.padding_bottom = scatplot.padding_bottom
        self.colorbar=colorbar
        return colorbar

    @on_trait_change('ShowCC')
    def toggle_cc_view(self):
        if self.ShowCC:
            self.update_CC()
            grid_data_source = self.img_plot.range2d.sources[0]
            grid_data_source.set_data(np.arange(self.CC.shape[1]), 
                                      np.arange(self.CC.shape[0]))
        else:
            self.img_plotdata.set_data("imagedata",self.data)
        self.redraw_plots()

    @on_trait_change("img_idx")
    def update_img_depth(self):
        """
        TODO: We look up the index in the model - first get a list of files, then get the name
        of the file at the given index.
        """
        super(CellCropper, self).update_img_depth()
        if self.ShowCC:
            self.update_CC()
        self.redraw_plots()        

    def _get_max_pos_x(self):
        max_pos_x=self.data.shape[-1]-self.tmp_size-1
        if max_pos_x>0:
            return max_pos_x
        else:
            return None

    def _get_max_pos_y(self):
        max_pos_y=self.data.shape[-2]-self.tmp_size-1
        if max_pos_y>0:
            return max_pos_y
        else:
            return None

    @on_trait_change('left, top')
    def update_csr_position(self):
        if self.left>0:        
            self.csr.current_position=self.left,self.top

    @on_trait_change('csr:current_position')
    def update_top_left(self):
        if self.csr.current_position[0]>0 or self.csr.current_position[1]>0:
            if self.csr.current_position[0]>self.max_pos_x:
                if self.csr.current_position[1]<self.max_pos_y:
                    self.top=self.csr.current_position[1]
                else:
                    self.csr.current_position=self.max_pos_x, self.max_pos_y
            elif self.csr.current_position[1]>self.max_pos_y:
                self.left,self.top=self.csr.current_position[0],self.max_pos_y
            else:
                self.left,self.top=self.csr.current_position
        
    @on_trait_change('left, top, tmp_size')
    def update_tmp_plot(self):
        self.template = self.data[self.top:self.top+self.tmp_size,self.left:self.left+self.tmp_size]
        self.tmp_plotdata.set_data("imagedata", self.template)
        grid_data_source = self.tmp_plot.range2d.sources[0]
        grid_data_source.set_data(np.arange(self.tmp_size), np.arange(self.tmp_size))
        self.tmp_img_idx=self.img_idx
        if self.numpeaks_total>0:
            print "clearing peaks"
            self.peaks=[np.array([[0,0,-1]])]
        self.update_CC()
        return

    def update_CC(self):
        if self.ShowCC:
            self.CC = cv_funcs.xcorr(self.template, self.data)
            self.img_plotdata.set_data("imagedata",self.CC)

    @on_trait_change('cbar_selection:selection')
    def update_thresh(self):
        try:
            thresh=self.cbar_selection.selection
            self.thresh=thresh
            self.cmap_renderer.color_data.metadata['selections']=thresh
            self.thresh_lower=thresh[0]
            self.thresh_upper=thresh[1]
            #cmap_renderer.color_data.metadata['selection_masks']=self.thresh
            self.cmap_renderer.color_data.metadata_changed={'selections':thresh}
            self.container.request_redraw()
            self.img_container.request_redraw()
        except:
            pass

    @on_trait_change('thresh_upper,thresh_lower')
    def manual_thresh_update(self):
        self.thresh=[self.thresh_lower,self.thresh_upper]
        self.cmap_renderer.color_data.metadata['selections']=self.thresh
        self.cmap_renderer.color_data.metadata_changed={'selections':self.thresh}
        self.container.request_redraw()
        self.img_container.request_redraw()

    @on_trait_change('peaks,cbar_selection:selection,img_idx')
    def calc_numpeaks(self):
        try:
            thresh=self.cbar_selection.selection
            self.thresh=thresh
        except:
            thresh=[]
        if thresh==[] or thresh==() or thresh==None:
            thresh=(-1,1)
        self.numpeaks_total=int(np.sum([np.sum(np.ma.masked_inside(self.peaks[i][:,2],thresh[0],thresh[1]).mask) for i in xrange(len(self.peaks))]))
        try:
            self.numpeaks_img=int(np.sum(np.ma.masked_inside(self.peaks[self.img_idx][:,2],thresh[0],thresh[1]).mask))
        except:
            self.numpeaks_img=0

    @on_trait_change('findpeaks')
    def locate_peaks(self):
        peaks=[]
        progress = ProgressDialog(title="Peak finder progress", message="Finding peaks on %s images"%self.numfiles, max=self.numfiles, show_time=True, can_cancel=False)
        progress.open()
        for idx in xrange(self.numfiles):
            self.controller.set_active_index(idx)
            self.data = self.controller.get_active_image()[:]
            self.CC = cv_funcs.xcorr(self.template, self.data)
            # peak finder needs peaks greater than 1.  Multiply by 255 to scale them.
            pks=pc.two_dim_findpeaks(self.CC*255, peak_width=self.peak_width, medfilt_radius=None)
            pks[:,2]=pks[:,2]/255.
            peaks.append(pks)
            progress.update(idx+1)
        #ipdb.set_trace()
        self.peaks=peaks
        self.redraw_plots()
        
    def mask_peaks(self,idx):
        thresh=self.cbar_selection.selection
        if thresh==[]:
            thresh=(-1,1)
        mpeaks=np.ma.asarray(self.peaks[idx])
        mpeaks[:,2]=np.ma.masked_outside(mpeaks[:,2],thresh[0],thresh[1])
        return mpeaks

    @on_trait_change("peaks")
    def redraw_plots(self):
        oldplot=self.img_plot
        self.container.remove(oldplot)
        newplot=self.render_image()
        self.container.add(newplot)
        self.img_plot=newplot

        try:
            # if these haven't been created before, this will fail.  wrap in try to prevent that.
            oldscat=self.scatplot
            self.container.remove(oldscat)
            oldcolorbar = self.colorbar
            self.img_container.remove(oldcolorbar)
        except:
            pass

        if self.numpeaks_img>0:
            newscat=self.render_scatplot()
            self.container.add(newscat)
            self.scatplot=newscat
            colorbar = self.draw_colorbar()
            self.img_container.add(colorbar)
            self.colorbar=colorbar

        self.container.request_redraw()
        self.img_container.request_redraw()

    def crop_cells(self):
        print "cropping cells..."
        for idx in xrange(self.numfiles):
            # filter the peaks that are outside the selected threshold
            self.controller.set_active_index(idx)
            self.data = self.controller.get_active_image()
            self.name = self.controller.get_active_name()
            peaks=np.ma.compress_rows(self.mask_peaks(idx))
            tmp_sz=self.tmp_size
            data=np.zeros((peaks.shape[0],tmp_sz,tmp_sz))
            if data.shape[0] >0:
                for i in xrange(peaks.shape[0]):
                    # crop the cells from the given locations
                    data[i,:,:]=self.data[peaks[i,1]:peaks[i,1]+tmp_sz, 
                                      peaks[i,0]:peaks[i,0]+tmp_sz]
                # send the data to the controller for storage in the chest
                self.controller.add_cells(name = self.name, data = data,
                                      locations = peaks)
Ejemplo n.º 21
0
class MainApp(HasTraits):
    container = Instance(OverlayPlotContainer)

    file_paths = List(Str)
    # Button group above tabbed area
    open_files = Button("Open files...")
    edit_datasets = Button("Edit datasets...")
    generate_plot = Button("Generate plot...")
    help_button = Button("Help...")

    # View tab
    scale = Enum('linear', 'log', 'sqrt')
    options = List
    reset_button = Button("Reset view")
    copy_to_clipboard = Button("Copy to clipboard")
    save_as_image = Button("Save as image...")

    # Process tab
    merge_positions = Enum('all', 'p1+p2', 'p3+p4', 'p12+p34')('p1+p2')
    load_partners = Button
    splice = Bool(True)
    merge = Bool(False)
    merge_regrid = Bool(False)
    normalise = Bool(True)
    # See comment in class Global() for an explanation of the following traits
    g = Instance(Global, ())
    file_list = DelegatesTo('g')
    normalisation_source_filenames = Enum(values='file_list')

    def _g_default(self):
        return g

    correction = Float(0.0)
    align_positions = Bool(False)
    bt_start_peak_select = Button
    bt_end_peak_select = Button
    peak_selecting = Bool(False)

    what_to_plot = Enum('Plot new', 'Plot old and new')('Plot old and new')

    bt_process = Button("Apply")
    bt_undo_processing = Button("Undo")
    bt_save = Button("Save...")

    # Background removal tab
    bt_manually_define_background = Button("Define")
    polynomial_order = Range(1, 20)(7)
    bt_poly_fit = Button("Poly fit")
    bt_load_background = Button("Load...")

    # theta/d/Q tab
    filename_field = Str("d")
    bt_convertscale_abscissa = Button("Convert/scale abscissa...")

    raw_data_plot = Instance(RawDataPlot)

    #-------------------------------------------------------------------------------------
    # MVC View
    view_group = VGroup(
        Label('Scale:'),
        UItem('scale', enabled_when='object._has_data()'),
        UItem('options',
              editor=CheckListEditor(name='_options'),
              style='custom',
              enabled_when='object._has_data()'),
        UItem('reset_button', enabled_when='object._has_data()'),
        spring,
        '_',
        spring,
        UItem('copy_to_clipboard', enabled_when='object._has_data()'),
        UItem('save_as_image', enabled_when='object._has_data()'),
        label='View',
        springy=False,
    )

    process_group = VGroup(
        VGroup(
            Label('Positions to process:'),
            UItem(name='merge_positions',
                  style='custom',
                  editor=EnumEditor(values={
                      'p1+p2': '1: p1+p2',
                      'p3+p4': '2: p3+p4',
                      'p12+p34': '3: p12+p34',
                      'all': '4: all',
                  },
                                    cols=2),
                  enabled_when='object._has_data()'),
            UItem('load_partners',
                  enabled_when=
                  'object._has_data() and (object.merge_positions != "all")'),
            show_border=True,
        ),
        VGroup(
            HGroup(Item('align_positions'),
                   enabled_when=
                   'object._has_data() and (object.merge_positions != "all")'),
            HGroup(
                UItem(
                    'bt_start_peak_select',
                    label='Select peak',
                    enabled_when=
                    'object.align_positions and not object.peak_selecting and (object.merge_positions != "all")'
                ),
                UItem(
                    'bt_end_peak_select',
                    label='Align',
                    enabled_when=
                    'object.peak_selecting and (object.merge_positions != "all")'
                ),
            ),
            Item('correction',
                 label='Zero correction:',
                 enabled_when='object._has_data()'),
            show_border=True,
        ),
        VGroup(
            HGroup(Item('splice'),
                   Item('merge',
                        enabled_when='object.merge_positions != "p12+p34"'),
                   enabled_when=
                   'object._has_data() and (object.merge_positions != "all")'),
            HGroup(
                Item(
                    'normalise',
                    label='Normalise',
                    enabled_when=
                    'object._has_data() and (object.merge_positions != "p12+p34")'
                ),
                Item('merge_regrid',
                     label='Grid',
                     enabled_when='object._has_data()'),
            ),
            VGroup(
                Label('Normalise to:'),
                UItem('normalisation_source_filenames',
                      style='simple',
                      enabled_when='object.normalise and object._has_data()'),
            ),
            show_border=True,
        ),
        spring,
        UItem('what_to_plot',
              editor=DefaultOverride(cols=2),
              style='custom',
              enabled_when='object._has_data()'),
        spring,
        UItem('bt_process', enabled_when='object._has_data()'),
        UItem('bt_undo_processing',
              enabled_when='object.undo_state is not None'),
        UItem('bt_save', enabled_when='object._has_data()'),
        label='Process',
        springy=False,
    )

    background_removal_group = VGroup(
        VGroup(
            Label('Manually define:'),
            UItem('bt_manually_define_background',
                  enabled_when='object._has_data()'),
            show_border=True,
        ),
        VGroup(
            Label('Fit polynomial:'),
            HGroup(
                Item('polynomial_order',
                     label='order',
                     enabled_when='object._has_data()'), ),
            UItem('bt_poly_fit', enabled_when='object._has_data()'),
            show_border=True,
        ),
        VGroup(
            Label('Load from file:'),
            UItem('bt_load_background', enabled_when='object._has_data()'),
            show_border=True,
        ),
        label='Backgrnd',
        springy=False,
    )

    convert_xscale_group = VGroup(
        Label('Filename label (prefix_<label>_nnnn.xye):'),
        UItem('filename_field', enabled_when='object._has_data()'),
        UItem(
            'bt_convertscale_abscissa',
            label='Convert/scale abscissa...',
            enabled_when='object._has_data()',
        ),
        label=ur'\u0398 d Q',
        springy=True,
    )

    traits_view = View(
        HGroup(
            VGroup(
                UItem('open_files'),
                UItem('edit_datasets', enabled_when='object._has_data()'),
                UItem('generate_plot', enabled_when='object._has_data()'),
                UItem('help_button'),
                spring,
                spring,
                Tabbed(
                    view_group,
                    process_group,
                    # background_removal_group,
                    convert_xscale_group,
                    springy=False,
                ),
                show_border=False,
            ),
            UItem('container', editor=ComponentEditor(bgcolor='white')),
            show_border=False,
        ),
        resizable=True,
        title=title,
        width=size[0],
        height=size[1])

    #-------------------------------------------------------------------------------------
    # MVC Control

    def _has_data(self):
        return len(self.datasets) != 0

    def __init__(self, *args, **kws):
        """
        self.datasets = [ <XYEDataset>, ..., <XYEDataset> ]
        self.dataset_pairs = set([ (<XYEDataset-p1>, <XYEDataset-p2>),
                                   ...,
                                   (<XYEDataset-p1>, <XYEDataset-p2>) ])
        """
        super(MainApp, self).__init__(*args, **kws)
        self.datasets = []
        self.dataset_pairs = set()
        self.undo_state = None
        self.raw_data_plot = RawDataPlot()
        self.plot = self.raw_data_plot.get_plot()
        self.container = OverlayPlotContainer(self.plot,
                                              bgcolor="white",
                                              use_backbuffer=True,
                                              border_visible=False)
        self.pan_tool = None
        # The list of all options.
        self._options = ['Show legend', 'Show gridlines', 'Show crosslines']
        # The list of currently set options, updated by the UI.
        self.options = self._options
        self.file_paths = []

    def _open_files_changed(self):
        file_list = get_file_list_from_dialog()
        if file_list:
            self.file_paths = file_list

    def _options_changed(self, opts):
        # opts just contains the keys that are true.
        # Create a dict all_options that has True/False for each item.
        all_options = dict.fromkeys(self._options, False)
        true_options = dict.fromkeys(opts, True)
        all_options.update(true_options)
        self.raw_data_plot.show_legend(all_options['Show legend'])
        self.raw_data_plot.show_grids(all_options['Show gridlines'])
        self.raw_data_plot.show_crosslines(all_options['Show crosslines'])
        self.container.request_redraw()

    def _bt_start_peak_select_changed(self):
        self.raw_data_plot.start_range_select()
        self.peak_selecting = True

    def _bt_end_peak_select_changed(self):
        self.peak_selecting = False
        selection_range = self.raw_data_plot.end_range_select()
        if not selection_range:
            return

        range_low, range_high = selection_range
        # fit the peak in all loaded dataseries
        self._get_partners()
        for datapair in self._get_dataset_pairs():
            processing.fit_peaks_for_a_dataset_pair(range_low, range_high,
                                                    datapair, self.normalise)
        editor = PeakFitWindow(dataset_pairs=self._get_dataset_pairs(),
                               range=selection_range)
        editor.edit_traits()

    def _get_dataset_pairs(self):
        datasets_dict = dict([(d.name, d) for d in self.datasets])
        return [ (datasets_dict[file1], datasets_dict[file2]) \
                    for file1, file2 in self.dataset_pairs ]

    def _bt_process_changed(self):
        '''
        Button click event handler for processing. 
        '''
        # Save the unprocessed data series at this point for later undoing
        processed_datasets = []
        processor = DatasetProcessor(self.normalise, self.correction,
                                     self.align_positions, self.splice,
                                     self.merge, self.merge_regrid,
                                     self.normalisation_source_filenames,
                                     self.datasets)
        # Processing at this point depends on the "Positions to process:" radiobutton
        # selection:
        # If Splice==True, get all pairs and splice them
        # If Merge==True, get all pairs and merge them
        # If Normalise==True, always normalise
        # If Grid===True, output gridded and ungridded
        # The following processing code sould really be placed into a processor.process()
        # method, but I only worked out how to pass required stuff late in the day, so
        # I do this stuff here.
        if self.merge_positions == 'p12+p34':
            self._get_partners(
            )  # pair up datasets corresponding to the radiobutton selection
            for dataset_pair in self._get_dataset_pairs():
                datasets = processor.splice_overlapping_datasets(dataset_pair)
                for dataset in datasets:
                    dataset.metadata['ui'].name = dataset.name + ' (processed)'
                    dataset.metadata['ui'].color = None
                processed_datasets.extend(datasets)
        elif self.merge_positions == 'all':
            # Handle "all" selection for regrid and normalise
            for d in self.datasets:
                dataset = processor.normalise_me(d)
                if dataset is not None:
                    processed_datasets.extend([dataset])
                    dataset.metadata['ui'].name = dataset.name + ' (processed)'
                    dataset.metadata['ui'].color = None
                    d = dataset

                dataset = processor.regrid_me(d)
                if dataset is not None:
                    processed_datasets.extend([dataset])
                    dataset.metadata['ui'].name = dataset.name + ' (processed)'
                    dataset.metadata['ui'].color = None
        else:
            self._get_partners(
            )  # pair up datasets corresponding to the radiobutton selection
            for dataset_pair in self._get_dataset_pairs():
                datasets = processor.process_dataset_pair(dataset_pair)
                for dataset in datasets:
                    dataset.metadata['ui'].name = dataset.name + ' (processed)'
                    dataset.metadata['ui'].color = None
                processed_datasets.extend(datasets)

        self.processed_datasets = processed_datasets
        self._plot_processed_datasets()

    def _plot_processed_datasets(self):
        self._save_state()
        self.dataset_pairs = set(
        )  # TODO: Check whether this line should be removed
        if 'old' not in self.what_to_plot:
            self.datasets = []
        if 'new' in self.what_to_plot:
            self.datasets.extend(self.processed_datasets)
        self._plot_datasets()

    def _save_state(self):
        self.undo_state = (self.datasets[:], self.dataset_pairs.copy())

    def _restore_state(self):
        if self.undo_state is not None:
            self.datasets, self.dataset_pairs = self.undo_state
            self.undo_state = None

    def _bt_undo_processing_changed(self):
        self._restore_state()
        self._plot_datasets()

    def _bt_save_changed(self):
        wildcard = 'All files (*.*)|*.*'
        default_filename = 'prefix_'
        dlg = FileDialog(title='Save results',
                         action='save as',
                         default_filename=default_filename,
                         wildcard=wildcard)
        if dlg.open() == OK:
            for dataset in self.processed_datasets:
                filename = os.path.join(dlg.directory,
                                        dlg.filename + dataset.name)
                dataset.save(filename)
            open_file_dir_with_default_handler(dlg.path)

    def _save_as_image_changed(self):
        if len(self.datasets) == 0:
            return
        filename = get_save_as_filename()
        if filename:
            PlotOutput.save_as_image(self.container, filename)
            open_file_dir_with_default_handler(filename)

    def _copy_to_clipboard_changed(self):
        if self.datasets:
            PlotOutput.copy_to_clipboard(self.container)

    def _scale_changed(self):
        self._plot_datasets()

    def _get_partner(self, position_index):
        # return index of partner; i.e., 2=>1, 1=>2, 3=>4, 4=>3, 12=>34, 34=>12
        if position_index in [1, 2, 3, 4]:
            partner = ((position_index - 1) ^ 1) + 1
        elif position_index == 12:
            partner = 34
        elif position_index == 34:
            partner = 12
        else:
            raise 'unparsable position'
        return partner

    def _get_position(self, filename):
        m = re.search('_p([0-9]*)_', filename)
        try:
            return int(m.group(1))
        except (AttributeError, ValueError):
            return None

    def _add_dataset_pair(self, filename):
        current_directory, filebase = os.path.split(filename)
        position_index = self._get_position(filename)
        if position_index is None:
            return

        # base filename for the associated position.
        other_filebase = re.sub(
            '_p{}_'.format(position_index),
            '_p{}_'.format(self._get_partner(position_index)), filebase)
        other_filename = os.path.join(current_directory, other_filebase)
        if not os.path.exists(other_filename):
            return

        # OK, we've got the names and paths, now add the actual data and references.
        if other_filename not in self.file_paths:
            self._add_xye_dataset(other_filename)
            # We also need to append the new path to the file_paths List trait which is
            # already populated by the files selected using the file selection dialog
            self.file_paths.append(other_filename)
        self._refresh_normalise_to_list()

    def _get_partners(self):
        """
        Populates the self.dataset_pairs list with all dataset partners in
        self.file_paths corresponding to the merge_positions radiobutton selection.
        """
        matching_re = {
            'all': '',
            'p1+p2': '_p[12]_',
            'p3+p4': '_p[34]_',
            'p12+p34': '_p(?:12|34)_',
        }
        basenames = [os.path.basename(f) for f in self.file_paths]
        filtered_paths = [
            f for f in basenames
            if re.search(matching_re[self.merge_positions], f) is not None
        ]
        self.dataset_pairs = set()
        for filebase in filtered_paths:
            # base filename for the first position.
            position_index = self._get_position(filebase)
            if position_index is None:
                return
            other_filebase = re.sub(
                '_p{}_'.format(position_index),
                '_p{}_'.format(self._get_partner(position_index)), filebase)
            if filebase in basenames and other_filebase in basenames:
                if position_index != 12 and (position_index & 1) == 0:
                    self.dataset_pairs.add((other_filebase, filebase))
                else:
                    self.dataset_pairs.add((filebase, other_filebase))
        return self.dataset_pairs

    def _refresh_normalise_to_list(self):
        g.populate_list(self.file_paths)

    def _file_paths_changed(self, new):
        """
        When the file dialog box is closed with a selection of filenames,
        just generate a list of all the filenames
        """
        self.datasets = []
        # self.file_paths is modified by _add_dataset_pair() so iterate over a copy of it.
        for filename in self.file_paths[:]:
            self._add_xye_dataset(filename)
        self._plot_datasets()
        self.datasets.sort(key=lambda d: d.name)
        self._refresh_normalise_to_list()

    def _load_partners_changed(self):
        for filename in self.file_paths[:]:
            self._add_dataset_pair(filename)
        self._plot_datasets()
        self.datasets.sort(key=lambda d: d.name)

    def _plot_datasets(self, reset_view=True):
        self.raw_data_plot.plot_datasets(self.datasets,
                                         scale=self.scale,
                                         reset_view=reset_view)
        self._options_changed(self.options)
        self.container.request_redraw()

    def _edit_datasets_changed(self):
        editor = DatasetEditor(datasets=self.datasets)
        editor.edit_traits()
        self._plot_datasets(reset_view=False)

    def _generate_plot_changed(self):
        if self.datasets:
            generator = PlotGenerator(datasets=self.datasets)
            generator.show()

    def _help_button_changed(self):
        help_box = HelpBox()
        help_box.edit_traits()

    def _reset_button_changed(self):
        self.raw_data_plot.reset_view()

    def _add_xye_dataset(self, file_path):
        try:
            dataset = XYEDataset.from_file(file_path)
        except IOError:
            return
        self.datasets.append(dataset)
        create_datasetui(dataset)

    def _bt_convertscale_abscissa_changed(self):
        editor = WavelengthEditor(datasets=self.datasets,
                                  filename_field=self.filename_field)
        editor.edit_traits()
        self._plot_datasets(reset_view=False)