コード例 #1
0
ファイル: line_plot_ui.py プロジェクト: xiaoyu-wu/phyreslib
class LinePlotUI(HasTraits):
    # Line data source
    line_data_source = Instance(LineDataSource)

    # container for all plots
    container = Instance(HPlotContainer)

    # Plot components within this container
    line_plot = Instance(Plot)

    # Plot data
    pd = Instance(ArrayPlotData)

    # Traits view definitions:
    traits_view = View(Group(
        UItem('container', editor=ComponentEditor(size=(500, 200)))),
                       resizable=True)

    # -------------------------------------------------------------------------
    # Private Traits
    # -------------------------------------------------------------------------

    # -------------------------------------------------------------------------
    # Public View interface
    # -------------------------------------------------------------------------

    def __init__(self, line_data_source=None):
        super(LinePlotUI, self).__init__()
        with errstate(invalid='ignore'):
            self.create_plot()
        self.line_data_source = line_data_source

    def create_plot(self):

        self.pd = ArrayPlotData(line_index=array([]), line_value=array([]))

        # Create the colormapped scalar plot
        self.line_plot = Plot(self.pd)
        self.line_plot.plot(("line_index", "line_value"), type="line")

        # Create a container and add components
        self.container = HPlotContainer()
        self.container.add(self.line_plot)

    @on_trait_change('line_data_source.data_source_changed')
    def update(self):
        xs = self.line_data_source.xs
        ys = self.line_data_source.ys
        self.pd.update(line_index=xs, line_value=ys)
        self.container.invalidate_draw()
        self.container.request_redraw()
コード例 #2
0
    def _load_image_data(self, data):
        cont = HPlotContainer()
        pd = ArrayPlotData()
        plot = Plot(data=pd, padding=[30, 5, 5, 30], default_origin='top left')

        pd.set_data('img', data)
        img_plot = plot.img_plot('img', )[0]

        self._add_inspector(img_plot)
        self._add_tools(img_plot)

        cont.add(plot)
        cont.request_redraw()
        self.image_container.container = cont
コード例 #3
0
ファイル: image_browser.py プロジェクト: OSUPychron/pychron
    def _load_image_data(self, data):
        cont = HPlotContainer()
        pd = ArrayPlotData()
        plot = Plot(data=pd, padding=[30, 5, 5, 30], default_origin="top left")

        pd.set_data("img", data)
        img_plot = plot.img_plot("img")[0]

        self._add_inspector(img_plot)
        self._add_tools(img_plot)

        cont.add(plot)
        cont.request_redraw()
        self.image_container.container = cont
コード例 #4
0
ファイル: image_viewer.py プロジェクト: UManPychron/pychron
class ImageViewer(HasTraits):
    container = Instance(HPlotContainer, ())
    plot = Any

    def load_image(self, path):
        if os.path.isfile(path):
            with open(path, 'r') as fp:
                self.set_image(fp)

    def set_image(self, buf):
        '''
            buf is a file-like object
        '''
        self.container = HPlotContainer()
        pd = ArrayPlotData(x=[0, 640],
                           y=[0, 480])
        padding = [30, 5, 5, 30]
        plot = Plot(data=pd, padding=padding,
#                    default_origin=''
                    )
        self.plot = plot.plot(('x', 'y'),)[0]
        self.plot.index.sort_order = 'ascending'
        imo = ImageUnderlay(self.plot,
                            padding=padding,
                            path=buf)
        self.plot.overlays.append(imo)

        self._add_tools(self.plot)

        self.container.add(plot)
        self.container.request_redraw()

    def _add_tools(self, plot):
        inspector = XYInspector(plot)
        plot.tools.append(inspector)
        plot.overlays.append(XYInspectorOverlay(inspector=inspector,
                                                component=plot,
                                                align='ul',
                                                bgcolor=0xFFFFD2
                                                ))
コード例 #5
0
class ImageViewer(HasTraits):
    container = Instance(HPlotContainer, ())
    plot = Any

    def load_image(self, path):
        if os.path.isfile(path):
            with open(path, 'r') as rfile:
                self.set_image(rfile)

    def set_image(self, buf):
        '''
            buf is a file-like object
        '''
        self.container = HPlotContainer()
        pd = ArrayPlotData(x=[0, 640],
                           y=[0, 480])
        padding = [30, 5, 5, 30]
        plot = Plot(data=pd, padding=padding,
#                    default_origin=''
                    )
        self.plot = plot.plot(('x', 'y'),)[0]
        self.plot.index.sort_order = 'ascending'
        imo = ImageUnderlay(self.plot,
                            padding=padding,
                            path=buf)
        self.plot.overlays.append(imo)

        self._add_tools(self.plot)

        self.container.add(plot)
        self.container.request_redraw()

    def _add_tools(self, plot):
        inspector = XYInspector(plot)
        plot.tools.append(inspector)
        plot.overlays.append(XYInspectorOverlay(inspector=inspector,
                                                component=plot,
                                                align='ul',
                                                bgcolor=0xFFFFD2
                                                ))
コード例 #6
0
class MandelbrotPlot(HasTraits):
    """View and Controller for the Mandelbrot Plot interface."""

    mandelbrot_model_view = View(
        VGroup(
            HGroup(
                Item("use_multiprocessing", label="multiprocessing?"),
                Item("number_of_processors", label="processors"),
                label="multiprocessing",
                show_border=True,
            ),
            VGroup(
                Item("max_iterations"),
                Item("x_steps"),
                Item("y_steps"),
                label="calculation",
                show_border=True,
            ),
        ),
        buttons=OKCancelButtons,
        kind="modal",
        resizable=True,
        title="Mandelbrot calculation settings",
        icon=None,
    )
    traits_view = View(
        VGroup(
            HGroup(
                Item("mandelbrot_model", show_label=False),
                Item("colormap"),
                Item("reset_button", show_label=False),
            ),
            Item("container", editor=ComponentEditor(), show_label=False),
            orientation="vertical",
        ),
        resizable=True,
        title="mandelbrot",
    )

    mandelbrot_model = Instance(MandelbrotModel, view=mandelbrot_model_view)
    container = Instance(HPlotContainer)
    colormap = Enum(colormaps)
    reset_button = Button(label="reset")

    _initial_region = (-2.25, 0.75, -1.25, 1.25)
    _plot_data = ArrayPlotData()
    _plot_object = Plot(_plot_data)

    def __init__(self, *args, **kwargs):
        """Instantiates the mandelbrot model and necessary objects for the view/controller."""
        super().__init__(*args, **kwargs)
        self.mandelbrot_model = MandelbrotModel()
        self._update_with_initial_plot_data()

        self.image_plot = self._create_image_plot()
        self.container = HPlotContainer(padding=10,
                                        fill_padding=True,
                                        bgcolor="white",
                                        use_backbuffer=True)
        self.container.add(self._plot_object)
        self._fix_aspect_ratio()
        self._colormap_changed()
        self._append_tools()

    def _create_image_plot(self):
        """Return a Chaco image plot from the plot object referencing the ArrayPlotData fields."""
        return self._plot_object.img_plot("mandelbrot",
                                          xbounds="x_bounds",
                                          ybounds="y_bounds")[0]

    def _fix_aspect_ratio(self):
        """Fix the aspect ratio of the container."""
        x_width = (self.mandelbrot_model.latest_xs[-1] -
                   self.mandelbrot_model.latest_xs[0])
        y_width = (self.mandelbrot_model.latest_ys[-1] -
                   self.mandelbrot_model.latest_ys[0])
        self.container.aspect_ratio = x_width / y_width

    def _update_plot_data(self, mandelbrot_C: numpy.ndarray) -> None:
        """Update the plot_data attribute with passed values.

        Updates the main 2d-array data with mandelbrot_C values
        Updates the x bounds and y bounds of the image with the latest values
         stored on the mandelbrot model.

        :param mandelbrot_C:
        """
        self._plot_data.set_data("mandelbrot", mandelbrot_C)
        self._plot_data.set_data("x_bounds", self.mandelbrot_model.latest_xs)
        self._plot_data.set_data("y_bounds", self.mandelbrot_model.latest_ys)
        if hasattr(self, "image_plot"):
            self.image_plot.index.set_data(
                xdata=self.mandelbrot_model.latest_xs,
                ydata=self.mandelbrot_model.latest_ys,
            )

    def _append_tools(self):
        """Add tools to the necessary components."""
        zoom = RecalculatingZoomTool(
            self._zoom_recalculation_method,
            component=self.image_plot,
            tool_mode="box",
            always_on=True,
        )
        self.image_plot.overlays.append(zoom)

    def _zoom_recalculation_method(self, mins: Tuple[float, float],
                                   maxs: Tuple[float, float]):
        """Callable for the RecalculatingZoomTool to recalculate/display the mandelbrot set on zoom events.

        :param mins: min positions in selected zoom range
        :param maxs: max positions in selected zoom range
        :return:
        """
        min_x, max_x = mins[0], maxs[0]
        min_y, max_y = mins[1], maxs[1]

        mandelbrot_C = self._recalculate_mandelbrot(min_x, max_x, min_y, max_y)
        self._update_plot_data(mandelbrot_C)
        self.container.invalidate_draw()
        self.container.request_redraw()
        self._fix_aspect_ratio()

    def _recalculate_mandelbrot(self, min_x: float, max_x: float, min_y: float,
                                max_y: float) -> numpy.ndarray:
        """Recalculate the mandelbrot set for a given region.

        :param min_x:
        :param max_x:
        :param min_y:
        :param max_y:
        :return:
        """
        z = self.mandelbrot_model.create_initial_array(min_x, max_x, min_y,
                                                       max_y)
        return self.mandelbrot_model.calculate_mandelbrot(z[:-1, :-1])

    def _mandelbrot_model_changed(self):
        """Method automatically called by Traits when mandelbrot_model attribute changes.

        Recalculates the mandelbrot set for the current range.
        """
        if self.mandelbrot_model.latest_xs is None:
            return
        min_x, max_x = (
            self.mandelbrot_model.latest_xs[-1],
            self.mandelbrot_model.latest_xs[0],
        )
        min_y, max_y = (
            self.mandelbrot_model.latest_ys[-1],
            self.mandelbrot_model.latest_ys[0],
        )
        mandelbrot_C = self._recalculate_mandelbrot(min_x, max_x, min_y, max_y)
        self._update_plot_data(mandelbrot_C)
        self.container.invalidate_draw()
        self.container.request_redraw()
        self._fix_aspect_ratio()

    def _colormap_changed(self):
        """Method automatically called by Traits when colormap attribute changes.

        Updates the color map for the image plot with the selected colormap.
        """
        self._cmap = default_colormaps.color_map_name_dict[self.colormap]
        if self.image_plot is not None:
            value_range = self.image_plot.color_mapper.range
            self.image_plot.color_mapper = self._cmap(value_range)
            self.container.request_redraw()

    def _reset_button_fired(self):
        """Method automatically called by Traits when reset_button fired.

        Resets to the initial range and recalculates the plot.
        """
        self._update_with_initial_plot_data()
        self._fix_aspect_ratio()

    def _update_with_initial_plot_data(self):
        """Creates the initial mesh-grid and mandelbrot values on the grid."""
        self._initial_z_array = self.mandelbrot_model.create_initial_array(
            *self._initial_region)
        mandelbrot_C = self.mandelbrot_model.calculate_mandelbrot(
            self._initial_z_array[:-1, :-1])
        self._update_plot_data(mandelbrot_C)
