class PlotUI(HasTraits):
    
    # container for all plots
    container = Instance(HPlotContainer)
    
    # Plot components within this container:
    polyplot = Instance(ContourPolyPlot)
    lineplot = Instance(ContourLinePlot)
    cross_plot = Instance(Plot)
    cross_plot2 = Instance(Plot)
    colorbar = Instance(ColorBar)
    
    # plot data
    pd = Instance(ArrayPlotData)

    # view options
    num_levels = Int(15)
    colormap = Enum(colormaps)
    
    #Traits view definitions:
    traits_view = View(
        Group(UItem('container', editor=ComponentEditor(size=(800,600)))),
        resizable=True)

    plot_edit_view = View(
        Group(Item('num_levels'),
              Item('colormap')),
              buttons=["OK","Cancel"])

    
    #---------------------------------------------------------------------------
    # Private Traits
    #---------------------------------------------------------------------------

    _image_index = Instance(GridDataSource)
    _image_value = Instance(ImageData)

    _cmap = Trait(default_colormaps.jet, Callable)

    #---------------------------------------------------------------------------
    # Public View interface
    #---------------------------------------------------------------------------

    def __init__(self, *args, **kwargs):
        super(PlotUI, self).__init__(*args, **kwargs)
        # FIXME: 'with' wrapping is temporary fix for infinite range in initial 
        # color map, which can cause a distracting warning print. This 'with'
        # wrapping should be unnecessary after fix in color_mapper.py.
        with errstate(invalid='ignore'):
            self.create_plot()

    def create_plot(self):

        # Create the mapper, etc
        self._image_index = GridDataSource(array([]),
                                          array([]),
                                          sort_order=("ascending","ascending"))
        image_index_range = DataRange2D(self._image_index)
        self._image_index.on_trait_change(self._metadata_changed,
                                          "metadata_changed")

        self._image_value = ImageData(data=array([]), value_depth=1)
        image_value_range = DataRange1D(self._image_value)



        # Create the contour plots
        self.polyplot = ContourPolyPlot(index=self._image_index,
                                        value=self._image_value,
                                        index_mapper=GridMapper(range=
                                            image_index_range),
                                        color_mapper=\
                                            self._cmap(image_value_range),
                                        levels=self.num_levels)

        self.lineplot = ContourLinePlot(index=self._image_index,
                                        value=self._image_value,
                                        index_mapper=GridMapper(range=
                                            self.polyplot.index_mapper.range),
                                        levels=self.num_levels)


        # Add a left axis to the plot
        left = PlotAxis(orientation='left',
                        title= "y",
                        mapper=self.polyplot.index_mapper._ymapper,
                        component=self.polyplot)
        self.polyplot.overlays.append(left)

        # Add a bottom axis to the plot
        bottom = PlotAxis(orientation='bottom',
                          title= "x",
                          mapper=self.polyplot.index_mapper._xmapper,
                          component=self.polyplot)
        self.polyplot.overlays.append(bottom)


        # Add some tools to the plot
        self.polyplot.tools.append(PanTool(self.polyplot,
                                           constrain_key="shift"))
        self.polyplot.overlays.append(ZoomTool(component=self.polyplot,
                                            tool_mode="box", always_on=False))
        self.polyplot.overlays.append(LineInspector(component=self.polyplot,
                                               axis='index_x',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               is_listener=True,
                                               color="white"))
        self.polyplot.overlays.append(LineInspector(component=self.polyplot,
                                               axis='index_y',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               color="white",
                                               is_listener=True))

        # Add these two plots to one container
        contour_container = OverlayPlotContainer(padding=20,
                                                 use_backbuffer=True,
                                                 unified_draw=True)
        contour_container.add(self.polyplot)
        contour_container.add(self.lineplot)


        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=image_value_range)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=self.polyplot,
                                 padding_top=self.polyplot.padding_top,
                                 padding_bottom=self.polyplot.padding_bottom,
                                 padding_right=40,
                                 resizable='v',
                                 width=30)

        self.pd = ArrayPlotData(line_index = array([]),
                                line_value = array([]),
                                scatter_index = array([]),
                                scatter_value = array([]),
                                scatter_color = array([]))

        self.cross_plot = Plot(self.pd, resizable="h")
        self.cross_plot.height = 100
        self.cross_plot.padding = 20
        self.cross_plot.plot(("line_index", "line_value"),
                             line_style="dot")
        self.cross_plot.plot(("scatter_index","scatter_value","scatter_color"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=8)

        self.cross_plot.index_range = self.polyplot.index_range.x_range

        self.pd.set_data("line_index2", array([]))
        self.pd.set_data("line_value2", array([]))
        self.pd.set_data("scatter_index2", array([]))
        self.pd.set_data("scatter_value2", array([]))
        self.pd.set_data("scatter_color2", array([]))

        self.cross_plot2 = Plot(self.pd, width = 140, orientation="v", 
                                resizable="v", padding=20, padding_bottom=160)
        self.cross_plot2.plot(("line_index2", "line_value2"),
                             line_style="dot")
        self.cross_plot2.plot(("scatter_index2",
                               "scatter_value2",
                               "scatter_color2"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=8)

        self.cross_plot2.index_range = self.polyplot.index_range.y_range



        # Create a container and add components
        self.container = HPlotContainer(padding=40, fill_padding=True,
                                        bgcolor = "white", use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.cross_plot)
        inner_cont.add(contour_container)
        self.container.add(self.colorbar)
        self.container.add(inner_cont)
        self.container.add(self.cross_plot2)


    def update(self, model):
        self.minz = model.minz
        self.maxz = model.maxz
        self.colorbar.index_mapper.range.low = self.minz
        self.colorbar.index_mapper.range.high = self.maxz
        self._image_index.set_data(model.xs, model.ys)
        self._image_value.data = model.zs
        self.pd.update_data(line_index=model.xs, line_index2=model.ys)
        self.container.invalidate_draw()
        self.container.request_redraw()


    #---------------------------------------------------------------------------
    # Event handlers
    #---------------------------------------------------------------------------

    def _metadata_changed(self, old, new):
        """ This function takes out a cross section from the image data, based
        on the line inspector selections, and updates the line and scatter
        plots."""

        self.cross_plot.value_range.low = self.minz
        self.cross_plot.value_range.high = self.maxz
        self.cross_plot2.value_range.low = self.minz
        self.cross_plot2.value_range.high = self.maxz
        if self._image_index.metadata.has_key("selections"):
            x_ndx, y_ndx = self._image_index.metadata["selections"]
            if y_ndx and x_ndx:
                xdata, ydata = self._image_index.get_data()
                xdata, ydata = xdata.get_data(), ydata.get_data()
                self.pd.update_data(
                    line_value=self._image_value.data[y_ndx,:],
                    line_value2=self._image_value.data[:,x_ndx],
                    scatter_index=array([xdata[x_ndx]]),
                    scatter_index2=array([ydata[y_ndx]]),
                    scatter_value=array([self._image_value.data[y_ndx, x_ndx]]),
                    scatter_value2=array([self._image_value.data[y_ndx, x_ndx]]),
                    scatter_color=array([self._image_value.data[y_ndx, x_ndx]]),
                    scatter_color2=array([self._image_value.data[y_ndx, x_ndx]])
                )
        else:
            self.pd.update_data({"scatter_value": array([]),
                "scatter_value2": array([]), "line_value": array([]),
                "line_value2": array([])})

    def _colormap_changed(self):
        self._cmap = default_colormaps.color_map_name_dict[self.colormap]
        if self.polyplot is not None:
            value_range = self.polyplot.color_mapper.range
            self.polyplot.color_mapper = self._cmap(value_range)
            value_range = self.cross_plot.color_mapper.range
            self.cross_plot.color_mapper = self._cmap(value_range)
            # FIXME: change when we decide how best to update plots using
            # the shared colormap in plot object
            self.cross_plot.plots["dot"
                                  ][0].color_mapper = self._cmap(value_range)
            self.cross_plot2.plots["dot"
                                   ][0].color_mapper = self._cmap(value_range)
            self.container.request_redraw()

    def _num_levels_changed(self):
        if self.num_levels > 3:
            self.polyplot.levels = self.num_levels
            self.lineplot.levels = self.num_levels
class ImageGUI(HasTraits):
    
    # TO FIX : put here the last available shot
    #shot = File('L:\\data\\app3\\2011\\1108\\110823\\column_5200.ascii')
    #shot = File('/home/pmd/atomcool/lab/data/app3/2012/1203/120307/column_3195.ascii')

    #-- Shot traits
    shotdir = Directory('/home/pmd/atomcool/lab/data/app3/2012/1203/120320/')
    shots = List(Str)
    selectedshot = List(Str)
    namefilter = Str('column')

    #-- Report trait
    report = Str

    #-- Displayed analysis results
    number = Float
     
    #-- Column density plot container
    column_density = Instance(HPlotContainer)
    #---- Plot components within this container
    imgplot     = Instance(CMapImagePlot)
    cross_plot  = Instance(Plot)
    cross_plot2 = Instance(Plot)
    colorbar    = Instance(ColorBar)
    #---- Plot data
    pd = Instance(ArrayPlotData)
    #---- Colorbar 
    num_levels = Int(15)
    colormap = Enum(color_map_name_dict.keys())

    #-- Crosshair location
    cursor = Instance(BaseCursorTool)
    xy = DelegatesTo('cursor', prefix='current_position')
    xpos = Float(0.)
    ypos = Float(0.)
    xpos_read = Float(0.)
    ypos_read = Float(0.)
    cursor_group = Group( Group(Item('xpos', show_label=True), 
	                        Item('xpos_read', show_label=False, style="readonly"),
				orientation='horizontal'),
			  Group(Item('ypos', show_label=True), 
				Item('ypos_read', show_label=False, style="readonly"),
				orientation='horizontal'),
		          orientation='vertical', layout='normal',springy=True)

    
    #---------------------------------------------------------------------------
    # Traits View Definitions
    #---------------------------------------------------------------------------
    
    traits_view = View(
                    Group(
                      #Directory
                      Item( 'shotdir',style='simple', editor=DirectoryEditor(), width = 400, \
				      show_label=False, resizable=False ),
                      #Bottom
                      HSplit(
		        #-- Pane for shot selection
        	        Group(
		          Item( 'namefilter', show_label=False,springy=False),		
                          Item( 'shots',show_label=False, width=180, height= 360, \
					editor = TabularEditor(selected='selectedshot',\
					editable=False,multi_select=True,\
					adapter=SelectAdapter()) ),
			  cursor_group,
                          orientation='vertical',
		          layout='normal', ),

		        #-- Pane for column density plots
			Group(
			  Item('column_density',editor=ComponentEditor(), \
                                           show_label=False, width=600, height=500, \
                                           resizable=True ), 
			  Item('report',show_label=False, width=180, \
					springy=True, style='custom' ),
			  layout='tabbed', springy=True),

			#-- Pane for analysis results
			Group(
		          Item('number',show_label=False)
			  )
                      ),
                      orientation='vertical',
                      layout='normal',
                    ),
                  width=1400, height=500, resizable=True)
    
    #-- Pop-up view when Plot->Edit is selcted from the menu
    plot_edit_view = View(
                    Group(Item('num_levels'),
                          Item('colormap')),
                          buttons=["OK","Cancel"])
                          
    
    #---------------------------------------------------------------------------
    # Private Traits
    #---------------------------------------------------------------------------

    #-- Represents the region where the data set is defined
    _image_index = Instance(GridDataSource) 

    #-- Represents the data that will be plotted on the grid
    _image_value = Instance(ImageData)

    #-- Represents the color map that will be used
    _cmap = Trait(jet, Callable)
    
    
    #---------------------------------------------------------------------------
    # Public View interface
    #---------------------------------------------------------------------------

    def __init__(self, *args, **kwargs):
	#-- super is used to run the inherited __init__ method
	#-- this ensures that all the Traits machinery is properly setup
	#-- even though the __init__ method is overridden
        super(ImageGUI, self).__init__(*args, **kwargs)

	#-- after running the inherited __init__, a plot is created
        self.create_plot()



    def create_plot(self):

        #-- Create the index for the x an y axes and the range over
	#-- which they vary
        self._image_index = GridDataSource(array([]), array([]),
                                          sort_order=("ascending","ascending"))
        image_index_range = DataRange2D(self._image_index)
        
	#-- I believe this is what allows tracking the mouse
        self._image_index.on_trait_change(self._metadata_changed,
                                          "metadata_changed")


	#-- Create the image values and determine their range
        self._image_value = ImageData(data=array([]), value_depth=1)
        image_value_range = DataRange1D(self._image_value)
        
        # Create the image plot
        self.imgplot = CMapImagePlot( index=self._image_index,
                                      value=self._image_value,
                                      index_mapper=GridMapper(range=image_index_range),
                                      color_mapper=self._cmap(image_value_range),)
                                 

        # Add a left axis to the plot
        left = PlotAxis(orientation='left',
                        title= "axial",
                        mapper=self.imgplot.index_mapper._ymapper,
                        component=self.imgplot)
        self.imgplot.overlays.append(left)

        # Add a bottom axis to the plot
        bottom = PlotAxis(orientation='bottom',
                          title= "radial",
                          mapper=self.imgplot.index_mapper._xmapper,
                          component=self.imgplot)
        self.imgplot.overlays.append(bottom)


        # Add some tools to the plot
        self.imgplot.tools.append(PanTool(self.imgplot,drag_button="right",
                                            constrain_key="shift"))

        self.imgplot.overlays.append(ZoomTool(component=self.imgplot,
                                            tool_mode="box", always_on=False))

        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=image_value_range)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=self.imgplot,
                                 padding_top=self.imgplot.padding_top,
                                 padding_bottom=self.imgplot.padding_bottom,
                                 padding_right=40,
                                 resizable='v',
                                 width=30)


	# Add a cursor 
	self.cursor = CursorTool( self.imgplot, drag_button="left", color="white")
	# the cursor is a rendered component so it goes in the overlays list
	self.imgplot.overlays.append(self.cursor)
                        
        # Create the two cross plots
        self.pd = ArrayPlotData(line_index = array([]),
                                line_value = array([]),
                                scatter_index = array([]),
                                scatter_value = array([]),
                                scatter_color = array([]))

        self.cross_plot = Plot(self.pd, resizable="h")
        self.cross_plot.height = 100
        self.cross_plot.padding = 20
        self.cross_plot.plot(("line_index", "line_value"),
                             line_style="dot")
        self.cross_plot.plot(("scatter_index","scatter_value","scatter_color"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=6)

        self.cross_plot.index_range = self.imgplot.index_range.x_range

        self.pd.set_data("line_index2", array([]))
        self.pd.set_data("line_value2", array([]))
        self.pd.set_data("scatter_index2", array([]))
        self.pd.set_data("scatter_value2", array([]))
        self.pd.set_data("scatter_color2", array([]))

        self.cross_plot2 = Plot(self.pd, width = 140, orientation="v", resizable="v", padding=20, padding_bottom=160)
        self.cross_plot2.plot(("line_index2", "line_value2"),
                             line_style="dot")
        self.cross_plot2.plot(("scatter_index2","scatter_value2","scatter_color2"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=8)

        self.cross_plot2.index_range = self.imgplot.index_range.y_range


        # Create a container and add sub-containers and components
        self.column_density = HPlotContainer(padding=40, fill_padding=True,
                                        bgcolor = "white", use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.cross_plot)
	self.imgplot.padding =20
	inner_cont.add(self.imgplot)
        self.column_density.add(self.colorbar)
        self.column_density.add(inner_cont)
        self.column_density.add(self.cross_plot2)

    def update(self):
	#print self.cursor.current_index
	#self.cursor.current_position = 100.,100.
        self.shots = self.populate_shot_list()
	print self.selectedshot    
        imgdata, self.report = self.load_imagedata()
        if imgdata is not None:
            self.minz = imgdata.min()
            self.maxz = imgdata.max()
            self.colorbar.index_mapper.range.low = self.minz
            self.colorbar.index_mapper.range.high = self.maxz
            xs=numpy.linspace(0,imgdata.shape[0],imgdata.shape[0]+1)
            ys=numpy.linspace(0,imgdata.shape[1],imgdata.shape[1]+1)
            #print xs
            #print ys
            self._image_index.set_data(xs,ys)
            self._image_value.data = imgdata
            self.pd.set_data("line_index", xs)
            self.pd.set_data("line_index2",ys)
            self.column_density.invalidate_draw()
            self.column_density.request_redraw()                        

    def populate_shot_list(self):
        try:
            shot_list = os.listdir(self.shotdir)
	    fun = lambda x: iscol(x,self.namefilter)
            shot_list = filter( fun, shot_list)
	    shot_list = sorted(shot_list)
        except ValueError:
            print " *** Not a valid directory path ***"
        return shot_list

    def load_imagedata(self):
        try:
            directory = self.shotdir
	    if self.selectedshot == []:
		    filename = self.shots[0]
	    else:
		    filename = self.selectedshot[0]
            #shotnum = filename[filename.rindex('_')+1:filename.rindex('.ascii')]
	    shotnum = filename[:filename.index('_')]
        except ValueError:
            print " *** Not a valid path *** " 
            return None
        # Set data path
        # Prepare PlotData object
	print "Loading file #%s from %s" % (filename,directory)
        return import_data.load(directory,filename), import_data.load_report(directory,shotnum)


    #---------------------------------------------------------------------------
    # Event handlers
    #---------------------------------------------------------------------------
    
    def _selectedshot_changed(self):
	print self.selectedshot
        self.update()

    def _shots_changed(self):
        self.shots = self.populate_shot_list()
	return

    def _namefilter_changed(self):
	self.shots = self.populate_shot_list()
	return

  
    def _xpos_changed(self):
	self.cursor.current_position = self.xpos, self.ypos
    def _ypos_changed(self):
	self.cursor.current_position = self.xpos, self.ypos

    def _metadata_changed(self):
	self._xy_changed()
	    
    def _xy_changed(self):
	self.xpos_read = self.cursor.current_index[0]
	self.ypos_read = self.cursor.current_index[1]
	#print self.cursor.current_index
        """ This function takes out a cross section from the image data, based
        on the cursor selections, and updates the line and scatter
        plots."""
        self.cross_plot.value_range.low = self.minz
        self.cross_plot.value_range.high = self.maxz
        self.cross_plot2.value_range.low = self.minz
        self.cross_plot2.value_range.high = self.maxz
        if True:
            x_ndx, y_ndx = self.cursor.current_index
            if y_ndx and x_ndx:
                self.pd.set_data("line_value",
				self._image_value.data[:,y_ndx])
                self.pd.set_data("line_value2",
				self._image_value.data[x_ndx,:])
                xdata, ydata = self._image_index.get_data()
                xdata, ydata = xdata.get_data(), ydata.get_data()
                self.pd.set_data("scatter_index", array([ydata[y_ndx]]))
                self.pd.set_data("scatter_index2", array([xdata[x_ndx]]))
                self.pd.set_data("scatter_value",
                    array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_value2",
                    array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_color",
                    array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_color2",
                    array([self._image_value.data[y_ndx, x_ndx]]))
        else:
            self.pd.set_data("scatter_value", array([]))
            self.pd.set_data("scatter_value2", array([]))
            self.pd.set_data("line_value", array([]))
            self.pd.set_data("line_value2", array([]))

    def _colormap_changed(self):
        self._cmap = color_map_name_dict[self.colormap]
        if hasattr(self, "polyplot"):
            value_range = self.polyplot.color_mapper.range
            self.polyplot.color_mapper = self._cmap(value_range)
            value_range = self.cross_plot.color_mapper.range
            self.cross_plot.color_mapper = self._cmap(value_range)
            # FIXME: change when we decide how best to update plots using
            # the shared colormap in plot object
            self.cross_plot.plots["dot"][0].color_mapper = self._cmap(value_range)
            self.cross_plot2.plots["dot"][0].color_mapper = self._cmap(value_range)
            self.column_density.request_redraw()

    def _num_levels_changed(self):
        if self.num_levels > 3:
            self.polyplot.levels = self.num_levels
            self.lineplot.levels = self.num_levels
class PlotUI(HasTraits):
    
    # container for all plots
    container = Instance(HPlotContainer)
    
    # Plot components within this container:
    polyplot = Instance(ContourPolyPlot)
    lineplot = Instance(ContourLinePlot)
    cross_plot = Instance(Plot)
    cross_plot2 = Instance(Plot)
    colorbar = Instance(ColorBar)
    
    # plot data
    pd = Instance(ArrayPlotData)

    # view options
    num_levels = Int(15)
    colormap = Enum(colormaps)
    
    #Traits view definitions:
    traits_view = View(
        Group(UItem('container', editor=ComponentEditor(size=(800,600)))),
        resizable=True)

    plot_edit_view = View(
        Group(Item('num_levels'),
              Item('colormap')),
              buttons=["OK","Cancel"])

    
    #---------------------------------------------------------------------------
    # Private Traits
    #---------------------------------------------------------------------------

    _image_index = Instance(GridDataSource)
    _image_value = Instance(ImageData)

    _cmap = Trait(default_colormaps.jet, Callable)

    #---------------------------------------------------------------------------
    # Public View interface
    #---------------------------------------------------------------------------

    def __init__(self, *args, **kwargs):
        super(PlotUI, self).__init__(*args, **kwargs)
        # FIXME: 'with' wrapping is temporary fix for infinite range in initial 
        # color map, which can cause a distracting warning print. This 'with'
        # wrapping should be unnecessary after fix in color_mapper.py.
        with errstate(invalid='ignore'):
            self.create_plot()

    def create_plot(self):

        # Create the mapper, etc
        self._image_index = GridDataSource(array([]),
                                          array([]),
                                          sort_order=("ascending","ascending"))
        image_index_range = DataRange2D(self._image_index)
        self._image_index.on_trait_change(self._metadata_changed,
                                          "metadata_changed")

        self._image_value = ImageData(data=array([]), value_depth=1)
        image_value_range = DataRange1D(self._image_value)



        # Create the contour plots
        self.polyplot = ContourPolyPlot(index=self._image_index,
                                        value=self._image_value,
                                        index_mapper=GridMapper(range=
                                            image_index_range),
                                        color_mapper=\
                                            self._cmap(image_value_range),
                                        levels=self.num_levels)

        self.lineplot = ContourLinePlot(index=self._image_index,
                                        value=self._image_value,
                                        index_mapper=GridMapper(range=
                                            self.polyplot.index_mapper.range),
                                        levels=self.num_levels)


        # Add a left axis to the plot
        left = PlotAxis(orientation='left',
                        title= "y",
                        mapper=self.polyplot.index_mapper._ymapper,
                        component=self.polyplot)
        self.polyplot.overlays.append(left)

        # Add a bottom axis to the plot
        bottom = PlotAxis(orientation='bottom',
                          title= "x",
                          mapper=self.polyplot.index_mapper._xmapper,
                          component=self.polyplot)
        self.polyplot.overlays.append(bottom)


        # Add some tools to the plot
        self.polyplot.tools.append(PanTool(self.polyplot,
                                           constrain_key="shift"))
        self.polyplot.overlays.append(ZoomTool(component=self.polyplot,
                                            tool_mode="box", always_on=False))
        self.polyplot.overlays.append(LineInspector(component=self.polyplot,
                                               axis='index_x',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               is_listener=True,
                                               color="white"))
        self.polyplot.overlays.append(LineInspector(component=self.polyplot,
                                               axis='index_y',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               color="white",
                                               is_listener=True))

        # Add these two plots to one container
        contour_container = OverlayPlotContainer(padding=20,
                                                 use_backbuffer=True,
                                                 unified_draw=True)
        contour_container.add(self.polyplot)
        contour_container.add(self.lineplot)


        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=image_value_range)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=self.polyplot,
                                 padding_top=self.polyplot.padding_top,
                                 padding_bottom=self.polyplot.padding_bottom,
                                 padding_right=40,
                                 resizable='v',
                                 width=30)

        self.pd = ArrayPlotData(line_index = array([]),
                                line_value = array([]),
                                scatter_index = array([]),
                                scatter_value = array([]),
                                scatter_color = array([]))

        self.cross_plot = Plot(self.pd, resizable="h")
        self.cross_plot.height = 100
        self.cross_plot.padding = 20
        self.cross_plot.plot(("line_index", "line_value"),
                             line_style="dot")
        self.cross_plot.plot(("scatter_index","scatter_value","scatter_color"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=8)

        self.cross_plot.index_range = self.polyplot.index_range.x_range

        self.pd.set_data("line_index2", array([]))
        self.pd.set_data("line_value2", array([]))
        self.pd.set_data("scatter_index2", array([]))
        self.pd.set_data("scatter_value2", array([]))
        self.pd.set_data("scatter_color2", array([]))

        self.cross_plot2 = Plot(self.pd, width = 140, orientation="v", resizable="v", padding=20, padding_bottom=160)
        self.cross_plot2.plot(("line_index2", "line_value2"),
                             line_style="dot")
        self.cross_plot2.plot(("scatter_index2","scatter_value2","scatter_color2"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=8)

        self.cross_plot2.index_range = self.polyplot.index_range.y_range



        # Create a container and add components
        self.container = HPlotContainer(padding=40, fill_padding=True,
                                        bgcolor = "white", use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.cross_plot)
        inner_cont.add(contour_container)
        self.container.add(self.colorbar)
        self.container.add(inner_cont)
        self.container.add(self.cross_plot2)


    def update(self, model):
        self.minz = model.minz
        self.maxz = model.maxz
        self.colorbar.index_mapper.range.low = self.minz
        self.colorbar.index_mapper.range.high = self.maxz
        self._image_index.set_data(model.xs, model.ys)
        self._image_value.data = model.zs
        self.pd.set_data("line_index", model.xs)
        self.pd.set_data("line_index2", model.ys)
        self.container.invalidate_draw()
        self.container.request_redraw()


    #---------------------------------------------------------------------------
    # Event handlers
    #---------------------------------------------------------------------------

    def _metadata_changed(self, old, new):
        """ This function takes out a cross section from the image data, based
        on the line inspector selections, and updates the line and scatter
        plots."""

        self.cross_plot.value_range.low = self.minz
        self.cross_plot.value_range.high = self.maxz
        self.cross_plot2.value_range.low = self.minz
        self.cross_plot2.value_range.high = self.maxz
        if self._image_index.metadata.has_key("selections"):
            x_ndx, y_ndx = self._image_index.metadata["selections"]
            if y_ndx and x_ndx:
                self.pd.set_data("line_value",
                                 self._image_value.data[y_ndx,:])
                self.pd.set_data("line_value2",
                                 self._image_value.data[:,x_ndx])
                xdata, ydata = self._image_index.get_data()
                xdata, ydata = xdata.get_data(), ydata.get_data()
                self.pd.set_data("scatter_index", array([xdata[x_ndx]]))
                self.pd.set_data("scatter_index2", array([ydata[y_ndx]]))
                self.pd.set_data("scatter_value",
                    array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_value2",
                    array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_color",
                    array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_color2",
                    array([self._image_value.data[y_ndx, x_ndx]]))
        else:
            self.pd.set_data("scatter_value", array([]))
            self.pd.set_data("scatter_value2", array([]))
            self.pd.set_data("line_value", array([]))
            self.pd.set_data("line_value2", array([]))

    def _colormap_changed(self):
        self._cmap = default_colormaps.color_map_name_dict[self.colormap]
        if self.polyplot is not None:
            value_range = self.polyplot.color_mapper.range
            self.polyplot.color_mapper = self._cmap(value_range)
            value_range = self.cross_plot.color_mapper.range
            self.cross_plot.color_mapper = self._cmap(value_range)
            # FIXME: change when we decide how best to update plots using
            # the shared colormap in plot object
            self.cross_plot.plots["dot"][0].color_mapper = self._cmap(value_range)
            self.cross_plot2.plots["dot"][0].color_mapper = self._cmap(value_range)
            self.container.request_redraw()

    def _num_levels_changed(self):
        if self.num_levels > 3:
            self.polyplot.levels = self.num_levels
            self.lineplot.levels = self.num_levels
class ImagePlot(Atom):
    # container for all plots
    container = Typed(HPlotContainer)

    # Plot components within this container:
    color_plot = Typed(CMapImagePlot)
    vertical_cross_plot = Typed(Plot)
    horizontal_cross_plot = Typed(Plot)
    colorbar = Typed(ColorBar)

    # plot data
    pd_all = Typed(ArrayPlotData)
    #pd_horiz=Instance(ArrayPlotData)
    #pd_vert=Instance(ArrayPlotData)
    #private data storage
    _imag_index = Typed(GridDataSource)
    _image_value = Typed(ImageData)

    def __init__(self, x, y, z):
        super(ImagePlot, self).__init__()
        self.pd_all = ArrayPlotData(imagedata=z)
        #self.pd_horiz = ArrayPlotData(x=x, horiz=z[4, :])
        #self.pd_vert = ArrayPlotData(y=y, vert=z[:,5])

        self._imag_index = GridDataSource(xdata=x,
                                          ydata=y,
                                          sort_order=("ascending",
                                                      "ascending"))
        index_mapper = GridMapper(range=DataRange2D(self._imag_index))
        self._imag_index.on_trait_change(self._metadata_changed,
                                         "metadata_changed")
        self._image_value = ImageData(data=z, value_depth=1)
        color_mapper = jet(DataRange1D(self._image_value))

        self.color_plot = CMapImagePlot(index=self._imag_index,
                                        index_mapper=index_mapper,
                                        value=self._image_value,
                                        value_mapper=color_mapper,
                                        padding=20,
                                        use_backbuffer=True,
                                        unified_draw=True)

        #Add axes to image plot
        left = PlotAxis(orientation='left',
                        title="Frequency (GHz)",
                        mapper=self.color_plot.index_mapper._ymapper,
                        component=self.color_plot)

        self.color_plot.overlays.append(left)

        bottom = PlotAxis(orientation='bottom',
                          title="Time (us)",
                          mapper=self.color_plot.index_mapper._xmapper,
                          component=self.color_plot)
        self.color_plot.overlays.append(bottom)

        self.color_plot.tools.append(
            PanTool(self.color_plot, constrain_key="shift"))
        self.color_plot.overlays.append(
            ZoomTool(component=self.color_plot,
                     tool_mode="box",
                     always_on=False))

        #Add line inspector tool for horizontal and vertical
        self.color_plot.overlays.append(
            LineInspector(component=self.color_plot,
                          axis='index_x',
                          inspect_mode="indexed",
                          write_metadata=True,
                          is_listener=True,
                          color="white"))

        self.color_plot.overlays.append(
            LineInspector(component=self.color_plot,
                          axis='index_y',
                          inspect_mode="indexed",
                          write_metadata=True,
                          color="white",
                          is_listener=True))

        myrange = DataRange1D(low=amin(z), high=amax(z))
        cmap = jet
        self.colormap = cmap(myrange)

        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=myrange)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=self.color_plot,
                                 padding_top=self.color_plot.padding_top,
                                 padding_bottom=self.color_plot.padding_bottom,
                                 padding_right=40,
                                 resizable='v',
                                 width=30)  #, ytitle="Magvec (mV)")

        #create horizontal line plot
        self.horiz_cross_plot = Plot(self.pd_horiz, resizable="h")
        self.horiz_cross_plot.height = 100
        self.horiz_cross_plot.padding = 20
        self.horiz_cross_plot.plot(("x", "horiz"))  #,
        #line_style="dot")
        #        self.cross_plot.plot(("scatter_index","scatter_value","scatter_color"),
        #                             type="cmap_scatter",
        #                             name="dot",
        #                             color_mapper=self._cmap(image_value_range),
        #                             marker="circle",
        #                             marker_size=8)

        self.horiz_cross_plot.index_range = self.color_plot.index_range.x_range

        #create vertical line plot
        self.vert_cross_plot = Plot(self.pd_vert,
                                    width=140,
                                    orientation="v",
                                    resizable="v",
                                    padding=20,
                                    padding_bottom=160)
        self.vert_cross_plot.plot(("y", "vert"))  #,
        #                             line_style="dot")
        # self.vert_cross_plot.xtitle="Magvec (mV)"
        #       self.vertica_cross_plot.plot(("vertical_scatter_index",
        #                              "vertical_scatter_value",
        #                              "vertical_scatter_color"),
        #                            type="cmap_scatter",
        #                            name="dot",
        #                            color_mapper=self._cmap(image_value_range),
        #                            marker="circle",
        #                           marker_size=8)

        self.vert_cross_plot.index_range = self.color_plot.index_range.y_range

        # Create a container and add components
        self.container = HPlotContainer(padding=40,
                                        fill_padding=True,
                                        bgcolor="white",
                                        use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.horiz_cross_plot)
        inner_cont.add(self.color_plot)
        self.container.add(self.colorbar)
        self.container.add(inner_cont)
        self.container.add(self.vert_cross_plot)

    def _metadata_changed(self, old, new):
        """ This function takes out a cross section from the image data, based
        on the line inspector selections, and updates the line and scatter
        plots."""

        #self.cross_plot.value_range.low = self.minz
        #self.cross_plot.value_range.high = self.maxz
        #self.cross_plot2.value_range.low = self.minz
        #self.cross_plot2.value_range.high = self.maxz
        if self._imag_index.metadata.has_key("selections"):
            x_ndx, y_ndx = self._imag_index.metadata["selections"]
            if y_ndx and x_ndx:
                #                xdata, ydata = self._image_index.get_data()
                #                xdata, ydata = xdata.get_data(), ydata.get_data()
                self.pd_horiz.set_data("horiz",
                                       self._image_value.data[y_ndx, :])
                self.pd_vert.set_data("vert", self._image_value.data[:, x_ndx])