コード例 #7
0
class ImageGUI(HasTraits):
    
    # TO FIX : put here the last available shot
    #shot = File('L:\\data\\app3\\2011\\1108\\110823\\column_5200.ascii')
    #shot = File('/home/pmd/atomcool/lab/data/app3/2012/1203/120307/column_3195.ascii')

    #-- Shot traits
    shotdir = Directory('/home/pmd/atomcool/lab/data/app3/2012/1203/120320/')
    shots = List(Str)
    selectedshot = List(Str)
    namefilter = Str('column')

    #-- Report trait
    report = Str

    #-- Displayed analysis results
    number = Float
     
    #-- Column density plot container
    column_density = Instance(HPlotContainer)
    #---- Plot components within this container
    imgplot     = Instance(CMapImagePlot)
    cross_plot  = Instance(Plot)
    cross_plot2 = Instance(Plot)
    colorbar    = Instance(ColorBar)
    #---- Plot data
    pd = Instance(ArrayPlotData)
    #---- Colorbar 
    num_levels = Int(15)
    colormap = Enum(color_map_name_dict.keys())

    #-- Crosshair location
    cursor = Instance(BaseCursorTool)
    xy = DelegatesTo('cursor', prefix='current_position')
    xpos = Float(0.)
    ypos = Float(0.)
    xpos_read = Float(0.)
    ypos_read = Float(0.)
    cursor_group = Group( Group(Item('xpos', show_label=True), 
	                        Item('xpos_read', show_label=False, style="readonly"),
				orientation='horizontal'),
			  Group(Item('ypos', show_label=True), 
				Item('ypos_read', show_label=False, style="readonly"),
				orientation='horizontal'),
		          orientation='vertical', layout='normal',springy=True)

    
    #---------------------------------------------------------------------------
    # Traits View Definitions
    #---------------------------------------------------------------------------
    
    traits_view = View(
                    Group(
                      #Directory
                      Item( 'shotdir',style='simple', editor=DirectoryEditor(), width = 400, \
				      show_label=False, resizable=False ),
                      #Bottom
                      HSplit(
		        #-- Pane for shot selection
        	        Group(
		          Item( 'namefilter', show_label=False,springy=False),		
                          Item( 'shots',show_label=False, width=180, height= 360, \
					editor = TabularEditor(selected='selectedshot',\
					editable=False,multi_select=True,\
					adapter=SelectAdapter()) ),
			  cursor_group,
                          orientation='vertical',
		          layout='normal', ),

		        #-- Pane for column density plots
			Group(
			  Item('column_density',editor=ComponentEditor(), \
                                           show_label=False, width=600, height=500, \
                                           resizable=True ), 
			  Item('report',show_label=False, width=180, \
					springy=True, style='custom' ),
			  layout='tabbed', springy=True),

			#-- Pane for analysis results
			Group(
		          Item('number',show_label=False)
			  )
                      ),
                      orientation='vertical',
                      layout='normal',
                    ),
                  width=1400, height=500, resizable=True)
    
    #-- Pop-up view when Plot->Edit is selcted from the menu
    plot_edit_view = View(
                    Group(Item('num_levels'),
                          Item('colormap')),
                          buttons=["OK","Cancel"])
                          
    
    #---------------------------------------------------------------------------
    # Private Traits
    #---------------------------------------------------------------------------

    #-- Represents the region where the data set is defined
    _image_index = Instance(GridDataSource) 

    #-- Represents the data that will be plotted on the grid
    _image_value = Instance(ImageData)

    #-- Represents the color map that will be used
    _cmap = Trait(jet, Callable)
    
    
    #---------------------------------------------------------------------------
    # Public View interface
    #---------------------------------------------------------------------------

    def __init__(self, *args, **kwargs):
	#-- super is used to run the inherited __init__ method
	#-- this ensures that all the Traits machinery is properly setup
	#-- even though the __init__ method is overridden
        super(ImageGUI, self).__init__(*args, **kwargs)

	#-- after running the inherited __init__, a plot is created
        self.create_plot()



    def create_plot(self):

        #-- Create the index for the x an y axes and the range over
	#-- which they vary
        self._image_index = GridDataSource(array([]), array([]),
                                          sort_order=("ascending","ascending"))
        image_index_range = DataRange2D(self._image_index)
        
	#-- I believe this is what allows tracking the mouse
        self._image_index.on_trait_change(self._metadata_changed,
                                          "metadata_changed")


	#-- Create the image values and determine their range
        self._image_value = ImageData(data=array([]), value_depth=1)
        image_value_range = DataRange1D(self._image_value)
        
        # Create the image plot
        self.imgplot = CMapImagePlot( index=self._image_index,
                                      value=self._image_value,
                                      index_mapper=GridMapper(range=image_index_range),
                                      color_mapper=self._cmap(image_value_range),)
                                 

        # Add a left axis to the plot
        left = PlotAxis(orientation='left',
                        title= "axial",
                        mapper=self.imgplot.index_mapper._ymapper,
                        component=self.imgplot)
        self.imgplot.overlays.append(left)

        # Add a bottom axis to the plot
        bottom = PlotAxis(orientation='bottom',
                          title= "radial",
                          mapper=self.imgplot.index_mapper._xmapper,
                          component=self.imgplot)
        self.imgplot.overlays.append(bottom)


        # Add some tools to the plot
        self.imgplot.tools.append(PanTool(self.imgplot,drag_button="right",
                                            constrain_key="shift"))

        self.imgplot.overlays.append(ZoomTool(component=self.imgplot,
                                            tool_mode="box", always_on=False))

        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=image_value_range)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=self.imgplot,
                                 padding_top=self.imgplot.padding_top,
                                 padding_bottom=self.imgplot.padding_bottom,
                                 padding_right=40,
                                 resizable='v',
                                 width=30)


	# Add a cursor 
	self.cursor = CursorTool( self.imgplot, drag_button="left", color="white")
	# the cursor is a rendered component so it goes in the overlays list
	self.imgplot.overlays.append(self.cursor)
                        
        # Create the two cross plots
        self.pd = ArrayPlotData(line_index = array([]),
                                line_value = array([]),
                                scatter_index = array([]),
                                scatter_value = array([]),
                                scatter_color = array([]))

        self.cross_plot = Plot(self.pd, resizable="h")
        self.cross_plot.height = 100
        self.cross_plot.padding = 20
        self.cross_plot.plot(("line_index", "line_value"),
                             line_style="dot")
        self.cross_plot.plot(("scatter_index","scatter_value","scatter_color"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=6)

        self.cross_plot.index_range = self.imgplot.index_range.x_range

        self.pd.set_data("line_index2", array([]))
        self.pd.set_data("line_value2", array([]))
        self.pd.set_data("scatter_index2", array([]))
        self.pd.set_data("scatter_value2", array([]))
        self.pd.set_data("scatter_color2", array([]))

        self.cross_plot2 = Plot(self.pd, width = 140, orientation="v", resizable="v", padding=20, padding_bottom=160)
        self.cross_plot2.plot(("line_index2", "line_value2"),
                             line_style="dot")
        self.cross_plot2.plot(("scatter_index2","scatter_value2","scatter_color2"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=8)

        self.cross_plot2.index_range = self.imgplot.index_range.y_range


        # Create a container and add sub-containers and components
        self.column_density = HPlotContainer(padding=40, fill_padding=True,
                                        bgcolor = "white", use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.cross_plot)
	self.imgplot.padding =20
	inner_cont.add(self.imgplot)
        self.column_density.add(self.colorbar)
        self.column_density.add(inner_cont)
        self.column_density.add(self.cross_plot2)

    def update(self):
	#print self.cursor.current_index
	#self.cursor.current_position = 100.,100.
        self.shots = self.populate_shot_list()
	print self.selectedshot    
        imgdata, self.report = self.load_imagedata()
        if imgdata is not None:
            self.minz = imgdata.min()
            self.maxz = imgdata.max()
            self.colorbar.index_mapper.range.low = self.minz
            self.colorbar.index_mapper.range.high = self.maxz
            xs=numpy.linspace(0,imgdata.shape[0],imgdata.shape[0]+1)
            ys=numpy.linspace(0,imgdata.shape[1],imgdata.shape[1]+1)
            #print xs
            #print ys
            self._image_index.set_data(xs,ys)
            self._image_value.data = imgdata
            self.pd.set_data("line_index", xs)
            self.pd.set_data("line_index2",ys)
            self.column_density.invalidate_draw()
            self.column_density.request_redraw()                        

    def populate_shot_list(self):
        try:
            shot_list = os.listdir(self.shotdir)
	    fun = lambda x: iscol(x,self.namefilter)
            shot_list = filter( fun, shot_list)
	    shot_list = sorted(shot_list)
        except ValueError:
            print " *** Not a valid directory path ***"
        return shot_list

    def load_imagedata(self):
        try:
            directory = self.shotdir
	    if self.selectedshot == []:
		    filename = self.shots[0]
	    else:
		    filename = self.selectedshot[0]
            #shotnum = filename[filename.rindex('_')+1:filename.rindex('.ascii')]
	    shotnum = filename[:filename.index('_')]
        except ValueError:
            print " *** Not a valid path *** " 
            return None
        # Set data path
        # Prepare PlotData object
	print "Loading file #%s from %s" % (filename,directory)
        return import_data.load(directory,filename), import_data.load_report(directory,shotnum)


    #---------------------------------------------------------------------------
    # Event handlers
    #---------------------------------------------------------------------------
    
    def _selectedshot_changed(self):
	print self.selectedshot
        self.update()

    def _shots_changed(self):
        self.shots = self.populate_shot_list()
	return

    def _namefilter_changed(self):
	self.shots = self.populate_shot_list()
	return

  
    def _xpos_changed(self):
	self.cursor.current_position = self.xpos, self.ypos
    def _ypos_changed(self):
	self.cursor.current_position = self.xpos, self.ypos

    def _metadata_changed(self):
	self._xy_changed()
	    
    def _xy_changed(self):
	self.xpos_read = self.cursor.current_index[0]
	self.ypos_read = self.cursor.current_index[1]
	#print self.cursor.current_index
        """ This function takes out a cross section from the image data, based
        on the cursor selections, and updates the line and scatter
        plots."""
        self.cross_plot.value_range.low = self.minz
        self.cross_plot.value_range.high = self.maxz
        self.cross_plot2.value_range.low = self.minz
        self.cross_plot2.value_range.high = self.maxz
        if True:
            x_ndx, y_ndx = self.cursor.current_index
            if y_ndx and x_ndx:
                self.pd.set_data("line_value",
				self._image_value.data[:,y_ndx])
                self.pd.set_data("line_value2",
				self._image_value.data[x_ndx,:])
                xdata, ydata = self._image_index.get_data()
                xdata, ydata = xdata.get_data(), ydata.get_data()
                self.pd.set_data("scatter_index", array([ydata[y_ndx]]))
                self.pd.set_data("scatter_index2", array([xdata[x_ndx]]))
                self.pd.set_data("scatter_value",
                    array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_value2",
                    array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_color",
                    array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_color2",
                    array([self._image_value.data[y_ndx, x_ndx]]))
        else:
            self.pd.set_data("scatter_value", array([]))
            self.pd.set_data("scatter_value2", array([]))
            self.pd.set_data("line_value", array([]))
            self.pd.set_data("line_value2", array([]))

    def _colormap_changed(self):
        self._cmap = color_map_name_dict[self.colormap]
        if hasattr(self, "polyplot"):
            value_range = self.polyplot.color_mapper.range
            self.polyplot.color_mapper = self._cmap(value_range)
            value_range = self.cross_plot.color_mapper.range
            self.cross_plot.color_mapper = self._cmap(value_range)
            # FIXME: change when we decide how best to update plots using
            # the shared colormap in plot object
            self.cross_plot.plots["dot"][0].color_mapper = self._cmap(value_range)
            self.cross_plot2.plots["dot"][0].color_mapper = self._cmap(value_range)
            self.column_density.request_redraw()

    def _num_levels_changed(self):
        if self.num_levels > 3:
            self.polyplot.levels = self.num_levels
            self.lineplot.levels = self.num_levels
コード例 #8
0
class PlotUI(HasTraits):
    
    # container for all plots
    container = Instance(HPlotContainer)
    
    # Plot components within this container:
    polyplot = Instance(ContourPolyPlot)
    lineplot = Instance(ContourLinePlot)
    cross_plot = Instance(Plot)
    cross_plot2 = Instance(Plot)
    colorbar = Instance(ColorBar)
    
    # plot data
    pd = Instance(ArrayPlotData)

    # view options
    num_levels = Int(15)
    colormap = Enum(colormaps)
    
    #Traits view definitions:
    traits_view = View(
        Group(UItem('container', editor=ComponentEditor(size=(800,600)))),
        resizable=True)

    plot_edit_view = View(
        Group(Item('num_levels'),
              Item('colormap')),
              buttons=["OK","Cancel"])

    
    #---------------------------------------------------------------------------
    # Private Traits
    #---------------------------------------------------------------------------

    _image_index = Instance(GridDataSource)
    _image_value = Instance(ImageData)

    _cmap = Trait(default_colormaps.jet, Callable)

    #---------------------------------------------------------------------------
    # Public View interface
    #---------------------------------------------------------------------------

    def __init__(self, *args, **kwargs):
        super(PlotUI, self).__init__(*args, **kwargs)
        # FIXME: 'with' wrapping is temporary fix for infinite range in initial 
        # color map, which can cause a distracting warning print. This 'with'
        # wrapping should be unnecessary after fix in color_mapper.py.
        with errstate(invalid='ignore'):
            self.create_plot()

    def create_plot(self):

        # Create the mapper, etc
        self._image_index = GridDataSource(array([]),
                                          array([]),
                                          sort_order=("ascending","ascending"))
        image_index_range = DataRange2D(self._image_index)
        self._image_index.on_trait_change(self._metadata_changed,
                                          "metadata_changed")

        self._image_value = ImageData(data=array([]), value_depth=1)
        image_value_range = DataRange1D(self._image_value)



        # Create the contour plots
        self.polyplot = ContourPolyPlot(index=self._image_index,
                                        value=self._image_value,
                                        index_mapper=GridMapper(range=
                                            image_index_range),
                                        color_mapper=\
                                            self._cmap(image_value_range),
                                        levels=self.num_levels)

        self.lineplot = ContourLinePlot(index=self._image_index,
                                        value=self._image_value,
                                        index_mapper=GridMapper(range=
                                            self.polyplot.index_mapper.range),
                                        levels=self.num_levels)


        # Add a left axis to the plot
        left = PlotAxis(orientation='left',
                        title= "y",
                        mapper=self.polyplot.index_mapper._ymapper,
                        component=self.polyplot)
        self.polyplot.overlays.append(left)

        # Add a bottom axis to the plot
        bottom = PlotAxis(orientation='bottom',
                          title= "x",
                          mapper=self.polyplot.index_mapper._xmapper,
                          component=self.polyplot)
        self.polyplot.overlays.append(bottom)


        # Add some tools to the plot
        self.polyplot.tools.append(PanTool(self.polyplot,
                                           constrain_key="shift"))
        self.polyplot.overlays.append(ZoomTool(component=self.polyplot,
                                            tool_mode="box", always_on=False))
        self.polyplot.overlays.append(LineInspector(component=self.polyplot,
                                               axis='index_x',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               is_listener=True,
                                               color="white"))
        self.polyplot.overlays.append(LineInspector(component=self.polyplot,
                                               axis='index_y',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               color="white",
                                               is_listener=True))

        # Add these two plots to one container
        contour_container = OverlayPlotContainer(padding=20,
                                                 use_backbuffer=True,
                                                 unified_draw=True)
        contour_container.add(self.polyplot)
        contour_container.add(self.lineplot)


        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=image_value_range)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=self.polyplot,
                                 padding_top=self.polyplot.padding_top,
                                 padding_bottom=self.polyplot.padding_bottom,
                                 padding_right=40,
                                 resizable='v',
                                 width=30)

        self.pd = ArrayPlotData(line_index = array([]),
                                line_value = array([]),
                                scatter_index = array([]),
                                scatter_value = array([]),
                                scatter_color = array([]))

        self.cross_plot = Plot(self.pd, resizable="h")
        self.cross_plot.height = 100
        self.cross_plot.padding = 20
        self.cross_plot.plot(("line_index", "line_value"),
                             line_style="dot")
        self.cross_plot.plot(("scatter_index","scatter_value","scatter_color"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=8)

        self.cross_plot.index_range = self.polyplot.index_range.x_range

        self.pd.set_data("line_index2", array([]))
        self.pd.set_data("line_value2", array([]))
        self.pd.set_data("scatter_index2", array([]))
        self.pd.set_data("scatter_value2", array([]))
        self.pd.set_data("scatter_color2", array([]))

        self.cross_plot2 = Plot(self.pd, width = 140, orientation="v", 
                                resizable="v", padding=20, padding_bottom=160)
        self.cross_plot2.plot(("line_index2", "line_value2"),
                             line_style="dot")
        self.cross_plot2.plot(("scatter_index2",
                               "scatter_value2",
                               "scatter_color2"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=8)

        self.cross_plot2.index_range = self.polyplot.index_range.y_range



        # Create a container and add components
        self.container = HPlotContainer(padding=40, fill_padding=True,
                                        bgcolor = "white", use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.cross_plot)
        inner_cont.add(contour_container)
        self.container.add(self.colorbar)
        self.container.add(inner_cont)
        self.container.add(self.cross_plot2)


    def update(self, model):
        self.minz = model.minz
        self.maxz = model.maxz
        self.colorbar.index_mapper.range.low = self.minz
        self.colorbar.index_mapper.range.high = self.maxz
        self._image_index.set_data(model.xs, model.ys)
        self._image_value.data = model.zs
        self.pd.update_data(line_index=model.xs, line_index2=model.ys)
        self.container.invalidate_draw()
        self.container.request_redraw()


    #---------------------------------------------------------------------------
    # Event handlers
    #---------------------------------------------------------------------------

    def _metadata_changed(self, old, new):
        """ This function takes out a cross section from the image data, based
        on the line inspector selections, and updates the line and scatter
        plots."""

        self.cross_plot.value_range.low = self.minz
        self.cross_plot.value_range.high = self.maxz
        self.cross_plot2.value_range.low = self.minz
        self.cross_plot2.value_range.high = self.maxz
        if self._image_index.metadata.has_key("selections"):
            x_ndx, y_ndx = self._image_index.metadata["selections"]
            if y_ndx and x_ndx:
                xdata, ydata = self._image_index.get_data()
                xdata, ydata = xdata.get_data(), ydata.get_data()
                self.pd.update_data(
                    line_value=self._image_value.data[y_ndx,:],
                    line_value2=self._image_value.data[:,x_ndx],
                    scatter_index=array([xdata[x_ndx]]),
                    scatter_index2=array([ydata[y_ndx]]),
                    scatter_value=array([self._image_value.data[y_ndx, x_ndx]]),
                    scatter_value2=array([self._image_value.data[y_ndx, x_ndx]]),
                    scatter_color=array([self._image_value.data[y_ndx, x_ndx]]),
                    scatter_color2=array([self._image_value.data[y_ndx, x_ndx]])
                )
        else:
            self.pd.update_data({"scatter_value": array([]),
                "scatter_value2": array([]), "line_value": array([]),
                "line_value2": array([])})

    def _colormap_changed(self):
        self._cmap = default_colormaps.color_map_name_dict[self.colormap]
        if self.polyplot is not None:
            value_range = self.polyplot.color_mapper.range
            self.polyplot.color_mapper = self._cmap(value_range)
            value_range = self.cross_plot.color_mapper.range
            self.cross_plot.color_mapper = self._cmap(value_range)
            # FIXME: change when we decide how best to update plots using
            # the shared colormap in plot object
            self.cross_plot.plots["dot"
                                  ][0].color_mapper = self._cmap(value_range)
            self.cross_plot2.plots["dot"
                                   ][0].color_mapper = self._cmap(value_range)
            self.container.request_redraw()

    def _num_levels_changed(self):
        if self.num_levels > 3:
            self.polyplot.levels = self.num_levels
            self.lineplot.levels = self.num_levels
コード例 #9
0
ファイル: ucc.py プロジェクト: magnunor/analyzarr
class CellCropper(StackViewer):
    template = Array
    CC = Array
    peaks = List
    zero=Int(0)
    tmp_size = Range(low=2, high=512, value=64, cols=4)
    max_pos_x=Property(depends_on=['tmp_size'])
    max_pos_y=Property(depends_on=['tmp_size'])
    top = Range(low='zero',high='max_pos_y', value=20, cols=4)
    left = Range(low='zero',high='max_pos_x', value=20, cols=4)
    is_square = Bool
    tmp_plot = Instance(Plot)
    findpeaks = Button
    peak_width = Range(low=2, high=200, value=10)
    ShowCC = Bool
    numpeaks_total = Int(0,cols=5)
    numpeaks_img = Int(0,cols=5)
    OK_custom=OK_custom_handler
    thresh=Trait(None,None,List,Tuple,Array)
    thresh_upper=Float(1.0)
    thresh_lower=Float(-1.0)
    tmp_img_idx=Int(0)

    csr=Instance(BaseCursorTool)

    traits_view = View(
        Group(
            Group(
                Item("img_container",editor=ComponentEditor(), show_label=False),
                HGroup(
                    Item("ShowCC", editor=BooleanEditor(), label="Show cross correlation image"),
                    Spring(),
                    Item("prev_img",editor=ButtonEditor(label="<"),show_label=False, enabled_when='numfiles > 1'),
                    Item("next_img",editor=ButtonEditor(label=">"),show_label=False, enabled_when='numfiles > 1'),
                    ),
                label="Original image", show_border=True, trait_modified="tab_selected",
                orientation='vertical',),
            VGroup(
                Group(
                    HGroup(
                        Item("left", label="Left coordinate", style="custom"),
                        Spring(),
                        Item("top", label="Top coordinate", style="custom"),
                        ),
                    Item("tmp_size", label="Template size", style="custom"),
                    Item("tmp_plot",editor=ComponentEditor(height=256, width=256), show_label=False, resizable=True),
                    label="Template", show_border=True),
                Group(
                    HGroup(
                        Item("peak_width", label="Peak width", style="custom"),
                        Spring(),
                        Item("findpeaks",editor=ButtonEditor(label="Find Peaks"),show_label=False),
                        ),
                    HGroup(
                        Item("thresh_lower",label="Threshold Lower Value", editor=TextEditor(evaluate=float,
                                                                                             format_str='%1.4f')),
                        Spring(),
                        Item("thresh_upper",label="Threshold Upper Value", editor=TextEditor(evaluate=float,
                                                                                             format_str='%1.4f')),
                    ),
                    HGroup(
                        Item("numpeaks_img",label="Number of Cells selected (this image)",style='readonly'),
                        Spring(),
                        Item("numpeaks_total",label="Total",style='readonly'),                          
                        ),
                    label="Peak parameters", show_border=True),
                ),
            orientation='horizontal'),
        buttons = [ Action(name='OK', enabled_when = 'numpeaks_total > 0' ),
            CancelButton ],
        title="Template Picker",
        handler=OK_custom, kind='livemodal',
        key_bindings = key_bindings,
        width=940, height=530,resizable=True)        

    def __init__(self, controller, *args, **kw):
        super(CellCropper, self).__init__(controller, *args, **kw)
        try:
            import cv
        except:
            try:
                import cv2.cv as cv
            except:
                print "OpenCV unavailable.  Can't do cross correlation without it.  Aborting."
                return None
        self.OK_custom=OK_custom_handler()
        self.template = self.data[self.top:self.top+self.tmp_size,self.left:self.left+self.tmp_size]
        tmp_plot_data=ArrayPlotData(imagedata=self.template)
        tmp_plot=Plot(tmp_plot_data,default_origin="top left")
        tmp_plot.img_plot("imagedata", colormap=jet)
        tmp_plot.aspect_ratio=1.0
        self.tmp_plot=tmp_plot
        self.tmp_plotdata=tmp_plot_data
        self.crop_sig=None

    def render_image(self):
        plot = super(CellCropper,self).render_image()
        img=plot.img_plot("imagedata", colormap=gray)[0]
        csr = CursorTool(img, drag_button='left', color='white',
                         line_width=2.0)
        self.csr=csr
        csr.current_position=self.left, self.top
        img.overlays.append(csr)
        self.img_plot=plot
        return plot

    # TODO: use base class render_scatter_overlay method
    def render_scatplot(self):
        peakdata=ArrayPlotData()
        peakdata.set_data("index",self.peaks[self.img_idx][:,0])
        peakdata.set_data("value",self.peaks[self.img_idx][:,1])
        peakdata.set_data("color",self.peaks[self.img_idx][:,2])
        scatplot=Plot(peakdata,aspect_ratio=self.img_plot.aspect_ratio,default_origin="top left")
        scatplot.plot(("index", "value", "color"),
                      type="cmap_scatter",
                      name="my_plot",
                      color_mapper=jet(DataRange1D(low = 0.0,
                                       high = 1.0)),
                      marker = "circle",
                      fill_alpha = 0.5,
                      marker_size = 6,
                      )
        scatplot.x_grid.visible = False
        scatplot.y_grid.visible = False
        scatplot.range2d=self.img_plot.range2d
        self.scatplot=scatplot
        self.peakdata=peakdata
        return scatplot

    def _image_plot_container(self):
        plot = self.render_image()

        # Create a container to position the plot and the colorbar side-by-side
        self.container=OverlayPlotContainer()
        self.container.add(plot)
        self.img_container = HPlotContainer(use_backbuffer = False)
        self.img_container.add(self.container)
        self.img_container.bgcolor = "white"
        #ipdb.set_trace()

        if self.numpeaks_img>0:
            scatplot = self.render_scatplot()
            self.container.add(scatplot)
            colorbar = self.draw_colorbar()
            self.img_container.add(colorbar)
        return self.img_container

    def draw_colorbar(self):
        scatplot=self.scatplot
        cmap_renderer = scatplot.plots["my_plot"][0]
        selection = ColormappedSelectionOverlay(cmap_renderer, fade_alpha=0.35, 
                                                selection_type="range")
        cmap_renderer.overlays.append(selection)
        if self.thresh is not None:
            cmap_renderer.color_data.metadata['selections']=self.thresh
            cmap_renderer.color_data.metadata_changed={'selections':self.thresh}
        # Create the colorbar, handing in the appropriate range and colormap
        colormap=scatplot.color_mapper
        colorbar = ColorBar(
            index_mapper=LinearMapper(
                range=DataRange1D(
                    low = -1.0,
                    high = 1.0)
                ),
            orientation='v',
            resizable='v',
            width=30,
            padding=20)
        colorbar_selection=RangeSelection(component=colorbar)
        colorbar.tools.append(colorbar_selection)
        ovr=colorbar.overlays.append(
            RangeSelectionOverlay(component=colorbar,
                                  border_color="white",
                                  alpha=0.8,
                                  fill_color="lightgray",
                                  metadata_name='selections'))
        #ipshell('colorbar, colorbar_selection and ovr available:')
        self.cbar_selection=colorbar_selection
        self.cmap_renderer=cmap_renderer
        colorbar.plot = cmap_renderer
        colorbar.padding_top = scatplot.padding_top
        colorbar.padding_bottom = scatplot.padding_bottom
        self.colorbar=colorbar
        return colorbar

    @on_trait_change('ShowCC')
    def toggle_cc_view(self):
        if self.ShowCC:
            self.update_CC()
            grid_data_source = self.img_plot.range2d.sources[0]
            grid_data_source.set_data(np.arange(self.CC.shape[1]), 
                                      np.arange(self.CC.shape[0]))
        else:
            self.img_plotdata.set_data("imagedata",self.data)
        self.redraw_plots()

    @on_trait_change("img_idx")
    def update_img_depth(self):
        """
        TODO: We look up the index in the model - first get a list of files, then get the name
        of the file at the given index.
        """
        super(CellCropper, self).update_img_depth()
        if self.ShowCC:
            self.update_CC()
        self.redraw_plots()        

    def _get_max_pos_x(self):
        max_pos_x=self.data.shape[-1]-self.tmp_size-1
        if max_pos_x>0:
            return max_pos_x
        else:
            return None

    def _get_max_pos_y(self):
        max_pos_y=self.data.shape[-2]-self.tmp_size-1
        if max_pos_y>0:
            return max_pos_y
        else:
            return None

    @on_trait_change('left, top')
    def update_csr_position(self):
        if self.left>0:        
            self.csr.current_position=self.left,self.top

    @on_trait_change('csr:current_position')
    def update_top_left(self):
        if self.csr.current_position[0]>0 or self.csr.current_position[1]>0:
            if self.csr.current_position[0]>self.max_pos_x:
                if self.csr.current_position[1]<self.max_pos_y:
                    self.top=self.csr.current_position[1]
                else:
                    self.csr.current_position=self.max_pos_x, self.max_pos_y
            elif self.csr.current_position[1]>self.max_pos_y:
                self.left,self.top=self.csr.current_position[0],self.max_pos_y
            else:
                self.left,self.top=self.csr.current_position
        
    @on_trait_change('left, top, tmp_size')
    def update_tmp_plot(self):
        self.template = self.data[self.top:self.top+self.tmp_size,self.left:self.left+self.tmp_size]
        self.tmp_plotdata.set_data("imagedata", self.template)
        grid_data_source = self.tmp_plot.range2d.sources[0]
        grid_data_source.set_data(np.arange(self.tmp_size), np.arange(self.tmp_size))
        self.tmp_img_idx=self.img_idx
        if self.numpeaks_total>0:
            print "clearing peaks"
            self.peaks=[np.array([[0,0,-1]])]
        self.update_CC()
        return

    def update_CC(self):
        if self.ShowCC:
            self.CC = cv_funcs.xcorr(self.template, self.data)
            self.img_plotdata.set_data("imagedata",self.CC)

    @on_trait_change('cbar_selection:selection')
    def update_thresh(self):
        try:
            thresh=self.cbar_selection.selection
            self.thresh=thresh
            self.cmap_renderer.color_data.metadata['selections']=thresh
            self.thresh_lower=thresh[0]
            self.thresh_upper=thresh[1]
            #cmap_renderer.color_data.metadata['selection_masks']=self.thresh
            self.cmap_renderer.color_data.metadata_changed={'selections':thresh}
            self.container.request_redraw()
            self.img_container.request_redraw()
        except:
            pass

    @on_trait_change('thresh_upper,thresh_lower')
    def manual_thresh_update(self):
        self.thresh=[self.thresh_lower,self.thresh_upper]
        self.cmap_renderer.color_data.metadata['selections']=self.thresh
        self.cmap_renderer.color_data.metadata_changed={'selections':self.thresh}
        self.container.request_redraw()
        self.img_container.request_redraw()

    @on_trait_change('peaks,cbar_selection:selection,img_idx')
    def calc_numpeaks(self):
        try:
            thresh=self.cbar_selection.selection
            self.thresh=thresh
        except:
            thresh=[]
        if thresh==[] or thresh==() or thresh==None:
            thresh=(-1,1)
        self.numpeaks_total=int(np.sum([np.sum(np.ma.masked_inside(self.peaks[i][:,2],thresh[0],thresh[1]).mask) for i in xrange(len(self.peaks))]))
        try:
            self.numpeaks_img=int(np.sum(np.ma.masked_inside(self.peaks[self.img_idx][:,2],thresh[0],thresh[1]).mask))
        except:
            self.numpeaks_img=0

    @on_trait_change('findpeaks')
    def locate_peaks(self):
        peaks=[]
        progress = ProgressDialog(title="Peak finder progress", message="Finding peaks on %s images"%self.numfiles, max=self.numfiles, show_time=True, can_cancel=False)
        progress.open()
        for idx in xrange(self.numfiles):
            self.controller.set_active_index(idx)
            self.data = self.controller.get_active_image()[:]
            self.CC = cv_funcs.xcorr(self.template, self.data)
            # peak finder needs peaks greater than 1.  Multiply by 255 to scale them.
            pks=pc.two_dim_findpeaks(self.CC*255, peak_width=self.peak_width, medfilt_radius=None)
            pks[:,2]=pks[:,2]/255.
            peaks.append(pks)
            progress.update(idx+1)
        #ipdb.set_trace()
        self.peaks=peaks
        self.redraw_plots()
        
    def mask_peaks(self,idx):
        thresh=self.cbar_selection.selection
        if thresh==[]:
            thresh=(-1,1)
        mpeaks=np.ma.asarray(self.peaks[idx])
        mpeaks[:,2]=np.ma.masked_outside(mpeaks[:,2],thresh[0],thresh[1])
        return mpeaks

    @on_trait_change("peaks")
    def redraw_plots(self):
        oldplot=self.img_plot
        self.container.remove(oldplot)
        newplot=self.render_image()
        self.container.add(newplot)
        self.img_plot=newplot

        try:
            # if these haven't been created before, this will fail.  wrap in try to prevent that.
            oldscat=self.scatplot
            self.container.remove(oldscat)
            oldcolorbar = self.colorbar
            self.img_container.remove(oldcolorbar)
        except:
            pass

        if self.numpeaks_img>0:
            newscat=self.render_scatplot()
            self.container.add(newscat)
            self.scatplot=newscat
            colorbar = self.draw_colorbar()
            self.img_container.add(colorbar)
            self.colorbar=colorbar

        self.container.request_redraw()
        self.img_container.request_redraw()

    def crop_cells(self):
        print "cropping cells..."
        for idx in xrange(self.numfiles):
            # filter the peaks that are outside the selected threshold
            self.controller.set_active_index(idx)
            self.data = self.controller.get_active_image()
            self.name = self.controller.get_active_name()
            peaks=np.ma.compress_rows(self.mask_peaks(idx))
            tmp_sz=self.tmp_size
            data=np.zeros((peaks.shape[0],tmp_sz,tmp_sz))
            if data.shape[0] >0:
                for i in xrange(peaks.shape[0]):
                    # crop the cells from the given locations
                    data[i,:,:]=self.data[peaks[i,1]:peaks[i,1]+tmp_sz, 
                                      peaks[i,0]:peaks[i,0]+tmp_sz]
                # send the data to the controller for storage in the chest
                self.controller.add_cells(name = self.name, data = data,
                                      locations = peaks)
コード例 #10
0
class PlotUI(HasTraits):
    
    # container for all plots
    container = Instance(HPlotContainer)
    
    # Plot components within this container:
    polyplot = Instance(ContourPolyPlot)
    lineplot = Instance(ContourLinePlot)
    cross_plot = Instance(Plot)
    cross_plot2 = Instance(Plot)
    colorbar = Instance(ColorBar)
    
    # plot data
    pd = Instance(ArrayPlotData)

    # view options
    num_levels = Int(15)
    colormap = Enum(colormaps)
    
    #Traits view definitions:
    traits_view = View(
        Group(UItem('container', editor=ComponentEditor(size=(800,600)))),
        resizable=True)

    plot_edit_view = View(
        Group(Item('num_levels'),
              Item('colormap')),
              buttons=["OK","Cancel"])

    
    #---------------------------------------------------------------------------
    # Private Traits
    #---------------------------------------------------------------------------

    _image_index = Instance(GridDataSource)
    _image_value = Instance(ImageData)

    _cmap = Trait(default_colormaps.jet, Callable)

    #---------------------------------------------------------------------------
    # Public View interface
    #---------------------------------------------------------------------------

    def __init__(self, *args, **kwargs):
        super(PlotUI, self).__init__(*args, **kwargs)
        # FIXME: 'with' wrapping is temporary fix for infinite range in initial 
        # color map, which can cause a distracting warning print. This 'with'
        # wrapping should be unnecessary after fix in color_mapper.py.
        with errstate(invalid='ignore'):
            self.create_plot()

    def create_plot(self):

        # Create the mapper, etc
        self._image_index = GridDataSource(array([]),
                                          array([]),
                                          sort_order=("ascending","ascending"))
        image_index_range = DataRange2D(self._image_index)
        self._image_index.on_trait_change(self._metadata_changed,
                                          "metadata_changed")

        self._image_value = ImageData(data=array([]), value_depth=1)
        image_value_range = DataRange1D(self._image_value)



        # Create the contour plots
        self.polyplot = ContourPolyPlot(index=self._image_index,
                                        value=self._image_value,
                                        index_mapper=GridMapper(range=
                                            image_index_range),
                                        color_mapper=\
                                            self._cmap(image_value_range),
                                        levels=self.num_levels)

        self.lineplot = ContourLinePlot(index=self._image_index,
                                        value=self._image_value,
                                        index_mapper=GridMapper(range=
                                            self.polyplot.index_mapper.range),
                                        levels=self.num_levels)


        # Add a left axis to the plot
        left = PlotAxis(orientation='left',
                        title= "y",
                        mapper=self.polyplot.index_mapper._ymapper,
                        component=self.polyplot)
        self.polyplot.overlays.append(left)

        # Add a bottom axis to the plot
        bottom = PlotAxis(orientation='bottom',
                          title= "x",
                          mapper=self.polyplot.index_mapper._xmapper,
                          component=self.polyplot)
        self.polyplot.overlays.append(bottom)


        # Add some tools to the plot
        self.polyplot.tools.append(PanTool(self.polyplot,
                                           constrain_key="shift"))
        self.polyplot.overlays.append(ZoomTool(component=self.polyplot,
                                            tool_mode="box", always_on=False))
        self.polyplot.overlays.append(LineInspector(component=self.polyplot,
                                               axis='index_x',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               is_listener=True,
                                               color="white"))
        self.polyplot.overlays.append(LineInspector(component=self.polyplot,
                                               axis='index_y',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               color="white",
                                               is_listener=True))

        # Add these two plots to one container
        contour_container = OverlayPlotContainer(padding=20,
                                                 use_backbuffer=True,
                                                 unified_draw=True)
        contour_container.add(self.polyplot)
        contour_container.add(self.lineplot)


        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=image_value_range)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=self.polyplot,
                                 padding_top=self.polyplot.padding_top,
                                 padding_bottom=self.polyplot.padding_bottom,
                                 padding_right=40,
                                 resizable='v',
                                 width=30)

        self.pd = ArrayPlotData(line_index = array([]),
                                line_value = array([]),
                                scatter_index = array([]),
                                scatter_value = array([]),
                                scatter_color = array([]))

        self.cross_plot = Plot(self.pd, resizable="h")
        self.cross_plot.height = 100
        self.cross_plot.padding = 20
        self.cross_plot.plot(("line_index", "line_value"),
                             line_style="dot")
        self.cross_plot.plot(("scatter_index","scatter_value","scatter_color"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=8)

        self.cross_plot.index_range = self.polyplot.index_range.x_range

        self.pd.set_data("line_index2", array([]))
        self.pd.set_data("line_value2", array([]))
        self.pd.set_data("scatter_index2", array([]))
        self.pd.set_data("scatter_value2", array([]))
        self.pd.set_data("scatter_color2", array([]))

        self.cross_plot2 = Plot(self.pd, width = 140, orientation="v", resizable="v", padding=20, padding_bottom=160)
        self.cross_plot2.plot(("line_index2", "line_value2"),
                             line_style="dot")
        self.cross_plot2.plot(("scatter_index2","scatter_value2","scatter_color2"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=8)

        self.cross_plot2.index_range = self.polyplot.index_range.y_range



        # Create a container and add components
        self.container = HPlotContainer(padding=40, fill_padding=True,
                                        bgcolor = "white", use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.cross_plot)
        inner_cont.add(contour_container)
        self.container.add(self.colorbar)
        self.container.add(inner_cont)
        self.container.add(self.cross_plot2)


    def update(self, model):
        self.minz = model.minz
        self.maxz = model.maxz
        self.colorbar.index_mapper.range.low = self.minz
        self.colorbar.index_mapper.range.high = self.maxz
        self._image_index.set_data(model.xs, model.ys)
        self._image_value.data = model.zs
        self.pd.set_data("line_index", model.xs)
        self.pd.set_data("line_index2", model.ys)
        self.container.invalidate_draw()
        self.container.request_redraw()


    #---------------------------------------------------------------------------
    # Event handlers
    #---------------------------------------------------------------------------

    def _metadata_changed(self, old, new):
        """ This function takes out a cross section from the image data, based
        on the line inspector selections, and updates the line and scatter
        plots."""

        self.cross_plot.value_range.low = self.minz
        self.cross_plot.value_range.high = self.maxz
        self.cross_plot2.value_range.low = self.minz
        self.cross_plot2.value_range.high = self.maxz
        if self._image_index.metadata.has_key("selections"):
            x_ndx, y_ndx = self._image_index.metadata["selections"]
            if y_ndx and x_ndx:
                self.pd.set_data("line_value",
                                 self._image_value.data[y_ndx,:])
                self.pd.set_data("line_value2",
                                 self._image_value.data[:,x_ndx])
                xdata, ydata = self._image_index.get_data()
                xdata, ydata = xdata.get_data(), ydata.get_data()
                self.pd.set_data("scatter_index", array([xdata[x_ndx]]))
                self.pd.set_data("scatter_index2", array([ydata[y_ndx]]))
                self.pd.set_data("scatter_value",
                    array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_value2",
                    array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_color",
                    array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_color2",
                    array([self._image_value.data[y_ndx, x_ndx]]))
        else:
            self.pd.set_data("scatter_value", array([]))
            self.pd.set_data("scatter_value2", array([]))
            self.pd.set_data("line_value", array([]))
            self.pd.set_data("line_value2", array([]))

    def _colormap_changed(self):
        self._cmap = default_colormaps.color_map_name_dict[self.colormap]
        if self.polyplot is not None:
            value_range = self.polyplot.color_mapper.range
            self.polyplot.color_mapper = self._cmap(value_range)
            value_range = self.cross_plot.color_mapper.range
            self.cross_plot.color_mapper = self._cmap(value_range)
            # FIXME: change when we decide how best to update plots using
            # the shared colormap in plot object
            self.cross_plot.plots["dot"][0].color_mapper = self._cmap(value_range)
            self.cross_plot2.plots["dot"][0].color_mapper = self._cmap(value_range)
            self.container.request_redraw()

    def _num_levels_changed(self):
        if self.num_levels > 3:
            self.polyplot.levels = self.num_levels
            self.lineplot.levels = self.num_levels
コード例 #11
0
ファイル: image_plot_ui.py プロジェクト: xiaoyu-wu/phyreslib
class ImagePlotUI(HasTraits):
    # Image data source
    image_data_source = Instance(ImageDataSource)

    # container for all plots
    container = Instance(HPlotContainer)

    # Plot components within this container
    plot = Instance(CMapImagePlot)
    colorbar = Instance(ColorBar)

    # View options
    colormap = Enum(colormaps)

    # Traits view definitions:
    traits_view = View(Group(
        UItem('container', editor=ComponentEditor(size=(500, 450)))),
                       resizable=True)

    plot_edit_view = View(Group(Item('colormap')), buttons=["OK", "Cancel"])

    # -------------------------------------------------------------------------
    # Private Traits
    # -------------------------------------------------------------------------

    _image_index = Instance(GridDataSource)
    _image_value = Instance(ImageData)

    _cmap = Trait(default_colormaps.gray, Callable)

    # -------------------------------------------------------------------------
    # Public View interface
    # -------------------------------------------------------------------------

    def __init__(self, image_data_source=None):
        super(ImagePlotUI, self).__init__()
        with errstate(invalid='ignore'):
            self.create_plot()
        self.image_data_source = image_data_source

    def create_plot(self):

        # Create the mapper, etc
        self._image_index = GridDataSource(array([]),
                                           array([]),
                                           sort_order=("ascending",
                                                       "ascending"))
        image_index_range = DataRange2D(self._image_index)
        # self._image_index.on_trait_change(self._metadata_changed,
        #                                   "metadata_changed")

        self._image_value = ImageData(data=array([]), value_depth=1)
        image_value_range = DataRange1D(self._image_value)

        # Create the colormapped scalar plot
        self.plot = CMapImagePlot(
            index=self._image_index,
            index_mapper=GridMapper(range=image_index_range),
            value=self._image_value,
            value_mapper=self._cmap(image_value_range))

        # Add a left axis to the plot
        left = PlotAxis(orientation='left',
                        title="y",
                        mapper=self.plot.index_mapper._ymapper,
                        component=self.plot)
        self.plot.overlays.append(left)

        # Add a bottom axis to the plot
        bottom = PlotAxis(orientation='bottom',
                          title="x",
                          mapper=self.plot.index_mapper._xmapper,
                          component=self.plot)
        self.plot.overlays.append(bottom)

        # Add some tools to the plot
        self.plot.tools.append(PanTool(self.plot, constrain_key="shift"))
        self.plot.overlays.append(
            ZoomTool(component=self.plot, tool_mode="box", always_on=False))

        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=image_value_range)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=self.plot,
                                 padding_top=self.plot.padding_top,
                                 padding_bottom=self.plot.padding_bottom,
                                 padding_right=40,
                                 resizable='v',
                                 width=10)

        # Create a container and add components
        self.container = HPlotContainer(padding=40,
                                        fill_padding=True,
                                        bgcolor="white",
                                        use_backbuffer=False)
        self.container.add(self.colorbar)
        self.container.add(self.plot)

    @on_trait_change('image_data_source.data_source_changed')
    def update_plot(self):
        xs = self.image_data_source.xs
        ys = self.image_data_source.ys
        zs = self.image_data_source.zs
        self.colorbar.index_mapper.range.low = zs.min()
        self.colorbar.index_mapper.range.high = zs.max()
        self._image_index.set_data(xs, ys)
        self._image_value.data = zs
        self.container.invalidate_draw()
        self.container.request_redraw()

    # -------------------------------------------------------------------------
    # Event handlers
    # -------------------------------------------------------------------------

    def _colormap_changed(self):
        self._cmap = default_colormaps.color_map_name_dict[self.colormap]
        if self.plot is not None:
            value_range = self.plot.color_mapper.range
            self.plot.color_mapper = self._cmap(value_range)
            self.container.request_redraw()
コード例 #12
0
ファイル: ucc.py プロジェクト: Gazworth/hyperspy
class TemplatePicker(HasTraits):
    template = Array
    CC = Array
    peaks = List
    zero=Int(0)
    tmp_size = Range(low=2, high=512, value=64, cols=4)
    max_pos_x=Property(depends_on=['tmp_size'])
    max_pos_y=Property(depends_on=['tmp_size'])
    top = Range(low='zero',high='max_pos_y', value=20, cols=4)
    left = Range(low='zero',high='max_pos_x', value=20, cols=4)
    is_square = Bool
    img_plot = Instance(Plot)
    tmp_plot = Instance(Plot)
    findpeaks = Button
    next_img = Button
    prev_img = Button
    peak_width = Range(low=2, high=200, value=10)
    tab_selected = Event
    ShowCC = Bool
    img_container = Instance(Component)
    container = Instance(Component)
    colorbar= Instance(Component)
    numpeaks_total = Int(0,cols=5)
    numpeaks_img = Int(0,cols=5)
    OK_custom=OK_custom_handler
    cbar_selection = Instance(RangeSelection)
    cbar_selected = Event
    thresh=Trait(None,None,List,Tuple,Array)
    thresh_upper=Float(1.0)
    thresh_lower=Float(0.0)
    numfiles=Int(1)
    img_idx=Int(0)
    tmp_img_idx=Int(0)

    csr=Instance(BaseCursorTool)

    traits_view = View(
        Group(
            Group(
                Item("img_container",editor=ComponentEditor(), show_label=False),
                HGroup(
                    Item("ShowCC", editor=BooleanEditor(), label="Show cross correlation image"),
                    Spring(),
                    Item("prev_img",editor=ButtonEditor(label="<"),show_label=False, enabled_when='numfiles > 1'),
                    Item("next_img",editor=ButtonEditor(label=">"),show_label=False, enabled_when='numfiles > 1'),
                    ),
                label="Original image", show_border=True, trait_modified="tab_selected",
                orientation='vertical',),
            VGroup(
                Group(
                    HGroup(
                        Item("left", label="Left coordinate", style="custom"),
                        Spring(),
                        Item("top", label="Top coordinate", style="custom"),
                        ),
                    Item("tmp_size", label="Template size", style="custom"),
                    Item("tmp_plot",editor=ComponentEditor(height=256, width=256), show_label=False, resizable=True),
                    label="Template", show_border=True),
                Group(
                    HGroup(
                        Item("peak_width", label="Peak width", style="custom"),
                        Spring(),
                        Item("findpeaks",editor=ButtonEditor(label="Find Peaks"),show_label=False),
                        ),
                    HGroup(
                        Item("thresh_lower",label="Threshold Lower Value", editor=TextEditor(evaluate=float,
                                                                                             format_str='%1.4f')),
                        Spring(),
                        Item("thresh_upper",label="Threshold Upper Value", editor=TextEditor(evaluate=float,
                                                                                             format_str='%1.4f')),
                    ),
                    HGroup(
                        Item("numpeaks_img",label="Number of Cells selected (this image)",style='readonly'),
                        Spring(),
                        Item("numpeaks_total",label="Total",style='readonly'),                          
                        ),
                    label="Peak parameters", show_border=True),
                ),
            orientation='horizontal'),
        buttons = [ Action(name='OK', enabled_when = 'numpeaks_total > 0' ),
            CancelButton ],
        title="Template Picker",
        handler=OK_custom, kind='livemodal',
        key_bindings = key_bindings,
        width=940, height=530,resizable=True)        

    def __init__(self, signal_instance):
        super(TemplatePicker, self).__init__()
        try:
            import cv
        except:
            try:
                import cv2.cv as cv
            except:
                print "OpenCV unavailable.  Can't do cross correlation without it.  Aborting."
                return None
        self.OK_custom=OK_custom_handler()
        self.sig=signal_instance
        if not hasattr(self.sig.mapped_parameters,"original_files"):
            self.titles=[os.path.splitext(self.sig.mapped_parameters.title)[0]]
        else:
            self.numfiles=len(self.sig.mapped_parameters.original_files.keys())
            self.titles=self.sig.mapped_parameters.original_files.keys()
        tmp_plot_data=ArrayPlotData(imagedata=self.sig.data[self.img_idx,self.top:self.top+self.tmp_size,self.left:self.left+self.tmp_size])
        tmp_plot=Plot(tmp_plot_data,default_origin="top left")
        tmp_plot.img_plot("imagedata", colormap=jet)
        tmp_plot.aspect_ratio=1.0
        self.tmp_plot=tmp_plot
        self.tmp_plotdata=tmp_plot_data
        self.img_plotdata=ArrayPlotData(imagedata=self.sig.data[self.img_idx,:,:])
        self.img_container=self._image_plot_container()

        self.crop_sig=None

    def render_image(self):
        plot = Plot(self.img_plotdata,default_origin="top left")
        img=plot.img_plot("imagedata", colormap=gray)[0]
        plot.title="%s of %s: "%(self.img_idx+1,self.numfiles)+self.titles[self.img_idx]
        plot.aspect_ratio=float(self.sig.data.shape[2])/float(self.sig.data.shape[1])

        csr = CursorTool(img, drag_button='left', color='white',
                         line_width=2.0)
        self.csr=csr
        csr.current_position=self.left, self.top
        img.overlays.append(csr)

        # attach the rectangle tool
        plot.tools.append(PanTool(plot,drag_button="right"))
        zoom = ZoomTool(plot, tool_mode="box", always_on=False, aspect_ratio=plot.aspect_ratio)
        plot.overlays.append(zoom)
        self.img_plot=plot
        return plot

    def render_scatplot(self):
        peakdata=ArrayPlotData()
        peakdata.set_data("index",self.peaks[self.img_idx][:,0])
        peakdata.set_data("value",self.peaks[self.img_idx][:,1])
        peakdata.set_data("color",self.peaks[self.img_idx][:,2])
        scatplot=Plot(peakdata,aspect_ratio=self.img_plot.aspect_ratio,default_origin="top left")
        scatplot.plot(("index", "value", "color"),
                      type="cmap_scatter",
                      name="my_plot",
                      color_mapper=jet(DataRange1D(low = 0.0,
                                       high = 1.0)),
                      marker = "circle",
                      fill_alpha = 0.5,
                      marker_size = 6,
                      )
        scatplot.x_grid.visible = False
        scatplot.y_grid.visible = False
        scatplot.range2d=self.img_plot.range2d
        self.scatplot=scatplot
        self.peakdata=peakdata
        return scatplot

    def _image_plot_container(self):
        plot = self.render_image()

        # Create a container to position the plot and the colorbar side-by-side
        self.container=OverlayPlotContainer()
        self.container.add(plot)
        self.img_container = HPlotContainer(use_backbuffer = False)
        self.img_container.add(self.container)
        self.img_container.bgcolor = "white"

        if self.numpeaks_img>0:
            scatplot = self.render_scatplot()
            self.container.add(scatplot)
            colorbar = self.draw_colorbar()
            self.img_container.add(colorbar)
        return self.img_container

    def draw_colorbar(self):
        scatplot=self.scatplot
        cmap_renderer = scatplot.plots["my_plot"][0]
        selection = ColormappedSelectionOverlay(cmap_renderer, fade_alpha=0.35, 
                                                selection_type="range")
        cmap_renderer.overlays.append(selection)
        if self.thresh is not None:
            cmap_renderer.color_data.metadata['selections']=self.thresh
            cmap_renderer.color_data.metadata_changed={'selections':self.thresh}
        # Create the colorbar, handing in the appropriate range and colormap
        colormap=scatplot.color_mapper
        colorbar = ColorBar(index_mapper=LinearMapper(range=DataRange1D(low = 0.0,
                                       high = 1.0)),
                        orientation='v',
                        resizable='v',
                        width=30,
                        padding=20)
        colorbar_selection=RangeSelection(component=colorbar)
        colorbar.tools.append(colorbar_selection)
        ovr=colorbar.overlays.append(RangeSelectionOverlay(component=colorbar,
                                                   border_color="white",
                                                   alpha=0.8,
                                                   fill_color="lightgray",
                                                   metadata_name='selections'))
        #ipshell('colorbar, colorbar_selection and ovr available:')
        self.cbar_selection=colorbar_selection
        self.cmap_renderer=cmap_renderer
        colorbar.plot = cmap_renderer
        colorbar.padding_top = scatplot.padding_top
        colorbar.padding_bottom = scatplot.padding_bottom
        self.colorbar=colorbar
        return colorbar

    @on_trait_change('ShowCC')
    def toggle_cc_view(self):
        if self.ShowCC:
            self.update_CC()
            grid_data_source = self.img_plot.range2d.sources[0]
            grid_data_source.set_data(np.arange(self.CC.shape[1]), 
                                      np.arange(self.CC.shape[0]))
        else:
            self.img_plotdata.set_data("imagedata",self.sig.data[self.img_idx,:,:])
        self.redraw_plots()

    @on_trait_change("img_idx")
    def update_img_depth(self):
        if self.ShowCC:
            self.update_CC()
        else:
            self.img_plotdata.set_data("imagedata",self.sig.data[self.img_idx,:,:])
        self.img_plot.title="%s of %s: "%(self.img_idx+1,self.numfiles)+self.titles[self.img_idx]
        self.redraw_plots()        

    def _get_max_pos_x(self):
        max_pos_x=self.sig.data.shape[-1]-self.tmp_size-1
        if max_pos_x>0:
            return max_pos_x
        else:
            return None

    def _get_max_pos_y(self):
        max_pos_y=self.sig.data.shape[-2]-self.tmp_size-1
        if max_pos_y>0:
            return max_pos_y
        else:
            return None

    @on_trait_change('next_img')
    def increase_img_idx(self,info):
        if self.img_idx==(self.numfiles-1):
            self.img_idx=0
        else:
            self.img_idx+=1

    @on_trait_change('prev_img')
    def decrease_img_idx(self,info):
        if self.img_idx==0:
            self.img_idx=self.numfiles-1
        else:
            self.img_idx-=1

    @on_trait_change('left, top')
    def update_csr_position(self):
        if self.left>0:        
            self.csr.current_position=self.left,self.top

    @on_trait_change('csr:current_position')
    def update_top_left(self):
        if self.csr.current_position[0]>0 or self.csr.current_position[1]>0:
            if self.csr.current_position[0]>self.max_pos_x:
                if self.csr.current_position[1]<self.max_pos_y:
                    self.top=self.csr.current_position[1]
                else:
                    self.csr.current_position=self.max_pos_x, self.max_pos_y
            elif self.csr.current_position[1]>self.max_pos_y:
                self.left,self.top=self.csr.current_position[0],self.max_pos_y
            else:
                self.left,self.top=self.csr.current_position
        
    @on_trait_change('left, top, tmp_size')
    def update_tmp_plot(self):
        self.tmp_plotdata.set_data("imagedata", 
                                   self.sig.data[self.img_idx,self.top:self.top+self.tmp_size,self.left:self.left+self.tmp_size])
        grid_data_source = self.tmp_plot.range2d.sources[0]
        grid_data_source.set_data(np.arange(self.tmp_size), np.arange(self.tmp_size))
        self.tmp_img_idx=self.img_idx
        if self.numpeaks_total>0:
            print "clearing peaks"
            self.peaks=[np.array([[0,0,-1]])]
        return

    @on_trait_change('left, top, tmp_size')
    def update_CC(self):
        if self.ShowCC:
            self.CC = cv_funcs.xcorr(self.sig.data[self.tmp_img_idx,self.top:self.top+self.tmp_size,
                                                   self.left:self.left+self.tmp_size],
                                     self.sig.data[self.img_idx,:,:])
            self.img_plotdata.set_data("imagedata",self.CC)

    @on_trait_change('cbar_selection:selection')
    def update_thresh(self):
        try:
            thresh=self.cbar_selection.selection
            self.thresh=thresh
            self.cmap_renderer.color_data.metadata['selections']=thresh
            self.thresh_lower=thresh[0]
            self.thresh_upper=thresh[1]
            #cmap_renderer.color_data.metadata['selection_masks']=self.thresh
            self.cmap_renderer.color_data.metadata_changed={'selections':thresh}
            self.container.request_redraw()
            self.img_container.request_redraw()
        except:
            pass

    @on_trait_change('thresh_upper,thresh_lower')
    def manual_thresh_update(self):
        self.thresh=[self.thresh_lower,self.thresh_upper]
        self.cmap_renderer.color_data.metadata['selections']=self.thresh
        self.cmap_renderer.color_data.metadata_changed={'selections':self.thresh}
        self.container.request_redraw()
        self.img_container.request_redraw()

    @on_trait_change('peaks,cbar_selection:selection,img_idx')
    def calc_numpeaks(self):
        try:
            thresh=self.cbar_selection.selection
            self.thresh=thresh
        except:
            thresh=[]
        if thresh==[] or thresh==() or thresh==None:
            thresh=(0,1)
        self.numpeaks_total=int(np.sum([np.sum(np.ma.masked_inside(self.peaks[i][:,2],thresh[0],thresh[1]).mask) for i in xrange(len(self.peaks))]))
        try:
            self.numpeaks_img=int(np.sum(np.ma.masked_inside(self.peaks[self.img_idx][:,2],thresh[0],thresh[1]).mask))
        except:
            self.numpeaks_img=0

    @on_trait_change('findpeaks')
    def locate_peaks(self):
        from hyperspy import peak_char as pc
        peaks=[]
        """from hyperspy.misc.progressbar import ProgressBar, \
            Percentage, RotatingMarker, ETA, Bar
        widgets = ['Locating peaks: ', Percentage(), ' ', Bar(marker=RotatingMarker()),
                   ' ', ETA()]
        pbar = ProgressBar(widgets=widgets, maxval=100).start()"""
        progress = ProgressDialog(title="Peak finder progress", message="Finding peaks on %s images"%self.numfiles, max=self.numfiles, show_time=True, can_cancel=False)
        progress.open()
        for idx in xrange(self.numfiles):
            #pbar.update(float(idx)/self.numfiles*100)
            self.CC = cv_funcs.xcorr(self.sig.data[self.tmp_img_idx,
                                                   self.top:self.top+self.tmp_size,
                                                   self.left:self.left+self.tmp_size],
                                               self.sig.data[idx,:,:])
            # peak finder needs peaks greater than 1.  Multiply by 255 to scale them.
            pks=pc.two_dim_findpeaks(self.CC*255, peak_width=self.peak_width, medfilt_radius=None)
            pks[:,2]=pks[:,2]/255.
            peaks.append(pks)
            progress.update(idx+1)
        #pbar.finish()
        self.peaks=peaks
        
    def mask_peaks(self,idx):
        thresh=self.cbar_selection.selection
        if thresh==[]:
            thresh=(0,1)
        mpeaks=np.ma.asarray(self.peaks[idx])
        mpeaks[:,2]=np.ma.masked_outside(mpeaks[:,2],thresh[0],thresh[1])
        return mpeaks

    @on_trait_change("peaks")
    def redraw_plots(self):
        oldplot=self.img_plot
        self.container.remove(oldplot)
        newplot=self.render_image()
        self.container.add(newplot)
        self.img_plot=newplot

        try:
            # if these haven't been created before, this will fail.  wrap in try to prevent that.
            oldscat=self.scatplot
            self.container.remove(oldscat)
            oldcolorbar = self.colorbar
            self.img_container.remove(oldcolorbar)
        except:
            pass

        if self.numpeaks_img>0:
            newscat=self.render_scatplot()
            self.container.add(newscat)
            self.scatplot=newscat
            colorbar = self.draw_colorbar()
            self.img_container.add(colorbar)
            self.colorbar=colorbar

        self.container.request_redraw()
        self.img_container.request_redraw()

    def crop_cells_stack(self):
        from hyperspy.signals.aggregate import AggregateCells
        if self.numfiles==1:
            self.crop_sig=self.crop_cells()
            return
        else:
            crop_agg=[]
            for idx in xrange(self.numfiles):
                peaks=np.ma.compress_rows(self.mask_peaks(idx))
                if peaks.any():
                    crop_agg.append(self.crop_cells(idx))
            self.crop_sig=AggregateCells(*crop_agg)
            return

    def crop_cells(self,idx=0):
        print "cropping cells..."
        from hyperspy.signals.image import Image
        # filter the peaks that are outside the selected threshold
        peaks=np.ma.compress_rows(self.mask_peaks(idx))
        tmp_sz=self.tmp_size
        data=np.zeros((peaks.shape[0],tmp_sz,tmp_sz))
        if not hasattr(self.sig.mapped_parameters,"original_files"):
            parent=self.sig
        else:
            parent=self.sig.mapped_parameters.original_files[self.titles[idx]]
        pmp=parent.mapped_parameters
        positions=np.zeros((peaks.shape[0],1),dtype=[('filename','a256'),('id','i4'),('position','f4',(1,2))])
        for i in xrange(peaks.shape[0]):
            # crop the cells from the given locations
            data[i,:,:]=self.sig.data[idx,peaks[i,1]:peaks[i,1]+tmp_sz,peaks[i,0]:peaks[i,0]+tmp_sz]
            positions[i]=(self.titles[idx],i,peaks[i,:2])
            crop_sig=Image({'data':data,
                            'mapped_parameters':{
                               'title':'Cropped cells from %s'%self.titles[idx],
                               'record_by':'image',
                               'locations':positions,
                               'original_files':{pmp.title:parent},
                               }
                            })
            
        return crop_sig