#                    scatter_index=array([xdata[x_ndx]]),
#                    scatter_index2=array([ydata[y_ndx]]),
#                    scatter_value=array([self._image_value.data[y_ndx, x_ndx]]),
#                    scatter_value2=array([self._image_value.data[y_ndx, x_ndx]]),
#                    scatter_color=array([self._image_value.data[y_ndx, x_ndx]]),
#                    scatter_color2=array([self._image_value.data[y_ndx, x_ndx]])
#                )
#        else:
#            self.pd.update_data({"scatter_value": array([]),
#                "scatter_value2": array([]), "line_value": array([]),
#                "line_value2": array([])})

#if __name__ == "__main__":

#    filename="/Users/thomasaref/Dropbox/Dad stuff/sample3/digitizer/lt/sample3_digitizer_f_sweep_t_300mk_100nspulse.hdf5"
#
#    with h5py.File(filename, 'r') as f:
#
#        time=f["Traces"]["d - AvgTrace - t"][:]
#        Magvec=f["Traces"]["d - AvgTrace - Magvec"][:]
#        frequency=f["Data"]["Data"][:]
#    #    for name in f["Data"]:
#    #        print name
#
#    time=squeeze(time)
#    Magvec=squeeze(Magvec)
#    frequency=squeeze(frequency)
#
#    x = time[:,0]*1.0e6
#    y = frequency[0,:]/1.0e9
#    z=transpose(Magvec*1000.0)
#
#    ip=ImagePlot(xs,ys,z)
#    ip.configure_traits()