コード例 #13
0
ファイル: plotter_2d.py プロジェクト: MatthieuDartiailh/PyHQC
class Plotter2D(HasPreferenceTraits):

    plot = Instance(Plot2D)
    colorbar = Instance(ColorBar)
    container = Instance(HPlotContainer)

    zoom_bar_plot = Instance(ZoomBar)
    zoom_bar_colorbar = Instance(ZoomBar)
    pan_bar = Instance(PanBar)
    range_bar = Instance(RangeBar)

    data = Instance(ArrayPlotData,())
    x_min = Float(0.0)
    x_max = Float(1.0)
    y_min = Float(0.0)
    y_max = Float(1.0)

    add_contour = Bool(False)
    x_axis_label = Str
    y_axis_label = Str
    c_axis_label = Str

    x_axis_formatter = Instance(AxisFormatter)
    y_axis_formatter = Instance(AxisFormatter)
    c_axis_formatter = Instance(AxisFormatter)

    colormap = Enum(color_map_name_dict.keys(), preference = 'async')
    _cmap = Trait(Greys, Callable)

    update_index = Event

    traits_view = View(
                    Group(
                        Group(
                            UItem('container', editor=ComponentEditor()),
                            VGroup(
                                UItem('zoom_bar_colorbar',style = 'custom'),
                                ),
                            orientation = 'horizontal',
                            ),
                        Group(
                            Group(
                                UItem('zoom_bar_plot', style = 'custom'),
                                UItem('pan_bar', style = 'custom'),
                                UItem('range_bar', style = 'custom'),
                                Group(
                                    UItem('colormap'),
                                    label = 'Color map',
                                    ),
                                orientation = 'horizontal',
                                ),
                            orientation = 'vertical',
                            ),
                        orientation = 'vertical',
                        ),
                    resizable=True
                    )

    preference_view = View(
                        HGroup(
                            VGroup(
                                Item('x_axis_formatter', style = 'custom',
                                     editor = InstanceEditor(
                                                 view = 'preference_view'),
                                     label = 'X axis',
                                     ),
                                Item('y_axis_formatter', style = 'custom',
                                     editor = InstanceEditor(
                                                 view = 'preference_view'),
                                     label = 'Y axis',
                                     ),
                                ),
                            Item('c_axis_formatter', style = 'custom',
                                 editor = InstanceEditor(
                                             view = 'preference_view'),
                                 label = 'C axis',
                                 ),
                            show_border = True,
                            label = 'Axis format',
                            ),
                        )

    def __init__(self, **kwargs):

        super(Plotter2D, self).__init__(**kwargs)

        self.x_axis_formatter = AxisFormatter(pref_name = 'X axis format',
                                              pref_parent = self)
        self.y_axis_formatter = AxisFormatter(pref_name = 'Y axis format',
                                              pref_parent = self)
        self.c_axis_formatter = AxisFormatter(pref_name = 'C axis format',
                                              pref_parent = self)

        self.data = ArrayPlotData()
        self.plot = Plot2D(self.data)
        self.plot.padding = (80,50,10,40)
        self.plot.x_axis.tick_label_formatter =\
                        self.x_axis_formatter.float_format
        self.plot.y_axis.tick_label_formatter =\
                        self.y_axis_formatter.float_format
        self.pan_bar = PanBar(self.plot)
        self.zoom_bar_plot = zoom_bar(self.plot,x = True,\
                                        y = True, reset = True
                                        )

        #Dummy plot so that the color bar can be correctly initialized
        xs = linspace(-2, 2, 600)
        ys = linspace(-1.2, 1.2, 300)
        self.x_min = xs[0]
        self.x_max = xs[-1]
        self.y_min = ys[0]
        self.y_max = ys[-1]
        x, y = meshgrid(xs,ys)
        z = tanh(x*y/6)*cosh(exp(-y**2)*x/3)
        z = x*y
        self.data.set_data('c',z)
        self.plot.img_plot(('c'),\
                                name = 'c',
                                colormap = self._cmap,
                                xbounds = (self.x_min,self.x_max),
                                ybounds = (self.y_min,self.y_max),
                                )

        # Create the colorbar, the appropriate range and colormap are handled
        # at the plot creation

        self.colorbar = ColorBar(
                            index_mapper = LinearMapper(range = \
                                            self.plot.color_mapper.range),
                            color_mapper=self.plot.color_mapper,
                            plot = self.plot,
                            orientation='v',
                            resizable='v',
                            width=20,
                            padding=10)

        self.colorbar.padding_top = self.plot.padding_top
        self.colorbar.padding_bottom = self.plot.padding_bottom

        self.colorbar._axis.tick_label_formatter =\
                                self.c_axis_formatter.float_format

        self.container = HPlotContainer(self.plot,
                                    self.colorbar,
                                    use_backbuffer=True,
                                    bgcolor="lightgray")

        # Add pan and zoom tools to the colorbar
        self.colorbar.tools.append(PanTool(self.colorbar,\
                                        constrain_direction="y",\
                                        constrain=True)
                                )
        self.zoom_bar_colorbar = zoom_bar(self.colorbar,
                                          box = False,
                                          reset=True,
                                          orientation = 'vertical'
                                        )

        # Add the range bar now that we are sure that we have a color_mapper
        self.range_bar = RangeBar(self.plot)
        self.x_axis_label = 'X'
        self.y_axis_label = 'Y'
        self.c_axis_label = 'C'
        self.sync_trait('x_axis_label',self.range_bar,alias = 'x_name')
        self.sync_trait('y_axis_label',self.range_bar,alias = 'y_name')
        self.sync_trait('c_axis_label',self.range_bar,alias = 'c_name')

        #Dynamically bing the update methods for trait likely to be updated
        #from other thread
        self.on_trait_change(self.new_x_label, 'x_axis_label',
                             dispatch = 'ui')
        self.on_trait_change(self.new_y_label, 'y_axis_label',
                             dispatch = 'ui')
        self.on_trait_change(self.new_c_label, 'c_axis_label',
                             dispatch = 'ui')
        self.on_trait_change(self.new_x_axis_format, 'x_axis_formatter.+',
                             dispatch = 'ui')
        self.on_trait_change(self.new_y_axis_format, 'y_axis_formatter.+',
                             dispatch = 'ui')
        self.on_trait_change(self.new_c_axis_format, 'c_axis_formatter.+',
                             dispatch = 'ui')
        self.on_trait_change(self._update_plots_index, 'update_index',
                             dispatch = 'ui')

        #set the default colormap in the editor
        self.colormap = 'Blues'

        self.preference_init()

    #@on_trait_change('x_axis_label', dispatch = 'ui')
    def new_x_label(self,new):
        self.plot.x_axis.title = new

    #@on_trait_change('y_axis_label', dispatch = 'ui')
    def new_y_label(self,new):
        self.plot.y_axis.title = new

    #@on_trait_change('c_axis_label', dispatch = 'ui')
    def new_c_label(self,new):
        self.colorbar._axis.title = new

    @on_trait_change('colormap')
    def new_colormap(self, new):
        self._cmap = color_map_name_dict[new]
        for plots in self.plot.plots.itervalues():
            for plot in plots:
                if isinstance(plot,ImagePlot) or\
                    isinstance(plot,CMapImagePlot) or\
                    isinstance(plot,ContourPolyPlot):
                    value_range = plot.color_mapper.range
                    plot.color_mapper = self._cmap(value_range)
                    self.plot.color_mapper = self._cmap(value_range)

        self.container.request_redraw()

    #@on_trait_change('x_axis_formatter', dispatch = 'ui')
    def new_x_axis_format(self):
        self.plot.x_axis._invalidate()
        self.plot.invalidate_and_redraw()

    #@on_trait_change('y_axis_formatter', dispatch = 'ui')
    def new_y_axis_format(self):
        self.plot.y_axis._invalidate()
        self.plot.invalidate_and_redraw()

    #@on_trait_change('y_axis_formatter', dispatch = 'ui')
    def new_c_axis_format(self):
        self.colorbar._axis._invalidate()
        self.plot.invalidate_and_redraw()

    def request_update_plots_index(self):
        self.update_index = True

    #@on_trait_change('update_index', dispatch = 'ui')
    def _update_plots_index(self):
        if 'c' in self.data.list_data():
            array = self.data.get_data('c')
            xs = linspace(self.x_min, self.x_max, array.shape[1] + 1)
            ys = linspace(self.y_min, self.y_max, array.shape[0] + 1)
            self.plot.range2d.remove(self.plot.index)
            self.plot.index = GridDataSource(xs, ys,
                                        sort_order=('ascending', 'ascending'))
            self.plot.range2d.add(self.plot.index)
            for plots in self.plot.plots.itervalues():
                for plot in plots:
                    plot.index = GridDataSource(xs, ys,
                                        sort_order=('ascending', 'ascending'))
コード例 #14
0
class Demo(HasTraits):
    pd = Instance(ArrayPlotData, ())
    plot = Instance(HPlotContainer)
    
    _load_file = File(
        find_resource('imageAlignment', '../images/GIRLS-IN-SPACE.jpg',
        '../images/GIRLS-IN-SPACE.jpg', return_path=True))
    _save_file = File
    
    load_file_view = View(
        Item('_load_file'),
        buttons=OKCancelButtons,
        kind='livemodal',
        width=400,
        resizable=True,
    )
    
    save_file_view = View(
        Item('_save_file'),
        buttons=OKCancelButtons,
        kind='livemodal',
        width=400,
        resizable=True,
    )
    
    def __init__(self, *args, **kwargs):
        super(Demo, self).__init__(*args, **kwargs)
        
        from imread import imread
        imarray = imread(find_resource('imageAlignment', '../images/GIRLS-IN-SPACE.jpg',
            '../images/GIRLS-IN-SPACE.jpg', return_path=True))
        
        self.pd = ArrayPlotData(imagedata=imarray)
        #self.pd.x_axis.orientation = "top"
        self.plot = HPlotContainer()
        
        titles = ["I KEEP DANCE", "ING ON MY OWN"]
        
        
        self._load()
        
        i = 0
        for plc in [Plot, Plot]:
            xs = linspace(0, 334*pi, 333)
            ys = linspace(0, 334*pi, 333)
            x, y = meshgrid(xs,ys)
            z = tanh(x*y/6)*cosh(exp(-y**2)*x/3)
            z = x*y
            
            _pd = ArrayPlotData()
            _pd.set_data("drawdata", z)
            _pd.set_data("imagedata", self.pd.get_data('imagedata'))
            
            plc = Plot(_pd,
                title="render_style = hold",
                padding=50, border_visible=True, overlay_border=True)
            
            self.plot.add(plc)
            
            plc.img_plot("imagedata",
                alpha=0.95)
            
            # Create a contour polygon plot of the data
            plc.contour_plot("drawdata",
                              type="poly",
                              poly_cmap=jet,
                              xbounds=(0, 499),
                              ybounds=(0, 582),
                              alpha=0.35)
            
            # Create a contour line plot for the data, too
            plc.contour_plot("drawdata",
                              type="line",
                              xbounds=(0, 499),
                              ybounds=(0, 582),
                              alpha=0.35)
            
            # Create a plot data obect and give it this data
            plc.legend.visible = True
            plc.title = titles[i]
            i += 1
            
            #plc.plot(("index", "y0"), name="j_0", color="red", render_style="hold")
            
            #plc.padding = 50
            #plc.padding_top = 75
            plc.tools.append(PanTool(plc))
            zoom = ZoomTool(component=plc, tool_mode="box", always_on=False)
            plc.overlays.append(zoom)
            
            # Tweak some of the plot properties
            plc.padding = 50
            #zoom = ZoomTool(component=plot1, tool_mode="box", always_on=False)
            #plot1.overlays.append(zoom)
            
            # Attach some tools to the plot
            #attach_tools(plc)
            plc.bg_color = None
            plc.fill_padding = True
            

    def default_traits_view(self):
        traits_view = View(
            Group(
                Item('plot',
                    editor=ComponentEditor(size=size),
                    show_label=False),
                orientation="vertical"),
            menubar=MenuBar(
                Menu(Action(name="Save Plot", action="save"),
                     Action(name="Load Plot", action="load"),
                     Separator(), CloseAction,
                     name="File")),
            resizable=True,
            title=title,
            handler=ImageFileController)
        return traits_view
    
    '''
    def _plot_default(self):
        
        # Create some x-y data series to plot
        x = linspace(-2.0, 10.0, 400)
        self.pd = pd = ArrayPlotData(index=x, y0=jn(0,x), default_origin="top left")

        # Create some line plots of some of the data
        plot1 = Plot(self.pd,
            title="render_style = hold",
            padding=50, border_visible=True, overlay_border=True)

        plot1.legend.visible = True
        plot1.plot(("index", "y0"), name="j_0", color="red", render_style="hold")
        
        plot1.padding = 50
        plot1.padding_top = 75
        plot1.tools.append(PanTool(plot1))
        #zoom = ZoomTool(component=plot1, tool_mode="box", always_on=False)
        #plot1.overlays.append(zoom)

        # Attach some tools to the plot
        attach_tools(plot1)

        # Create a second scatter plot of one of the datasets, linking its
        # range to the first plot
        plot2 = Plot(self.pd, range2d=plot1.range2d,
            title="render_style = connectedhold",
            padding=50, border_visible=True, overlay_border=True)

        plot2.plot(('index', 'y0'), color="blue", render_style="connectedhold")
        
        plot2.padding = 50
        plot2.padding_top = 75
        plot2.tools.append(PanTool(plot2))
        #zoom = ZoomTool(component=plot2, tool_mode="box", always_on=False)
        #plot2.overlays.append(zoom)
        
        attach_tools(plot2)

        # Create a container and add our plots
        container = HPlotContainer()
        container.add(plot1)
        container.add(plot2)
        return container
    '''
    
    def _save(self):
        win_size = self.plot.outer_bounds
        plot_gc = PlotGraphicsContext(win_size)
        plot_gc.render_component(self.plot)
        plot_gc.save(self._save_file)

    def _load(self):
        try:
            image = ImageData.fromfile(self._load_file)
            self.pd.set_data('imagedata', image._data)
            self.plot.title = "YO DOGG: %s" % os.path.basename(self._load_file)
            self.plot.request_redraw()
        except Exception, exc:
            print "YO DOGG: %s" % exc