#class Image_Plot(Atom):
#    plot_control=Instance(Plot_Control)
#    xtitle=DelegatesTo('plot_control')
#    ytitle=DelegatesTo('plot_control')
#    ztitle=DelegatesTo('plot_control')
#    request_redraw=DelegatesTo('plot_control')
#    #ykeys=DelegatesTo('plot_control')
#    container = Typed(HPlotContainer)
#    color_plot = Typed(CMapImagePlot)
#    plot=Instance(Plot)
#    vertical_cross_plot = Typed(Plot)
#    horizontal_cross_plot = Typed(Plot)
#    colorbar = Typed(ColorBar)
#    pd_all = Instance(ArrayPlotData)
#    _image_index=Instance(GridDataSource)
#    _image_value=Instance(ImageData)
#    data=Dict()
#
#    pd=Instance(ArrayPlotData)
#
#    traits_view = View(Group(Item('container', editor=ComponentEditor(), show_label=False),
#                             orientation='horizontal'),
#        width=1000, height=700, resizable=True, title="Chaco Plot")
#
#    def _xtitle_changed(self):
#        self.horiz_cross_plot.x_axis.title=self.xtitle
#        self.plot.x_axis.title=self.xtitle
#
#    def _ytitle_changed(self):
#        self.vert_cross_plot.y_axis.title=self.ytitle
#        self.plot.y_axis.title=self.ytitle
#
#    def _request_redraw_fired(self):
#        self.color_plot.request_redraw()
#        self.horiz_cross_plot.request_redraw()
#        self.vert_cross_plot.request_redraw()
#
#    def __init__(self, data, plot_control):
#        super(Image_Plot, self).__init__()
#        self.plot_control=plot_control
#        z=zeros((len(data['y']['0']), len(data['x']['0'])))
#        z[:] = nan
#        for key, item in data['z'].iteritems():
#            z[int(key)]=item
#        x=data['x']['0']
#        y=data['y']['0']
#        self.pd = ArrayPlotData(z=z, x=x, y=y, horiz=z[0, :], vert=z[:, 0])
#        self.plot=Plot(self.pd, padding=50, fill_padding=True,
#                        bgcolor="white", use_backbuffer=True,  unified_draw=True)
#        xgrid, ygrid = meshgrid(x, y)
#
#        color_plot=self.plot.img_plot('z', name="img_plot", xbounds=xgrid, ybounds=ygrid)[0]
#        self._image_index = color_plot.index #GridDataSource(xdata=x, ydata=y, sort_order=("ascending","ascending"))
#        self._image_index.on_trait_change(self._metadata_changed, "metadata_changed")
#        self._image_value=color_plot.value
#        self.value_range=DataRange1D(self._image_value)
#        color_plot.color_mapper = jet(self.value_range)
#        color_plot.tools.append(PanTool(color_plot,
#                                           constrain_key="shift"))
#        color_plot.overlays.append(ZoomTool(component=color_plot,
#                                            tool_mode="box", always_on=False))
#
#        color_plot.overlays.append(LineInspector(component=color_plot,
#                                               axis='index_x',
#                                               inspect_mode="indexed",
#                                               write_metadata=True,
#                                               is_listener=True,
#                                               color="white"))
#
#        color_plot.overlays.append(LineInspector(component=color_plot,
#                                               axis='index_y',
#                                               inspect_mode="indexed",
#                                               write_metadata=True,
#                                               color="white",
#                                               is_listener=True))
#
#        cbar_index_mapper = LinearMapper(range=self.value_range)
#        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
#                                 plot=color_plot,
#                                 padding_top=color_plot.padding_top,
#                                 padding_bottom=color_plot.padding_bottom,
#                                 padding_right=40,
#                                 resizable='v',
#                                 width=30)#, ytitle="Magvec (mV)")
#
#        #create horizontal line plot
#        self.horiz_cross_plot = Plot(self.pd, resizable="h", height=100, padding=50)
#        self.horiz_cross_plot.plot(("x", "horiz"))#,
#        self.horiz_cross_plot.index_range = color_plot.index_range.x_range
#
#        #create vertical line plot
#        self.vert_cross_plot = Plot(self.pd, width = 100, orientation="v",
#                                resizable="v", padding=50, padding_bottom=250)
#        self.vert_cross_plot.plot(("y", "vert"))
#        self.vert_cross_plot.index_range = color_plot.index_range.y_range
#        #self.vert_cross_plot.x_axis.tick_label_formatter = lambda x: '%.2g'%x
#        self.color_plot=color_plot
#
#        self.container = HPlotContainer(padding=40, fill_padding=True,
#                                        bgcolor = "white", use_backbuffer=False)
#        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
#        inner_cont.add(self.horiz_cross_plot)
#        inner_cont.add(self.plot)
#        self.container.add(self.colorbar)
#        self.container.add(inner_cont)
#        self.container.add(self.vert_cross_plot)
#        #self.vert_cross_plot.y_axis.title="Frequency"
#        #self.horiz_cross_plot.x_axis.title="Time (us)"
#
#    def _metadata_changed(self, old, new):
#        if self._image_index.metadata.has_key("selections"):
#            x_ndx, y_ndx = self._image_index.metadata["selections"]
#            if y_ndx and x_ndx:
#                self.pd.set_data("horiz", self._image_value.data[y_ndx,:])
#                self.pd.set_data("vert", self._image_value.data[:,x_ndx])
#
#class Line_Plot(HasTraits):
#    plot_control=Instance(Plot_Control)
#    request_redraw=DelegatesTo('plot_control')
#    new_plot=DelegatesTo('plot_control')
#    xtitle=DelegatesTo('plot_control')
#    ytitle=DelegatesTo('plot_control')
#    title=DelegatesTo('plot_control')
#    show_legend=DelegatesTo('plot_control')
#    xyformat=DelegatesTo('plot_control')
#    plot=Instance(Plot)
#    keymap=DelegatesTo('plot_control') #Dict()
#    color_index=Int()
#    mycolors=List([ 'blue', 'red', 'green', 'purple',  'black', 'darkgray', 'cyan', 'magenta', 'orange'])
#    value_scale=DelegatesTo('plot_control')
#    index_scale=DelegatesTo('plot_control')
#    xcomplex=DelegatesTo('plot_control')
#    ycomplex=DelegatesTo('plot_control')
#
#    xkeys=DelegatesTo('plot_control')
#    zkeys=DelegatesTo('plot_control')
#    xindices=DelegatesTo('plot_control')
#    zindices=DelegatesTo('plot_control')
#    pd = Instance(ArrayPlotData)
#
#    def _value_scale_changed(self):
#         #if self.color_index!=0:
#             self.plot.value_scale = self.value_scale
#             self.plot.request_redraw()
#
#    def _index_scale_changed(self):
#         #if self.color_index!=0:
#             self.plot.index_scale = self.index_scale
#             self.plot.request_redraw()
#
#    def _show_legend_changed(self):
#        self.plot.legend.visible = self.show_legend
#        self.plot.request_redraw()
#
#    def _title_changed(self):
#        self.plot.title = self.title
#        self.plot.request_redraw()
#
#    def _xtitle_changed(self):
#        self.plot.x_axis.title=self.xtitle
#        self.plot.request_redraw()
#
#    def _ytitle_changed(self):
#        self.plot.y_axis.title=self.ytitle
#        self.plot.request_redraw()
#
#    def _request_redraw_fired(self):
#        self.plot.request_redraw()
#
#    def _new_plot_fired(self):
#        for key in self.plot.plots.keys():
#            self.remove_plot(key)
#        self.color_index=0
#        for n, name in enumerate(self.zkeys):
#            key='z'+str(name)
#            self.add_plot(key)
#
#    def _zkeys_changed(self,  name, old, new):
#        #print self.pd.list_data()
#        n=0
#        for key in self.pd.list_data():
#            if int(key[1:]) in new:
#                self.add_plot(key)
#                n=n+1
#            else:
#                self.remove_plot(key)
#
##        for key, plot in self.plot.plots.iteritems():
##            if int(key[1:]) in new:
##                if self.xyformat.t_color=="transparent" or self.xyformat.t_color==(1.0, 1.0, 1.0, 1.0) :
##                    color=self.mycolors[mod(n, len(self.mycolors))]
##                else:
##                    color=self.xyformat.t_color
##                plot[0].color=color
##                #plot[0].outline_color=self.xyformat.outline_color,
##                n=n+1
##
##            else:
##               plot[0].color="none"
#               #plot[0].outline_color="none"
#
#    def add_plot(self, key, z, xkey='x0', x=None):
#        if key not in self.plot.plots.keys() and key[0]!='x':
#            if self.xyformat.t_color=="transparent" or self.xyformat.t_color==(1.0, 1.0, 1.0, 1.0) :
#                color=self.mycolors[mod(self.color_index, len(self.mycolors))]
#            else:
#                color=self.xyformat.t_color
#
#            if x!=None:
#                self.pd.set_data(xkey, x)
#            self.pd.set_data(key, z)
#
#            #if self.color_index<len(self.xkeys):
#            #    xkey='x'+str(self.xkeys[self.color_index])
#            #else:
#            #    xkey='x'+str(self.xkeys[0])
#            self.plot.plot((xkey, key),
#                           name=key,
#                           type=self.xyformat.plot_type,
#                           line_width=self.xyformat.line_width,
#                           color=color,
#                           outline_color=self.xyformat.outline_color,
#                           marker = self.xyformat.marker,
#                           marker_size = self.xyformat.marker_size)
#            self.color_index=self.color_index+1
#
#    def remove_plot(self, key):
#        if key in self.plot.plots.keys():
#            self.plot.delplot(key)
#
#    def __init__(self, data, plot_control, *args, **kws):
#        super(Line_Plot, self).__init__(*args, **kws)
#        self.plot_control=plot_control
#        self.pd = ArrayPlotData()
#
#        for name, arr in sorted(data['z'].iteritems()):
#                self.pd.set_data('z'+str(name), arr)
#
#        for name, arr in sorted(data['x'].iteritems()):
#                self.pd.set_data('x'+str(name), arr)
#
#        plot = Plot(self.pd, padding=50, fill_padding=True,
#                        bgcolor="white", use_backbuffer=True)
#
#        # Attach some tools to the plot
#        plot.tools.append(PanTool(plot))
#        zoom = ZoomTool(component=plot, tool_mode="box", always_on=False)
#        plot.overlays.append(zoom)
#        plot.legend.tools.append(LegendTool(plot.legend, drag_button="right"))
#        self.plot=plot
#
#        for n, item in enumerate(self.zkeys):
#                key='z'+str(item)
#                if self.xyformat.t_color=="transparent" or self.xyformat.t_color==(1.0, 1.0, 1.0, 1.0) :
#                    color=self.mycolors[mod(n, len(self.mycolors))]
#                else:
#                    color=self.xyformat.t_color
#                if n<len(self.xkeys):
#                    xkey='x'+str(self.xkeys[n])
#                else:
#                    xkey='x'+str(self.xkeys[0])
#                #n=n+1
#                #self.pd.set_data(key, magphase(self._image_value.data[int(item)], self.ycomplex))
#                self.plot.plot((xkey, key),
#                                name=key,
#                                type=self.xyformat.plot_type,
#                                line_width=self.xyformat.line_width,
#                                color=color,
#                                outline_color=self.xyformat.outline_color,
#                                marker = self.xyformat.marker,
#                                marker_size = self.xyformat.marker_size)
#        self.plot.value_scale = self.value_scale
#        self.plot.index_scale= self.index_scale
#
#
#    traits_view = View(Item('plot', style='custom',editor=ComponentEditor(),
#                             show_label=False),
#                    resizable=True, title="Chaco Plot",
#                    width=800, height=700, #kind='modal',
#                    buttons=[OKButton, CancelButton]
#                    )
class ImagePlot(Atom):
    # container for all plots
    container = Typed(HPlotContainer)

    # Plot components within this container:
    color_plot = Typed(CMapImagePlot)
    vertical_cross_plot = Typed(Plot)
    horizontal_cross_plot = Typed(Plot)
    colorbar = Typed(ColorBar)

    # plot data
    pd_all = Typed(ArrayPlotData)
    #pd_horiz=Instance(ArrayPlotData)
    #pd_vert=Instance(ArrayPlotData)
    #private data storage
    _imag_index=Typed(GridDataSource)
    _image_value=Typed(ImageData)

    def __init__(self, x,y,z):
        super(ImagePlot, self).__init__()
        self.pd_all = ArrayPlotData(imagedata = z)
        #self.pd_horiz = ArrayPlotData(x=x, horiz=z[4, :])
        #self.pd_vert = ArrayPlotData(y=y, vert=z[:,5])

        self._imag_index = GridDataSource(xdata=x, ydata=y, sort_order=("ascending","ascending"))
        index_mapper = GridMapper(range=DataRange2D(self._imag_index))
        self._imag_index.on_trait_change(self._metadata_changed,
                                          "metadata_changed")
        self._image_value = ImageData(data=z, value_depth=1)
        color_mapper = jet(DataRange1D(self._image_value))

        self.color_plot= CMapImagePlot(
            index=self._imag_index,
            index_mapper=index_mapper,
            value=self._image_value,
            value_mapper=color_mapper,
            padding=20,
            use_backbuffer=True,
            unified_draw=True)

        #Add axes to image plot
        left = PlotAxis(orientation='left',
                        title= "Frequency (GHz)",
                        mapper=self.color_plot.index_mapper._ymapper,
                        component=self.color_plot)

        self.color_plot.overlays.append(left)

        bottom = PlotAxis(orientation='bottom',
                        title= "Time (us)",
                        mapper=self.color_plot.index_mapper._xmapper,
                        component=self.color_plot)
        self.color_plot.overlays.append(bottom)

        self.color_plot.tools.append(PanTool(self.color_plot,
                                           constrain_key="shift"))
        self.color_plot.overlays.append(ZoomTool(component=self.color_plot,
                                            tool_mode="box", always_on=False))

        #Add line inspector tool for horizontal and vertical
        self.color_plot.overlays.append(LineInspector(component=self.color_plot,
                                               axis='index_x',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               is_listener=True,
                                               color="white"))

        self.color_plot.overlays.append(LineInspector(component=self.color_plot,
                                               axis='index_y',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               color="white",
                                               is_listener=True))

        myrange = DataRange1D(low=amin(z),
                              high=amax(z))
        cmap=jet
        self.colormap = cmap(myrange)

        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=myrange)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=self.color_plot,
                                 padding_top=self.color_plot.padding_top,
                                 padding_bottom=self.color_plot.padding_bottom,
                                 padding_right=40,
                                 resizable='v',
                                 width=30)#, ytitle="Magvec (mV)")

        #create horizontal line plot
        self.horiz_cross_plot = Plot(self.pd_horiz, resizable="h")
        self.horiz_cross_plot.height = 100
        self.horiz_cross_plot.padding = 20
        self.horiz_cross_plot.plot(("x", "horiz"))#,
                             #line_style="dot")