コード例 #15
0
ファイル: plot2d.py プロジェクト: MatthieuDartiailh/HQCSim
class Plot2D(BasePlot):
    """
    """
    #: Colorbar of the plot (ideally this should be abstracted away)
    colorbar = Typed(ColorBar)

    zoom_colorbar = Typed(ZoomBar)

    #: Container for Chaco components
    container = Typed(HPlotContainer)

    #: Infos object used to retrieve the data.
    c_info = Typed(AbstractInfo)

    #: Bounds of the plot.
    x_min = Float()
    x_max = Float(1.0)
    y_min = Float()
    y_max = Float(1.0)

    #: Axis labels.
    x_axis = Str()
    y_axis = Str()
    c_axis = Str()

    #: Known colormaps from which the user can choose.
    colormap = Enum(*sorted(color_map_name_dict.keys())).tag(pref=True)

    #: Currently selected colormap.
    _cmap = Value(Greys)

    def __init__(self, **kwargs):

        super(Plot2D, self).__init__(**kwargs)

        self.renderer = ChacoPlot2D(self.data)
        self.renderer.padding = (80, 50, 10, 40)

        # Dummy plot so that the color bar can be correctly initialized
        xs = linspace(-2, 2, 600)
        ys = linspace(-1.2, 1.2, 300)
        x, y = meshgrid(xs, ys)
        z = tanh(x*y/6)*cosh(exp(-y**2)*x/3)
        z = x*y
        self.data.set_data('c', z)
        self.renderer.img_plot(('c'), name='c',
                               colormap=self._cmap,
                               xbounds=(self.x_min, self.x_max),
                               ybounds=(self.y_min, self.y_max))

        # Add basic tools and ways to activate them in public API
        zoom = BetterSelectingZoom(self.renderer, tool_mode="box",
                                   always_on=False)
        self.renderer.overlays.append(zoom)
        self.renderer.tools.append(PanTool(self.renderer,
                                           restrict_to_data=True))

        # Create the colorbar, the appropriate range and colormap are handled
        # at the plot creation
        mapper = LinearMapper(range=self.renderer.color_mapper.range)
        self.colorbar = ColorBar(index_mapper=mapper,
                                 color_mapper=self.renderer.color_mapper,
                                 plot=self.renderer,
                                 orientation='v',
                                 resizable='v',
                                 width=20,
                                 padding=10)

        self.colorbar.padding_top = self.renderer.padding_top
        self.colorbar.padding_bottom = self.renderer.padding_bottom

        self.container = HPlotContainer(self.renderer,
                                        self.colorbar,
                                        use_backbuffer=True,
                                        bgcolor="lightgray")

        # Add pan and zoom tools to the colorbar
        self.colorbar.tools.append(PanTool(self.colorbar,
                                           constrain_direction="y",
                                           constrain=True)
                                   )
        self.zoom_colorbar = zoom_bar(self.colorbar,
                                      box=False,
                                      reset=True,
                                      orientation='vertical'
                                      )
        self.colormap = 'Blues'

    def export_data(self, path):
        """
        """
        if not path.endswith('.dat'):
            path += '.dat'
        header = self.experiment.make_header()
        header += '\n' + self.c_info.make_header(self.experiment)
        arr = self.data.get_data('c')

        with open(path, 'wb') as f:
            header = ['#' + l for l in header.split('\n') if l]
            f.write('\n'.join(header) + '\n')
            savetxt(f, arr, fmt='%.6e', delimiter='\t')

    def auto_scale(self):
        """
        """
        self.renderer.range2d.set_bounds(('auto', 'auto'), ('auto', 'auto'))
        self.renderer.color_mapper.range.set_bounds('auto', 'auto')

    # For the time being stage is unused (will try to refine stuff if it is
    # needed)
    def update_data(self, stage):
        """
        """
        exp = self.experiment
        if self.c_info:
            data = self.c_info.gather_data(exp)
            if len(data.shape) == 2:
                self.data.set_data('c', data.T)
                self.update_plots_index()

    def update_plots_index(self):
        if 'c' in self.data.list_data():
            array = self.data.get_data('c')
            xs = linspace(self.x_min, self.x_max, array.shape[1] + 1)
            ys = linspace(self.y_min, self.y_max, array.shape[0] + 1)
            self.renderer.range2d.remove(self.renderer.index)
            self.renderer.index = GridDataSource(xs, ys,
                                                 sort_order=('ascending',
                                                             'ascending'))
            self.renderer.range2d.add(self.renderer.index)
            for plots in self.renderer.plots.itervalues():
                for plot in plots:
                    plot.index = GridDataSource(xs, ys,
                                                sort_order=('ascending',
                                                            'ascending'))

    @classmethod
    def build_view(cls, plot):
        """
        """
        return Plot2DItem(plot=plot)

    def preferences_from_members(self):
        """
        """
        d = super(Plot2D, self).preferences_from_members()
        d['c_info'] = self.c_info.preferences_from_members()

        return d

    def update_members_from_preferences(self, config):
        """
        """
        super(Plot2D, self).update_members_from_preferences(config)
        c_config = config['c_info']
        info = [c for c in DATA_INFOS
                if c.__name__ == c_config['info_class']][0]()
        info.update_members_from_preferences(c_config)

        self.c_info = info
        self.update_data(None)

    def _post_setattr_x_axis(self, old, new):
        self.renderer.x_axis.title = new
        self.container.request_redraw()

    def _post_setattr_y_axis(self, old, new):
        self.renderer.y_axis.title = new
        self.container.request_redraw()

    def _post_setattr_c_axis(self, old, new):
        self.colorbar._axis.title = new
        self.container.request_redraw()

    def _post_setattr_colormap(self, old, new):
        self._cmap = color_map_name_dict[new]
        for plots in self.renderer.plots.itervalues():
            for plot in plots:
                if isinstance(plot, ImagePlot) or\
                        isinstance(plot, CMapImagePlot) or\
                        isinstance(plot, ContourPolyPlot):
                    value_range = plot.color_mapper.range
                    plot.color_mapper = self._cmap(value_range)
                    self.renderer.color_mapper = self._cmap(value_range)

        self.container.request_redraw()

    def _post_setattr_x_min(self, old, new):
        """
        """
        self.update_plots_index()

    def _post_setattr_x_max(self, old, new):
        """
        """
        self.update_plots_index()

    def _post_setattr_y_min(self, old, new):
        """
        """
        self.update_plots_index()

    def _post_setattr_y_max(self, old, new):
        """
        """
        self.update_plots_index()