#        self.cross_plot.plot(("scatter_index","scatter_value","scatter_color"),
#                             type="cmap_scatter",
#                             name="dot",
#                             color_mapper=self._cmap(image_value_range),
#                             marker="circle",
#                             marker_size=8)

        self.horiz_cross_plot.index_range = self.color_plot.index_range.x_range

        #create vertical line plot
        self.vert_cross_plot = Plot(self.pd_vert, width = 140, orientation="v",
                                resizable="v", padding=20, padding_bottom=160)
        self.vert_cross_plot.plot(("y", "vert"))#,
#                             line_style="dot")
       # self.vert_cross_plot.xtitle="Magvec (mV)"
 #       self.vertica_cross_plot.plot(("vertical_scatter_index",
 #                              "vertical_scatter_value",
 #                              "vertical_scatter_color"),
 #                            type="cmap_scatter",
 #                            name="dot",
 #                            color_mapper=self._cmap(image_value_range),
 #                            marker="circle",
  #                           marker_size=8)

        self.vert_cross_plot.index_range = self.color_plot.index_range.y_range

        # Create a container and add components
        self.container = HPlotContainer(padding=40, fill_padding=True,
                                        bgcolor = "white", use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.horiz_cross_plot)
        inner_cont.add(self.color_plot)
        self.container.add(self.colorbar)
        self.container.add(inner_cont)
        self.container.add(self.vert_cross_plot)

    def _metadata_changed(self, old, new):
        """ This function takes out a cross section from the image data, based
        on the line inspector selections, and updates the line and scatter
        plots."""

        #self.cross_plot.value_range.low = self.minz
        #self.cross_plot.value_range.high = self.maxz
        #self.cross_plot2.value_range.low = self.minz
        #self.cross_plot2.value_range.high = self.maxz
        if self._imag_index.metadata.has_key("selections"):
            x_ndx, y_ndx = self._imag_index.metadata["selections"]
            if y_ndx and x_ndx:
#                xdata, ydata = self._image_index.get_data()
#                xdata, ydata = xdata.get_data(), ydata.get_data()
                self.pd_horiz.set_data("horiz", self._image_value.data[y_ndx,:])
                self.pd_vert.set_data("vert", self._image_value.data[:,x_ndx])
#                    scatter_index=array([xdata[x_ndx]]),
#                    scatter_index2=array([ydata[y_ndx]]),
#                    scatter_value=array([self._image_value.data[y_ndx, x_ndx]]),
#                    scatter_value2=array([self._image_value.data[y_ndx, x_ndx]]),
#                    scatter_color=array([self._image_value.data[y_ndx, x_ndx]]),
#                    scatter_color2=array([self._image_value.data[y_ndx, x_ndx]])
#                )
#        else:
#            self.pd.update_data({"scatter_value": array([]),
#                "scatter_value2": array([]), "line_value": array([]),
#                "line_value2": array([])})


#if __name__ == "__main__":

#    filename="/Users/thomasaref/Dropbox/Dad stuff/sample3/digitizer/lt/sample3_digitizer_f_sweep_t_300mk_100nspulse.hdf5"
#
#    with h5py.File(filename, 'r') as f:
#
#        time=f["Traces"]["d - AvgTrace - t"][:]
#        Magvec=f["Traces"]["d - AvgTrace - Magvec"][:]
#        frequency=f["Data"]["Data"][:]
#    #    for name in f["Data"]:
#    #        print name
#
#    time=squeeze(time)
#    Magvec=squeeze(Magvec)
#    frequency=squeeze(frequency)
#
#    x = time[:,0]*1.0e6
#    y = frequency[0,:]/1.0e9
#    z=transpose(Magvec*1000.0)
#
#    ip=ImagePlot(xs,ys,z)
#    ip.configure_traits()

#class Image_Plot(Atom):
#    plot_control=Instance(Plot_Control)
#    xtitle=DelegatesTo('plot_control')
#    ytitle=DelegatesTo('plot_control')
#    ztitle=DelegatesTo('plot_control')
#    request_redraw=DelegatesTo('plot_control')
#    #ykeys=DelegatesTo('plot_control')
#    container = Typed(HPlotContainer)
#    color_plot = Typed(CMapImagePlot)
#    plot=Instance(Plot)
#    vertical_cross_plot = Typed(Plot)
#    horizontal_cross_plot = Typed(Plot)
#    colorbar = Typed(ColorBar)
#    pd_all = Instance(ArrayPlotData)
#    _image_index=Instance(GridDataSource)
#    _image_value=Instance(ImageData)
#    data=Dict()
#
#    pd=Instance(ArrayPlotData)
#
#    traits_view = View(Group(Item('container', editor=ComponentEditor(), show_label=False),
#                             orientation='horizontal'),
#        width=1000, height=700, resizable=True, title="Chaco Plot")
#
#    def _xtitle_changed(self):
#        self.horiz_cross_plot.x_axis.title=self.xtitle
#        self.plot.x_axis.title=self.xtitle
#
#    def _ytitle_changed(self):
#        self.vert_cross_plot.y_axis.title=self.ytitle
#        self.plot.y_axis.title=self.ytitle
#
#    def _request_redraw_fired(self):
#        self.color_plot.request_redraw()
#        self.horiz_cross_plot.request_redraw()
#        self.vert_cross_plot.request_redraw()
#
#    def __init__(self, data, plot_control):
#        super(Image_Plot, self).__init__()
#        self.plot_control=plot_control
#        z=zeros((len(data['y']['0']), len(data['x']['0'])))
#        z[:] = nan
#        for key, item in data['z'].iteritems():
#            z[int(key)]=item
#        x=data['x']['0']
#        y=data['y']['0']
#        self.pd = ArrayPlotData(z=z, x=x, y=y, horiz=z[0, :], vert=z[:, 0])
#        self.plot=Plot(self.pd, padding=50, fill_padding=True,
#                        bgcolor="white", use_backbuffer=True,  unified_draw=True)
#        xgrid, ygrid = meshgrid(x, y)
#
#        color_plot=self.plot.img_plot('z', name="img_plot", xbounds=xgrid, ybounds=ygrid)[0]
#        self._image_index = color_plot.index #GridDataSource(xdata=x, ydata=y, sort_order=("ascending","ascending"))
#        self._image_index.on_trait_change(self._metadata_changed, "metadata_changed")
#        self._image_value=color_plot.value
#        self.value_range=DataRange1D(self._image_value)
#        color_plot.color_mapper = jet(self.value_range)
#        color_plot.tools.append(PanTool(color_plot,
#                                           constrain_key="shift"))
#        color_plot.overlays.append(ZoomTool(component=color_plot,
#                                            tool_mode="box", always_on=False))
#
#        color_plot.overlays.append(LineInspector(component=color_plot,
#                                               axis='index_x',
#                                               inspect_mode="indexed",
#                                               write_metadata=True,
#                                               is_listener=True,
#                                               color="white"))
#
#        color_plot.overlays.append(LineInspector(component=color_plot,
#                                               axis='index_y',
#                                               inspect_mode="indexed",
#                                               write_metadata=True,
#                                               color="white",
#                                               is_listener=True))
#
#        cbar_index_mapper = LinearMapper(range=self.value_range)
#        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
#                                 plot=color_plot,
#                                 padding_top=color_plot.padding_top,
#                                 padding_bottom=color_plot.padding_bottom,
#                                 padding_right=40,
#                                 resizable='v',
#                                 width=30)#, ytitle="Magvec (mV)")
#
#        #create horizontal line plot
#        self.horiz_cross_plot = Plot(self.pd, resizable="h", height=100, padding=50)
#        self.horiz_cross_plot.plot(("x", "horiz"))#,
#        self.horiz_cross_plot.index_range = color_plot.index_range.x_range
#
#        #create vertical line plot
#        self.vert_cross_plot = Plot(self.pd, width = 100, orientation="v",
#                                resizable="v", padding=50, padding_bottom=250)
#        self.vert_cross_plot.plot(("y", "vert"))
#        self.vert_cross_plot.index_range = color_plot.index_range.y_range
#        #self.vert_cross_plot.x_axis.tick_label_formatter = lambda x: '%.2g'%x
#        self.color_plot=color_plot
#
#        self.container = HPlotContainer(padding=40, fill_padding=True,
#                                        bgcolor = "white", use_backbuffer=False)
#        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
#        inner_cont.add(self.horiz_cross_plot)
#        inner_cont.add(self.plot)
#        self.container.add(self.colorbar)
#        self.container.add(inner_cont)
#        self.container.add(self.vert_cross_plot)
#        #self.vert_cross_plot.y_axis.title="Frequency"
#        #self.horiz_cross_plot.x_axis.title="Time (us)"
#
#    def _metadata_changed(self, old, new):
#        if self._image_index.metadata.has_key("selections"):
#            x_ndx, y_ndx = self._image_index.metadata["selections"]
#            if y_ndx and x_ndx:
#                self.pd.set_data("horiz", self._image_value.data[y_ndx,:])
#                self.pd.set_data("vert", self._image_value.data[:,x_ndx])
#
#class Line_Plot(HasTraits):
#    plot_control=Instance(Plot_Control)
#    request_redraw=DelegatesTo('plot_control')
#    new_plot=DelegatesTo('plot_control')
#    xtitle=DelegatesTo('plot_control')
#    ytitle=DelegatesTo('plot_control')
#    title=DelegatesTo('plot_control')
#    show_legend=DelegatesTo('plot_control')
#    xyformat=DelegatesTo('plot_control')
#    plot=Instance(Plot)
#    keymap=DelegatesTo('plot_control') #Dict()
#    color_index=Int()
#    mycolors=List([ 'blue', 'red', 'green', 'purple',  'black', 'darkgray', 'cyan', 'magenta', 'orange'])
#    value_scale=DelegatesTo('plot_control')
#    index_scale=DelegatesTo('plot_control')
#    xcomplex=DelegatesTo('plot_control')
#    ycomplex=DelegatesTo('plot_control')
#
#    xkeys=DelegatesTo('plot_control')
#    zkeys=DelegatesTo('plot_control')
#    xindices=DelegatesTo('plot_control')
#    zindices=DelegatesTo('plot_control')
#    pd = Instance(ArrayPlotData)
#
#    def _value_scale_changed(self):
#         #if self.color_index!=0:
#             self.plot.value_scale = self.value_scale
#             self.plot.request_redraw()
#
#    def _index_scale_changed(self):
#         #if self.color_index!=0:
#             self.plot.index_scale = self.index_scale
#             self.plot.request_redraw()
#
#    def _show_legend_changed(self):
#        self.plot.legend.visible = self.show_legend
#        self.plot.request_redraw()
#
#    def _title_changed(self):
#        self.plot.title = self.title
#        self.plot.request_redraw()
#
#    def _xtitle_changed(self):
#        self.plot.x_axis.title=self.xtitle
#        self.plot.request_redraw()
#
#    def _ytitle_changed(self):
#        self.plot.y_axis.title=self.ytitle
#        self.plot.request_redraw()
#
#    def _request_redraw_fired(self):
#        self.plot.request_redraw()
#
#    def _new_plot_fired(self):
#        for key in self.plot.plots.keys():
#            self.remove_plot(key)
#        self.color_index=0
#        for n, name in enumerate(self.zkeys):
#            key='z'+str(name)
#            self.add_plot(key)
#
#    def _zkeys_changed(self,  name, old, new):
#        #print self.pd.list_data()
#        n=0
#        for key in self.pd.list_data():
#            if int(key[1:]) in new:
#                self.add_plot(key)
#                n=n+1
#            else:
#                self.remove_plot(key)
#
##        for key, plot in self.plot.plots.iteritems():
##            if int(key[1:]) in new:
##                if self.xyformat.t_color=="transparent" or self.xyformat.t_color==(1.0, 1.0, 1.0, 1.0) :
##                    color=self.mycolors[mod(n, len(self.mycolors))]
##                else:
##                    color=self.xyformat.t_color
##                plot[0].color=color
##                #plot[0].outline_color=self.xyformat.outline_color,
##                n=n+1
##
##            else:
##               plot[0].color="none"
#               #plot[0].outline_color="none"
#
#    def add_plot(self, key, z, xkey='x0', x=None):
#        if key not in self.plot.plots.keys() and key[0]!='x':
#            if self.xyformat.t_color=="transparent" or self.xyformat.t_color==(1.0, 1.0, 1.0, 1.0) :
#                color=self.mycolors[mod(self.color_index, len(self.mycolors))]
#            else:
#                color=self.xyformat.t_color
#
#            if x!=None:
#                self.pd.set_data(xkey, x)
#            self.pd.set_data(key, z)
#
#            #if self.color_index<len(self.xkeys):
#            #    xkey='x'+str(self.xkeys[self.color_index])
#            #else:
#            #    xkey='x'+str(self.xkeys[0])
#            self.plot.plot((xkey, key),
#                           name=key,
#                           type=self.xyformat.plot_type,
#                           line_width=self.xyformat.line_width,
#                           color=color,
#                           outline_color=self.xyformat.outline_color,
#                           marker = self.xyformat.marker,
#                           marker_size = self.xyformat.marker_size)
#            self.color_index=self.color_index+1
#
#    def remove_plot(self, key):
#        if key in self.plot.plots.keys():
#            self.plot.delplot(key)
#
#    def __init__(self, data, plot_control, *args, **kws):
#        super(Line_Plot, self).__init__(*args, **kws)
#        self.plot_control=plot_control
#        self.pd = ArrayPlotData()
#
#        for name, arr in sorted(data['z'].iteritems()):
#                self.pd.set_data('z'+str(name), arr)
#
#        for name, arr in sorted(data['x'].iteritems()):
#                self.pd.set_data('x'+str(name), arr)
#
#        plot = Plot(self.pd, padding=50, fill_padding=True,
#                        bgcolor="white", use_backbuffer=True)
#
#        # Attach some tools to the plot
#        plot.tools.append(PanTool(plot))
#        zoom = ZoomTool(component=plot, tool_mode="box", always_on=False)
#        plot.overlays.append(zoom)
#        plot.legend.tools.append(LegendTool(plot.legend, drag_button="right"))
#        self.plot=plot
#
#        for n, item in enumerate(self.zkeys):
#                key='z'+str(item)
#                if self.xyformat.t_color=="transparent" or self.xyformat.t_color==(1.0, 1.0, 1.0, 1.0) :
#                    color=self.mycolors[mod(n, len(self.mycolors))]
#                else:
#                    color=self.xyformat.t_color
#                if n<len(self.xkeys):
#                    xkey='x'+str(self.xkeys[n])
#                else:
#                    xkey='x'+str(self.xkeys[0])
#                #n=n+1
#                #self.pd.set_data(key, magphase(self._image_value.data[int(item)], self.ycomplex))
#                self.plot.plot((xkey, key),
#                                name=key,
#                                type=self.xyformat.plot_type,
#                                line_width=self.xyformat.line_width,
#                                color=color,
#                                outline_color=self.xyformat.outline_color,
#                                marker = self.xyformat.marker,
#                                marker_size = self.xyformat.marker_size)
#        self.plot.value_scale = self.value_scale
#        self.plot.index_scale= self.index_scale
#
#
#    traits_view = View(Item('plot', style='custom',editor=ComponentEditor(),
#                             show_label=False),
#                    resizable=True, title="Chaco Plot",
#                    width=800, height=700, #kind='modal',
#                    buttons=[OKButton, CancelButton]
#                    )
class myImagePlot(HasTraits):
    # container for all plots
    container = Instance(HPlotContainer)
    
    # Plot components within this container:
    color_plot = Instance(CMapImagePlot)
    vertical_cross_plot = Instance(Plot)
    horizontal_cross_plot = Instance(Plot)
    colorbar = Instance(ColorBar)
    
    # plot data
    pd_all = Instance(ArrayPlotData)
    pd_horiz=Instance(ArrayPlotData)
    pd_vert=Instance(ArrayPlotData)
    #private data storage
    _imag_index=Instance(GridDataSource)
    _image_value=Instance(ImageData)   
    
    traits_view = View(
        Item('container', editor=ComponentEditor(), show_label=False),
        width=1000, height=700, resizable=True, title="Chaco Plot")

    def __init__(self, x,y,z):
        super(myImagePlot, self).__init__()
        self.pd_all = ArrayPlotData(imagedata = z)
        self.pd_horiz = ArrayPlotData(x=x, horiz=z[4, :])
        self.pd_vert = ArrayPlotData(y=y, vert=z[:,5])
    
        self._imag_index = GridDataSource(xdata=x, ydata=y, sort_order=("ascending","ascending"))
        index_mapper = GridMapper(range=DataRange2D(self._imag_index))
        self._imag_index.on_trait_change(self._metadata_changed,
                                          "metadata_changed")
        self._image_value = ImageData(data=z, value_depth=1)
        color_mapper = jet(DataRange1D(self._image_value))

        self.color_plot= CMapImagePlot(
            index=self._imag_index,
            index_mapper=index_mapper,
            value=self._image_value,
            value_mapper=color_mapper,
            padding=20,
            use_backbuffer=True,
            unified_draw=True)

        #Add axes to image plot            
        left = PlotAxis(orientation='left',
                        title= "Frequency (GHz)",
                        mapper=self.color_plot.index_mapper._ymapper,
                        component=self.color_plot)

        self.color_plot.overlays.append(left)
        
        bottom = PlotAxis(orientation='bottom',
                        title= "Time (us)",
                        mapper=self.color_plot.index_mapper._xmapper,
                        component=self.color_plot)
        self.color_plot.overlays.append(bottom)

        self.color_plot.tools.append(PanTool(self.color_plot,
                                           constrain_key="shift"))
        self.color_plot.overlays.append(ZoomTool(component=self.color_plot,
                                            tool_mode="box", always_on=False))
                                            
        #Add line inspector tool for horizontal and vertical
        self.color_plot.overlays.append(LineInspector(component=self.color_plot,
                                               axis='index_x',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               is_listener=True,
                                               color="white"))

        self.color_plot.overlays.append(LineInspector(component=self.color_plot,
                                               axis='index_y',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               color="white",
                                               is_listener=True))         

        myrange = DataRange1D(low=amin(z),
                              high=amax(z))
        cmap=jet                         
        self.colormap = cmap(myrange)

        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=myrange)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=self.color_plot,
                                 padding_top=self.color_plot.padding_top,
                                 padding_bottom=self.color_plot.padding_bottom,
                                 padding_right=40,
                                 resizable='v',
                                 width=30)#, ytitle="Magvec (mV)")

        #create horizontal line plot
        self.horiz_cross_plot = Plot(self.pd_horiz, resizable="h")
        self.horiz_cross_plot.height = 100
        self.horiz_cross_plot.padding = 20
        self.horiz_cross_plot.plot(("x", "horiz"))#,
                             #line_style="dot")