コード例 #16
0
class ImageGUI(HasTraits):

    # TO FIX : put here the last available shot
    shot = File("L:\\data\\app3\\2011\\1108\\110823\\column_5200.ascii")

    # ---------------------------------------------------------------------------
    # Traits View Definitions
    # ---------------------------------------------------------------------------

    traits_view = View(
        HSplit(
            Item(
                "shot",
                style="custom",
                editor=FileEditor(filter=["column_*.ascii"]),
                show_label=False,
                resizable=True,
                width=400,
            ),
            Item("container", editor=ComponentEditor(), show_label=False, width=800, height=800),
        ),
        width=1200,
        height=800,
        resizable=True,
        title="APPARATUS 3 :: Analyze Images",
    )

    plot_edit_view = View(Group(Item("num_levels"), Item("colormap")), buttons=["OK", "Cancel"])

    num_levels = Int(15)
    colormap = Enum(color_map_name_dict.keys())

    # ---------------------------------------------------------------------------
    # Private Traits
    # ---------------------------------------------------------------------------

    _image_index = Instance(GridDataSource)
    _image_value = Instance(ImageData)

    _cmap = Trait(jet, Callable)

    # ---------------------------------------------------------------------------
    # Public View interface
    # ---------------------------------------------------------------------------

    def __init__(self, *args, **kwargs):
        super(ImageGUI, self).__init__(*args, **kwargs)
        self.create_plot()

    def create_plot(self):

        # Create the mapper, etc
        self._image_index = GridDataSource(array([]), array([]), sort_order=("ascending", "ascending"))
        image_index_range = DataRange2D(self._image_index)

        self._image_index.on_trait_change(self._metadata_changed, "metadata_changed")

        self._image_value = ImageData(data=array([]), value_depth=1)
        image_value_range = DataRange1D(self._image_value)

        # Create the image plot
        self.imgplot = CMapImagePlot(
            index=self._image_index,
            value=self._image_value,
            index_mapper=GridMapper(range=image_index_range),
            color_mapper=self._cmap(image_value_range),
        )

        # Create the contour plots
        # ~ self.polyplot = ContourPolyPlot(index=self._image_index,
        # ~ value=self._image_value,
        # ~ index_mapper=GridMapper(range=
        # ~ image_index_range),
        # ~ color_mapper=\
        # ~ self._cmap(image_value_range),
        # ~ levels=self.num_levels)

        # ~ self.lineplot = ContourLinePlot(index=self._image_index,
        # ~ value=self._image_value,
        # ~ index_mapper=GridMapper(range=
        # ~ self.polyplot.index_mapper.range),
        # ~ levels=self.num_levels)

        # Add a left axis to the plot
        left = PlotAxis(
            orientation="left", title="axial", mapper=self.imgplot.index_mapper._ymapper, component=self.imgplot
        )
        self.imgplot.overlays.append(left)

        # Add a bottom axis to the plot
        bottom = PlotAxis(
            orientation="bottom", title="radial", mapper=self.imgplot.index_mapper._xmapper, component=self.imgplot
        )
        self.imgplot.overlays.append(bottom)

        # Add some tools to the plot
        # ~ self.polyplot.tools.append(PanTool(self.polyplot,
        # ~ constrain_key="shift"))
        self.imgplot.overlays.append(ZoomTool(component=self.imgplot, tool_mode="box", always_on=False))
        self.imgplot.overlays.append(
            LineInspector(
                component=self.imgplot,
                axis="index_x",
                inspect_mode="indexed",
                write_metadata=True,
                is_listener=False,
                color="white",
            )
        )
        self.imgplot.overlays.append(
            LineInspector(
                component=self.imgplot,
                axis="index_y",
                inspect_mode="indexed",
                write_metadata=True,
                color="white",
                is_listener=False,
            )
        )

        # Add these two plots to one container
        contour_container = OverlayPlotContainer(padding=20, use_backbuffer=True, unified_draw=True)
        contour_container.add(self.imgplot)
        # ~ contour_container.add(self.polyplot)
        # ~ contour_container.add(self.lineplot)

        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=image_value_range)
        self.colorbar = ColorBar(
            index_mapper=cbar_index_mapper,
            plot=self.imgplot,
            padding_top=self.imgplot.padding_top,
            padding_bottom=self.imgplot.padding_bottom,
            padding_right=40,
            resizable="v",
            width=30,
        )

        # Create the two cross plots
        self.pd = ArrayPlotData(
            line_index=array([]),
            line_value=array([]),
            scatter_index=array([]),
            scatter_value=array([]),
            scatter_color=array([]),
        )

        self.cross_plot = Plot(self.pd, resizable="h")
        self.cross_plot.height = 100
        self.cross_plot.padding = 20
        self.cross_plot.plot(("line_index", "line_value"), line_style="dot")
        self.cross_plot.plot(
            ("scatter_index", "scatter_value", "scatter_color"),
            type="cmap_scatter",
            name="dot",
            color_mapper=self._cmap(image_value_range),
            marker="circle",
            marker_size=8,
        )

        self.cross_plot.index_range = self.imgplot.index_range.x_range

        self.pd.set_data("line_index2", array([]))
        self.pd.set_data("line_value2", array([]))
        self.pd.set_data("scatter_index2", array([]))
        self.pd.set_data("scatter_value2", array([]))
        self.pd.set_data("scatter_color2", array([]))

        self.cross_plot2 = Plot(self.pd, width=140, orientation="v", resizable="v", padding=20, padding_bottom=160)
        self.cross_plot2.plot(("line_index2", "line_value2"), line_style="dot")
        self.cross_plot2.plot(
            ("scatter_index2", "scatter_value2", "scatter_color2"),
            type="cmap_scatter",
            name="dot",
            color_mapper=self._cmap(image_value_range),
            marker="circle",
            marker_size=8,
        )

        self.cross_plot2.index_range = self.imgplot.index_range.y_range

        # Create a container and add components
        self.container = HPlotContainer(padding=40, fill_padding=True, bgcolor="white", use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.cross_plot)
        inner_cont.add(contour_container)
        self.container.add(self.colorbar)
        self.container.add(inner_cont)
        self.container.add(self.cross_plot2)

    def update(self):
        imgdata = self.load_imagedata()
        if imgdata is not None:
            self.minz = imgdata.min()
            self.maxz = imgdata.max()
            self.colorbar.index_mapper.range.low = self.minz
            self.colorbar.index_mapper.range.high = self.maxz
            xs = numpy.linspace(0, imgdata.shape[0], imgdata.shape[0] + 1)
            ys = numpy.linspace(0, imgdata.shape[1], imgdata.shape[1] + 1)
            print xs
            print ys
            self._image_index.set_data(xs, ys)
            self._image_value.data = imgdata
            self.pd.set_data("line_index", xs)
            self.pd.set_data("line_index2", ys)
            self.container.invalidate_draw()
            self.container.request_redraw()

    def load_imagedata(self):
        try:
            dir = self.shot[self.shot.index(":\\") + 2 : self.shot.rindex("\\") + 1]
            shotnum = self.shot[self.shot.rindex("_") + 1 : self.shot.rindex(".ascii")]
        except ValueError:
            print " *** Not a valid column density path *** "
            return None
        # Set data path
        # Prepare PlotData object
        print dir
        print shotnum
        return load(dir, shotnum)

    # ---------------------------------------------------------------------------
    # Event handlers
    # ---------------------------------------------------------------------------

    def _shot_changed(self):
        self.update()

    def _metadata_changed(self, old, new):
        """ This function takes out a cross section from the image data, based
        on the line inspector selections, and updates the line and scatter
        plots."""

        self.cross_plot.value_range.low = self.minz
        self.cross_plot.value_range.high = self.maxz
        self.cross_plot2.value_range.low = self.minz
        self.cross_plot2.value_range.high = self.maxz
        if self._image_index.metadata.has_key("selections"):
            x_ndx, y_ndx = self._image_index.metadata["selections"]
            if y_ndx and x_ndx:
                self.pd.set_data("line_value", self._image_value.data[y_ndx, :])
                self.pd.set_data("line_value2", self._image_value.data[:, x_ndx])
                xdata, ydata = self._image_index.get_data()
                xdata, ydata = xdata.get_data(), ydata.get_data()
                self.pd.set_data("scatter_index", array([xdata[x_ndx]]))
                self.pd.set_data("scatter_index2", array([ydata[y_ndx]]))
                self.pd.set_data("scatter_value", array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_value2", array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_color", array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_color2", array([self._image_value.data[y_ndx, x_ndx]]))
        else:
            self.pd.set_data("scatter_value", array([]))
            self.pd.set_data("scatter_value2", array([]))
            self.pd.set_data("line_value", array([]))
            self.pd.set_data("line_value2", array([]))

    def _colormap_changed(self):
        self._cmap = color_map_name_dict[self.colormap]
        if hasattr(self, "polyplot"):
            value_range = self.polyplot.color_mapper.range
            self.polyplot.color_mapper = self._cmap(value_range)
            value_range = self.cross_plot.color_mapper.range
            self.cross_plot.color_mapper = self._cmap(value_range)
            # FIXME: change when we decide how best to update plots using
            # the shared colormap in plot object
            self.cross_plot.plots["dot"][0].color_mapper = self._cmap(value_range)
            self.cross_plot2.plots["dot"][0].color_mapper = self._cmap(value_range)
            self.container.request_redraw()

    def _num_levels_changed(self):
        if self.num_levels > 3:
            self.polyplot.levels = self.num_levels
            self.lineplot.levels = self.num_levels