#        self.cross_plot.plot(("scatter_index","scatter_value","scatter_color"),
#                             type="cmap_scatter",
#                             name="dot",
#                             color_mapper=self._cmap(image_value_range),
#                             marker="circle",
#                             marker_size=8)

        self.horiz_cross_plot.index_range = self.color_plot.index_range.x_range

        #create vertical line plot
        self.vert_cross_plot = Plot(self.pd_vert, width = 140, orientation="v", 
                                resizable="v", padding=20, padding_bottom=160)
        self.vert_cross_plot.plot(("y", "vert"))#,
#                             line_style="dot")
       # self.vert_cross_plot.xtitle="Magvec (mV)"
 #       self.vertica_cross_plot.plot(("vertical_scatter_index",
 #                              "vertical_scatter_value",
 #                              "vertical_scatter_color"),
 #                            type="cmap_scatter",
 #                            name="dot",
 #                            color_mapper=self._cmap(image_value_range),
 #                            marker="circle",
  #                           marker_size=8)

        self.vert_cross_plot.index_range = self.color_plot.index_range.y_range

        # Create a container and add components
        self.container = HPlotContainer(padding=40, fill_padding=True,
                                        bgcolor = "white", use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.horiz_cross_plot)
        inner_cont.add(self.color_plot)
        self.container.add(self.colorbar)
        self.container.add(inner_cont)
        self.container.add(self.vert_cross_plot)
        
    def _metadata_changed(self, old, new):
        """ This function takes out a cross section from the image data, based
        on the line inspector selections, and updates the line and scatter
        plots."""

        #self.cross_plot.value_range.low = self.minz
        #self.cross_plot.value_range.high = self.maxz
        #self.cross_plot2.value_range.low = self.minz
        #self.cross_plot2.value_range.high = self.maxz
        if self._imag_index.metadata.has_key("selections"):
            x_ndx, y_ndx = self._imag_index.metadata["selections"]
            if y_ndx and x_ndx:
#                xdata, ydata = self._image_index.get_data()
#                xdata, ydata = xdata.get_data(), ydata.get_data()
                self.pd_horiz.set_data("horiz", self._image_value.data[y_ndx,:])
                self.pd_vert.set_data("vert", self._image_value.data[:,x_ndx])
class ImageGUI(HasTraits):

    # TO FIX : put here the last available shot
    shot = File("L:\\data\\app3\\2011\\1108\\110823\\column_5200.ascii")

    # ---------------------------------------------------------------------------
    # Traits View Definitions
    # ---------------------------------------------------------------------------

    traits_view = View(
        HSplit(
            Item(
                "shot",
                style="custom",
                editor=FileEditor(filter=["column_*.ascii"]),
                show_label=False,
                resizable=True,
                width=400,
            ),
            Item("container", editor=ComponentEditor(), show_label=False, width=800, height=800),
        ),
        width=1200,
        height=800,
        resizable=True,
        title="APPARATUS 3 :: Analyze Images",
    )

    plot_edit_view = View(Group(Item("num_levels"), Item("colormap")), buttons=["OK", "Cancel"])

    num_levels = Int(15)
    colormap = Enum(color_map_name_dict.keys())

    # ---------------------------------------------------------------------------
    # Private Traits
    # ---------------------------------------------------------------------------

    _image_index = Instance(GridDataSource)
    _image_value = Instance(ImageData)

    _cmap = Trait(jet, Callable)

    # ---------------------------------------------------------------------------
    # Public View interface
    # ---------------------------------------------------------------------------

    def __init__(self, *args, **kwargs):
        super(ImageGUI, self).__init__(*args, **kwargs)
        self.create_plot()

    def create_plot(self):

        # Create the mapper, etc
        self._image_index = GridDataSource(array([]), array([]), sort_order=("ascending", "ascending"))
        image_index_range = DataRange2D(self._image_index)

        self._image_index.on_trait_change(self._metadata_changed, "metadata_changed")

        self._image_value = ImageData(data=array([]), value_depth=1)
        image_value_range = DataRange1D(self._image_value)

        # Create the image plot
        self.imgplot = CMapImagePlot(
            index=self._image_index,
            value=self._image_value,
            index_mapper=GridMapper(range=image_index_range),
            color_mapper=self._cmap(image_value_range),
        )

        # Create the contour plots
        # ~ self.polyplot = ContourPolyPlot(index=self._image_index,
        # ~ value=self._image_value,
        # ~ index_mapper=GridMapper(range=
        # ~ image_index_range),
        # ~ color_mapper=\
        # ~ self._cmap(image_value_range),
        # ~ levels=self.num_levels)

        # ~ self.lineplot = ContourLinePlot(index=self._image_index,
        # ~ value=self._image_value,
        # ~ index_mapper=GridMapper(range=
        # ~ self.polyplot.index_mapper.range),
        # ~ levels=self.num_levels)

        # Add a left axis to the plot
        left = PlotAxis(
            orientation="left", title="axial", mapper=self.imgplot.index_mapper._ymapper, component=self.imgplot
        )
        self.imgplot.overlays.append(left)

        # Add a bottom axis to the plot
        bottom = PlotAxis(
            orientation="bottom", title="radial", mapper=self.imgplot.index_mapper._xmapper, component=self.imgplot
        )
        self.imgplot.overlays.append(bottom)

        # Add some tools to the plot
        # ~ self.polyplot.tools.append(PanTool(self.polyplot,
        # ~ constrain_key="shift"))
        self.imgplot.overlays.append(ZoomTool(component=self.imgplot, tool_mode="box", always_on=False))
        self.imgplot.overlays.append(
            LineInspector(
                component=self.imgplot,
                axis="index_x",
                inspect_mode="indexed",
                write_metadata=True,
                is_listener=False,
                color="white",
            )
        )
        self.imgplot.overlays.append(
            LineInspector(
                component=self.imgplot,
                axis="index_y",
                inspect_mode="indexed",
                write_metadata=True,
                color="white",
                is_listener=False,
            )
        )

        # Add these two plots to one container
        contour_container = OverlayPlotContainer(padding=20, use_backbuffer=True, unified_draw=True)
        contour_container.add(self.imgplot)
        # ~ contour_container.add(self.polyplot)
        # ~ contour_container.add(self.lineplot)

        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=image_value_range)
        self.colorbar = ColorBar(
            index_mapper=cbar_index_mapper,
            plot=self.imgplot,
            padding_top=self.imgplot.padding_top,
            padding_bottom=self.imgplot.padding_bottom,
            padding_right=40,
            resizable="v",
            width=30,
        )

        # Create the two cross plots
        self.pd = ArrayPlotData(
            line_index=array([]),
            line_value=array([]),
            scatter_index=array([]),
            scatter_value=array([]),
            scatter_color=array([]),
        )

        self.cross_plot = Plot(self.pd, resizable="h")
        self.cross_plot.height = 100
        self.cross_plot.padding = 20
        self.cross_plot.plot(("line_index", "line_value"), line_style="dot")
        self.cross_plot.plot(
            ("scatter_index", "scatter_value", "scatter_color"),
            type="cmap_scatter",
            name="dot",
            color_mapper=self._cmap(image_value_range),
            marker="circle",
            marker_size=8,
        )

        self.cross_plot.index_range = self.imgplot.index_range.x_range

        self.pd.set_data("line_index2", array([]))
        self.pd.set_data("line_value2", array([]))
        self.pd.set_data("scatter_index2", array([]))
        self.pd.set_data("scatter_value2", array([]))
        self.pd.set_data("scatter_color2", array([]))

        self.cross_plot2 = Plot(self.pd, width=140, orientation="v", resizable="v", padding=20, padding_bottom=160)
        self.cross_plot2.plot(("line_index2", "line_value2"), line_style="dot")
        self.cross_plot2.plot(
            ("scatter_index2", "scatter_value2", "scatter_color2"),
            type="cmap_scatter",
            name="dot",
            color_mapper=self._cmap(image_value_range),
            marker="circle",
            marker_size=8,
        )

        self.cross_plot2.index_range = self.imgplot.index_range.y_range

        # Create a container and add components
        self.container = HPlotContainer(padding=40, fill_padding=True, bgcolor="white", use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.cross_plot)
        inner_cont.add(contour_container)
        self.container.add(self.colorbar)
        self.container.add(inner_cont)
        self.container.add(self.cross_plot2)

    def update(self):
        imgdata = self.load_imagedata()
        if imgdata is not None:
            self.minz = imgdata.min()
            self.maxz = imgdata.max()
            self.colorbar.index_mapper.range.low = self.minz
            self.colorbar.index_mapper.range.high = self.maxz
            xs = numpy.linspace(0, imgdata.shape[0], imgdata.shape[0] + 1)
            ys = numpy.linspace(0, imgdata.shape[1], imgdata.shape[1] + 1)
            print xs
            print ys
            self._image_index.set_data(xs, ys)
            self._image_value.data = imgdata
            self.pd.set_data("line_index", xs)
            self.pd.set_data("line_index2", ys)
            self.container.invalidate_draw()
            self.container.request_redraw()

    def load_imagedata(self):
        try:
            dir = self.shot[self.shot.index(":\\") + 2 : self.shot.rindex("\\") + 1]
            shotnum = self.shot[self.shot.rindex("_") + 1 : self.shot.rindex(".ascii")]
        except ValueError:
            print " *** Not a valid column density path *** "
            return None
        # Set data path
        # Prepare PlotData object
        print dir
        print shotnum
        return load(dir, shotnum)

    # ---------------------------------------------------------------------------
    # Event handlers
    # ---------------------------------------------------------------------------

    def _shot_changed(self):
        self.update()

    def _metadata_changed(self, old, new):
        """ This function takes out a cross section from the image data, based
        on the line inspector selections, and updates the line and scatter
        plots."""

        self.cross_plot.value_range.low = self.minz
        self.cross_plot.value_range.high = self.maxz
        self.cross_plot2.value_range.low = self.minz
        self.cross_plot2.value_range.high = self.maxz
        if self._image_index.metadata.has_key("selections"):
            x_ndx, y_ndx = self._image_index.metadata["selections"]
            if y_ndx and x_ndx:
                self.pd.set_data("line_value", self._image_value.data[y_ndx, :])
                self.pd.set_data("line_value2", self._image_value.data[:, x_ndx])
                xdata, ydata = self._image_index.get_data()
                xdata, ydata = xdata.get_data(), ydata.get_data()
                self.pd.set_data("scatter_index", array([xdata[x_ndx]]))
                self.pd.set_data("scatter_index2", array([ydata[y_ndx]]))
                self.pd.set_data("scatter_value", array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_value2", array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_color", array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_color2", array([self._image_value.data[y_ndx, x_ndx]]))
        else:
            self.pd.set_data("scatter_value", array([]))
            self.pd.set_data("scatter_value2", array([]))
            self.pd.set_data("line_value", array([]))
            self.pd.set_data("line_value2", array([]))

    def _colormap_changed(self):
        self._cmap = color_map_name_dict[self.colormap]
        if hasattr(self, "polyplot"):
            value_range = self.polyplot.color_mapper.range
            self.polyplot.color_mapper = self._cmap(value_range)
            value_range = self.cross_plot.color_mapper.range
            self.cross_plot.color_mapper = self._cmap(value_range)
            # FIXME: change when we decide how best to update plots using
            # the shared colormap in plot object
            self.cross_plot.plots["dot"][0].color_mapper = self._cmap(value_range)
            self.cross_plot2.plots["dot"][0].color_mapper = self._cmap(value_range)
            self.container.request_redraw()

    def _num_levels_changed(self):
        if self.num_levels > 3:
            self.polyplot.levels = self.num_levels
            self.lineplot.levels = self.num_levels