示例#1
0
 def test_empty(self):
     ds = GridDataSource()
     self.assert_(ds.sort_order == ('none', 'none'))
     self.assert_(ds.index_dimension == 'image')
     self.assert_(ds.value_dimension == 'scalar')
     self.assert_(ds.metadata == {"selections":[], "annotations":[]})
     xdata, ydata = ds.get_data()
     assert_ary_(xdata.get_data(), array([]))
     assert_ary_(ydata.get_data(), array([]))
     self.assert_(ds.get_bounds() == ((0,0),(0,0)))
 def test_empty(self):
     data_source = GridDataSource()
     self.assertEqual(data_source.sort_order, ('none', 'none'))
     self.assertEqual(data_source.index_dimension, 'image')
     self.assertEqual(data_source.value_dimension, 'scalar')
     self.assertEqual(data_source.metadata,
                      {"selections":[], "annotations":[]})
     xdata, ydata = data_source.get_data()
     assert_array_equal(xdata.get_data(), array([]))
     assert_array_equal(ydata.get_data(), array([]))
     self.assertEqual(data_source.get_bounds(), ((0,0),(0,0)))
示例#3
0
    def test_init(self):
        test_xd = array([1,2,3])
        test_yd = array([1.5, 0.5, -0.5, -1.5])
        test_sort_order = ('ascending', 'descending')

        ds = GridDataSource(xdata=test_xd, ydata=test_yd,
                            sort_order=test_sort_order)

        self.assert_(ds.sort_order == test_sort_order)
        xd, yd = ds.get_data()
        assert_ary_(xd.get_data(), test_xd)
        assert_ary_(yd.get_data(), test_yd)
        self.assert_(ds.get_bounds() == ((min(test_xd),min(test_yd)),
                                         (max(test_xd),max(test_yd))))
class GridDataSourceTestCase(UnittestTools, unittest.TestCase):
    def setUp(self):
        self.data_source = GridDataSource(xdata=array([1, 2, 3]),
                                          ydata=array([1.5, 0.5, -0.5, -1.5]),
                                          sort_order=('ascending',
                                                      'descending'))

    def test_empty(self):
        data_source = GridDataSource()
        self.assertEqual(data_source.sort_order, ('none', 'none'))
        self.assertEqual(data_source.index_dimension, 'image')
        self.assertEqual(data_source.value_dimension, 'scalar')
        self.assertEqual(data_source.metadata, {
            "selections": [],
            "annotations": []
        })
        xdata, ydata = data_source.get_data()
        assert_array_equal(xdata.get_data(), array([]))
        assert_array_equal(ydata.get_data(), array([]))
        self.assertEqual(data_source.get_bounds(), ((0, 0), (0, 0)))

    def test_init(self):
        test_xd = array([1, 2, 3])
        test_yd = array([1.5, 0.5, -0.5, -1.5])
        test_sort_order = ('ascending', 'descending')

        self.assertEqual(self.data_source.sort_order, test_sort_order)
        xd, yd = self.data_source.get_data()
        assert_array_equal(xd.get_data(), test_xd)
        assert_array_equal(yd.get_data(), test_yd)
        self.assertEqual(self.data_source.get_bounds(),
                         ((min(test_xd), min(test_yd)),
                          (max(test_xd), max(test_yd))))

    def test_set_data(self):

        test_xd = array([0, 2, 4])
        test_yd = array([0, 1, 2, 3, 4, 5])
        test_sort_order = ('none', 'none')

        self.data_source.set_data(xdata=test_xd,
                                  ydata=test_yd,
                                  sort_order=('none', 'none'))

        self.assertEqual(self.data_source.sort_order, test_sort_order)
        xd, yd = self.data_source.get_data()
        assert_array_equal(xd.get_data(), test_xd)
        assert_array_equal(yd.get_data(), test_yd)
        self.assertEqual(self.data_source.get_bounds(),
                         ((min(test_xd), min(test_yd)),
                          (max(test_xd), max(test_yd))))

    def test_metadata(self):
        self.assertEqual(self.data_source.metadata, {
            'annotations': [],
            'selections': []
        })

    def test_metadata_changed(self):
        with self.assertTraitChanges(self.data_source,
                                     'metadata_changed',
                                     count=1):
            self.data_source.metadata = {'new_metadata': True}

    def test_metadata_items_changed(self):
        with self.assertTraitChanges(self.data_source,
                                     'metadata_changed',
                                     count=1):
            self.data_source.metadata['new_metadata'] = True
示例#5
0
class ImagePlotUI(HasTraits):
    # Image data source
    image_data_source = Instance(ImageDataSource)

    # container for all plots
    container = Instance(HPlotContainer)

    # Plot components within this container
    plot = Instance(CMapImagePlot)
    colorbar = Instance(ColorBar)

    # View options
    colormap = Enum(colormaps)

    # Traits view definitions:
    traits_view = View(Group(
        UItem('container', editor=ComponentEditor(size=(500, 450)))),
                       resizable=True)

    plot_edit_view = View(Group(Item('colormap')), buttons=["OK", "Cancel"])

    # -------------------------------------------------------------------------
    # Private Traits
    # -------------------------------------------------------------------------

    _image_index = Instance(GridDataSource)
    _image_value = Instance(ImageData)

    _cmap = Trait(default_colormaps.gray, Callable)

    # -------------------------------------------------------------------------
    # Public View interface
    # -------------------------------------------------------------------------

    def __init__(self, image_data_source=None):
        super(ImagePlotUI, self).__init__()
        with errstate(invalid='ignore'):
            self.create_plot()
        self.image_data_source = image_data_source

    def create_plot(self):

        # Create the mapper, etc
        self._image_index = GridDataSource(array([]),
                                           array([]),
                                           sort_order=("ascending",
                                                       "ascending"))
        image_index_range = DataRange2D(self._image_index)
        # self._image_index.on_trait_change(self._metadata_changed,
        #                                   "metadata_changed")

        self._image_value = ImageData(data=array([]), value_depth=1)
        image_value_range = DataRange1D(self._image_value)

        # Create the colormapped scalar plot
        self.plot = CMapImagePlot(
            index=self._image_index,
            index_mapper=GridMapper(range=image_index_range),
            value=self._image_value,
            value_mapper=self._cmap(image_value_range))

        # Add a left axis to the plot
        left = PlotAxis(orientation='left',
                        title="y",
                        mapper=self.plot.index_mapper._ymapper,
                        component=self.plot)
        self.plot.overlays.append(left)

        # Add a bottom axis to the plot
        bottom = PlotAxis(orientation='bottom',
                          title="x",
                          mapper=self.plot.index_mapper._xmapper,
                          component=self.plot)
        self.plot.overlays.append(bottom)

        # Add some tools to the plot
        self.plot.tools.append(PanTool(self.plot, constrain_key="shift"))
        self.plot.overlays.append(
            ZoomTool(component=self.plot, tool_mode="box", always_on=False))

        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=image_value_range)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=self.plot,
                                 padding_top=self.plot.padding_top,
                                 padding_bottom=self.plot.padding_bottom,
                                 padding_right=40,
                                 resizable='v',
                                 width=10)

        # Create a container and add components
        self.container = HPlotContainer(padding=40,
                                        fill_padding=True,
                                        bgcolor="white",
                                        use_backbuffer=False)
        self.container.add(self.colorbar)
        self.container.add(self.plot)

    @on_trait_change('image_data_source.data_source_changed')
    def update_plot(self):
        xs = self.image_data_source.xs
        ys = self.image_data_source.ys
        zs = self.image_data_source.zs
        self.colorbar.index_mapper.range.low = zs.min()
        self.colorbar.index_mapper.range.high = zs.max()
        self._image_index.set_data(xs, ys)
        self._image_value.data = zs
        self.container.invalidate_draw()
        self.container.request_redraw()

    # -------------------------------------------------------------------------
    # Event handlers
    # -------------------------------------------------------------------------

    def _colormap_changed(self):
        self._cmap = default_colormaps.color_map_name_dict[self.colormap]
        if self.plot is not None:
            value_range = self.plot.color_mapper.range
            self.plot.color_mapper = self._cmap(value_range)
            self.container.request_redraw()
    def __init__(self, x, y, z):
        super(ImagePlot, self).__init__()
        self.pd_all = ArrayPlotData(imagedata=z)
        #self.pd_horiz = ArrayPlotData(x=x, horiz=z[4, :])
        #self.pd_vert = ArrayPlotData(y=y, vert=z[:,5])

        self._imag_index = GridDataSource(xdata=x,
                                          ydata=y,
                                          sort_order=("ascending",
                                                      "ascending"))
        index_mapper = GridMapper(range=DataRange2D(self._imag_index))
        self._imag_index.on_trait_change(self._metadata_changed,
                                         "metadata_changed")
        self._image_value = ImageData(data=z, value_depth=1)
        color_mapper = jet(DataRange1D(self._image_value))

        self.color_plot = CMapImagePlot(index=self._imag_index,
                                        index_mapper=index_mapper,
                                        value=self._image_value,
                                        value_mapper=color_mapper,
                                        padding=20,
                                        use_backbuffer=True,
                                        unified_draw=True)

        #Add axes to image plot
        left = PlotAxis(orientation='left',
                        title="Frequency (GHz)",
                        mapper=self.color_plot.index_mapper._ymapper,
                        component=self.color_plot)

        self.color_plot.overlays.append(left)

        bottom = PlotAxis(orientation='bottom',
                          title="Time (us)",
                          mapper=self.color_plot.index_mapper._xmapper,
                          component=self.color_plot)
        self.color_plot.overlays.append(bottom)

        self.color_plot.tools.append(
            PanTool(self.color_plot, constrain_key="shift"))
        self.color_plot.overlays.append(
            ZoomTool(component=self.color_plot,
                     tool_mode="box",
                     always_on=False))

        #Add line inspector tool for horizontal and vertical
        self.color_plot.overlays.append(
            LineInspector(component=self.color_plot,
                          axis='index_x',
                          inspect_mode="indexed",
                          write_metadata=True,
                          is_listener=True,
                          color="white"))

        self.color_plot.overlays.append(
            LineInspector(component=self.color_plot,
                          axis='index_y',
                          inspect_mode="indexed",
                          write_metadata=True,
                          color="white",
                          is_listener=True))

        myrange = DataRange1D(low=amin(z), high=amax(z))
        cmap = jet
        self.colormap = cmap(myrange)

        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=myrange)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=self.color_plot,
                                 padding_top=self.color_plot.padding_top,
                                 padding_bottom=self.color_plot.padding_bottom,
                                 padding_right=40,
                                 resizable='v',
                                 width=30)  #, ytitle="Magvec (mV)")

        #create horizontal line plot
        self.horiz_cross_plot = Plot(self.pd_horiz, resizable="h")
        self.horiz_cross_plot.height = 100
        self.horiz_cross_plot.padding = 20
        self.horiz_cross_plot.plot(("x", "horiz"))  #,
        #line_style="dot")
        #        self.cross_plot.plot(("scatter_index","scatter_value","scatter_color"),
        #                             type="cmap_scatter",
        #                             name="dot",
        #                             color_mapper=self._cmap(image_value_range),
        #                             marker="circle",
        #                             marker_size=8)

        self.horiz_cross_plot.index_range = self.color_plot.index_range.x_range

        #create vertical line plot
        self.vert_cross_plot = Plot(self.pd_vert,
                                    width=140,
                                    orientation="v",
                                    resizable="v",
                                    padding=20,
                                    padding_bottom=160)
        self.vert_cross_plot.plot(("y", "vert"))  #,
        #                             line_style="dot")
        # self.vert_cross_plot.xtitle="Magvec (mV)"
        #       self.vertica_cross_plot.plot(("vertical_scatter_index",
        #                              "vertical_scatter_value",
        #                              "vertical_scatter_color"),
        #                            type="cmap_scatter",
        #                            name="dot",
        #                            color_mapper=self._cmap(image_value_range),
        #                            marker="circle",
        #                           marker_size=8)

        self.vert_cross_plot.index_range = self.color_plot.index_range.y_range

        # Create a container and add components
        self.container = HPlotContainer(padding=40,
                                        fill_padding=True,
                                        bgcolor="white",
                                        use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.horiz_cross_plot)
        inner_cont.add(self.color_plot)
        self.container.add(self.colorbar)
        self.container.add(inner_cont)
        self.container.add(self.vert_cross_plot)
示例#7
0
class ImagePlot(HasTraits):

    #create UI interface
    plot = Instance(GridPlotContainer)
    LineScans = Instance(GridPlotContainer)
    value = Str
    saved = Bool(True)
    filedir = Str
    filename = Str
    path = Str
    xLo = Str
    xHi = Str
    yLo = Str
    yHi = Str
    Cmin = Str
    Cmax = Str
    Cmin2 = Str
    Cmax2 = Str
    loaded = False

    Rtrue = Bool(True)
    R2true = Bool(False)
    Fluortrue = Bool(False)
    #Catch variables
    bounds = []
    xLow = []
    yLow = []
    xHigh = []
    yHigh = []

    vscale = None
    notes = Str

    #UI buttons
    reload_button = Button("Reload")
    save_plot_button = Button("Save Plots..")
    save_LineScans_button = Button("Save LineScans..")
    load_button = Button("Load")
    load1_button =Button("Load 1")
    load2_button =Button("Load 2")
    load3_button =Button("Load 3")
    load4_button =Button("Load 4")
    catch_bounds = Button("Catch Bounds")
    to_csv = Button("Generate .CSV")
    scale_set = Button("Set Scale")
    reset = Button("Reset")
    nextfile = Button("Next")
    prevfile = Button("Prev")
    flatten = Button("Flatten")
    readline = Button("LineScan")
    colormap = Button("ApplyColorMap")
    genfilelist = Button("Generate File List")
    out_to_mat = Button("Generate .mat")
    special = Button("Special")

    presmooth = Bool
    plot1button = Bool
    plot2button = Bool
    plot3button = Bool
    plot4button = Bool
    fastline = Str
    slowline = Str

    fastscanline = 0
    slowscanline = 0

    fastline  = str(fastscanline)
    slowline  = str(slowscanline)

    fastlinevar = Str
    fastlinemean = Str
    fastlinemax = Str
    fastlinemin = Str
    slowlinevar = Str
    slowlinemean = Str
    slowlinemax = Str
    slowlinemin = Str

    x_f = Bool(label = "Fast Axis x:")
    y_f = Bool(label = "y:")
    z_f = Bool(label = "z:")
    x_s = Bool(label = "Slow Axis x:")
    y_s = Bool(label = "y:")
    z_s = Bool(label = "z:")
    xbounds=None
    ybounds=None



    #wildcard patterns
    file_wildcard = Str("zis File (*.zis)|*.zis|All files|*")
    file_wildcard2 = Str("png File (*.png)|*.png|All files|*")
    file_wildcard3 = Str("mat File (*.mat)|*.mat|All files|*")
    fastlinevars = Group(   Item('fastlinevar'),
                            Item('fastlinemean'),
                            Item('fastlinemax'),
                            Item('fastlinemin'),show_border=True)
    slowlinevars = Group(   Item('slowlinevar'),
                            Item('slowlinemean'),
                            Item('slowlinemax'),
                            Item('slowlinemin'),show_border=True)
    linescangroup = HGroup(Item('LineScans', editor=ComponentEditor(),show_label=False),
                        VGroup(
                            Item(name="plte", style = 'simple'),
                            HGroup(
                            Item('slowline'),
                            Item('fastline'),),
                            UItem('readline'),
                            fastlinevars,
                            slowlinevars
                            ),label = "LineScan")
    colorgroup = VGroup(HGroup(
                            Item('Cmin',width = -50),
                            Item('Cmax',width = -50),
                            Item('Cmin2',width = -50),
                            Item('Cmax2',width = -50),
                            ),
                        HGroup(
                            UItem('colormap'),
                            UItem('genfilelist'),
                            UItem('out_to_mat'),
                            UItem('special'),
                            ),
                        show_border =True)

    tabs = Group(Item('plot', editor=ComponentEditor(),show_label=False),
                linescangroup,
                Item('notes', style = 'custom', width=.1), layout = "tabbed")
    plta = Enum("R", "Phase", "R0","R/R0","Fluor", "X", "Y")
    pltb = Enum("R", "Phase", "R0","R/R0","Fluor", "X", "Y")
    pltc = Enum("R", "Phase", "R0","R/R0","Fluor", "X", "Y")
    pltd = Enum("R", "Phase", "R0","R/R0","Fluor", "X", "Y")
    plte = Enum("R", "Phase", "R0","R/R0","Fluor", "X", "Y")

    lefttop = Group(
                VGroup(
                    HGroup(
                        Item('value', label = "File", width = 300),
                        Item('presmooth'),
                        ),
                    HGroup(
                        UItem('save_plot_button'),
                        UItem('save_LineScans_button'),
                        UItem('load_button'),
                        UItem('nextfile'),
                        UItem('prevfile'),
                        ),
                    HGroup(
                    Item(name="plta", style = 'simple'),
                    Item(name="pltb", style = 'simple'),
                    Item(name="pltc", style = 'simple'),
                    Item(name="pltd", style = 'simple'),
                    UItem('reload_button'),
                    ),
                    ),
                )

    righttop = Group(
                    VGroup(
                        HGroup(
                            Item('xLo',width = -50, height = 25),
                            Item('xHi',width = -50),
                            Item('yLo',width = -50, height = 25),
                            Item('yHi',width = -50),
                            ),
                        HGroup(
                            UItem('flatten'),
                            UItem('catch_bounds'),
                            UItem('scale_set'),
                        ),
                        HGroup(
                            UItem("load1_button"),
                            UItem("load2_button"),
                            UItem("load3_button"),
                            UItem("load4_button"),
                        ),
                    ),
                show_border = True)

    traits_view = View(
            VGroup(
                    HGroup(
                        lefttop,
                        righttop,
                        colorgroup,
                    ),
                    tabs,
            ),
            width=1200, height=700, resizable=True, title="Scan Image Reconstruction")


    #USED DICTIONARY
    plotA = {"plot": plta, "data" : "", "shape" : "", "range" : np.zeros((2,2))}
    plotB = {"plot": pltb, "data" : "", "shape" : "", "range" : np.zeros((2,2))}
    plotC = {"plot": pltc, "data" : "", "shape" : "", "range" : np.zeros((2,2))}
    plotD = {"plot": pltd, "data" : "", "shape" : "", "range" : np.zeros((2,2))}
    plotE = {"plot": plte, "data" : "", "shape" : "", "range" : np.zeros((2,2))}

    def _Rtrue_changed(self):
        if self.Rtrue:
            self.R2true = False
            self.Fluortrue = False
    def _R2true_changed(self):
        if self.R2true:
            self.Rtrue = False
            self.Fluortrue = False
    def _Fluortrue_changed(self):
        if self.Fluortrue:
            self.Rtrue = False
            self.R2true = False
    def _plot1button_changed(self):
        if self.plot1button:
            self.plot2button = False
            self.plot3button = False
            self.plot4button = False
            self._plot2()
    def _plot2button_changed(self):
        if self.plot2button:
            self.plot1button=False
            self.plot3button = False
            self.plot4button = False
            self._plot2()
    def _plot3button_changed(self):
        if self.plot3button:
            self.plot1button=False
            self.plot2button = False
            self.plot4button = False
            self._plot2()
    def _plot4button_changed(self):
        if self.plot4button:
            self.plot1button=False
            self.plot2button = False
            self.plot3button = False
            self._plot2()

    def _colormap_fired(self):
        self.plot1.color_mapper.range.low = float(self.Cmin)
        self.plot1.color_mapper.range.high = float(self.Cmax)

        self.plot2.color_mapper.range.low = float(self.Cmin2)
        self.plot2.color_mapper.range.high = float(self.Cmax2)

        self.plot4.color_mapper.range.low = float(self.Cmin)
        self.plot4.color_mapper.range.high = float(self.Cmax)

        if self.plot1button or self.plot3button or self.plot4button:
            self.plot1lines.color_mapper.range.low = float(self.Cmin)
            self.plot1lines.color_mapper.range.high = float(self.Cmax)
        if self.plot2button:
            self.plot1lines.color_mapper.range.low = float(self.Cmin2)
            self.plot1lines.color_mapper.range.high = float(self.Cmax2)

        #self._plot()
        print "Color Mapped"
    def _scale_set_fired(self):

        self.plot1.range2d.x_range.low = float(self.xLo)
        self.plot1.range2d.x_range.high = float(self.xHi)
        self.plot1.range2d.y_range.low = float(self.yLo)
        self.plot1.range2d.y_range.high = float(self.yHi)

        self.plot2.range2d.x_range.low = float(self.xLo)
        self.plot2.range2d.x_range.high = float(self.xHi)
        self.plot2.range2d.y_range.low = float(self.yLo)
        self.plot2.range2d.y_range.high = float(self.yHi)

        self.plot3.range2d.x_range.low = float(self.xLo)
        self.plot3.range2d.x_range.high = float(self.xHi)
        self.plot3.range2d.y_range.low = float(self.yLo)
        self.plot3.range2d.y_range.high = float(self.yHi)

        self.plot4.range2d.x_range.low = float(self.xLo)
        self.plot4.range2d.x_range.high = float(self.xHi)
        self.plot4.range2d.y_range.low = float(self.yLo)
        self.plot4.range2d.y_range.high = float(self.yHi)

        self.plot1lines.range2d.x_range.low = float(self.xLo)/self.plotE["range"][0][1]*self.plotE["shape"][0]
        self.plot1lines.range2d.x_range.high = float(self.xHi)/self.plotE["range"][0][1]*self.plotE["shape"][0]
        self.plot1lines.range2d.y_range.low = float(self.yLo)/self.plotE["range"][1][1]*self.plotE["shape"][1]
        self.plot1lines.range2d.y_range.high = float(self.yHi)/self.plotE["range"][1][1]*self.plotE["shape"][1]

        self.slow_plot.range2d.x_range.low = float(self.xLo)
        self.slow_plot.range2d.x_range.high = float(self.xHi)
        self.fast_plot.range2d.x_range.low = float(self.yLo)
        self.fast_plot.range2d.x_range.high = float(self.yHi)


    def _reset_fired(self):
        self.vscale = None
        self.Cmin = ""
        self.Cmax = ""
        self._refresh()
    def _value_changed(self):
        #self.saved = False
        self.value = self.value

    def _refresh(self):
        try: self._plot()
        except: print "Option will be applied when plotting"

    def _readline_fired(self):

        #select which line to scan (this should be input as int value of line number until changed)

        self.fastscanline = int(self.fastline)#float(self.slowline)*float(self.dataset["range"][0][1])/self.dataset["shape"][0])
        self.slowscanline = int(self.slowline)#float(self.fastline)*float(self.dataset["range"][1][1])/self.dataset["shape"][1])

        slowstart = int(float(self.xLo)/float(self.plotE["range"][0][1])*self.plotE["shape"][0])
        slowend = int(float(self.xHi)/float(self.plotE["range"][0][1])*self.plotE["shape"][0]-1)
        faststart = int(float(self.yLo)/float(self.plotE["range"][1][1])*self.plotE["shape"][1])
        fastend = int(float(self.yHi)/float(self.plotE["range"][1][1])*(self.plotE["shape"][1])-1)

        fastarray = np.array(self._image_value.data[:,int(self.slowline)])
        slowarray = np.array(self._image_value.data[int(self.fastline),:])

        self.pd.set_data("line_value2",
                                 self._image_value.data[:,self.fastscanline])
        self.pd.set_data("line_value",
                                 self._image_value.data[self.slowscanline,:])

        self.fastlinevar = str(np.std(fastarray))
        self.fastlinemean = str(np.mean(fastarray))
        self.fastlinemax = str(np.amax(fastarray))
        self.fastlinemin = str(np.amin(fastarray))
        self.slowlinevar = str(np.std(slowarray))
        self.slowlinemean = str(np.mean(slowarray))
        self.slowlinemax = str(np.amax(slowarray))
        self.slowlinemin = str(np.amin(slowarray))

        self.slow_plot.title = "Slowline : " + str(self.fastscanline)
        self.fast_plot.title = "Fastline : " + str(self.fastscanline)


    def _flatten_fired(self):
        self._flatten()

    def _out_to_mat_fired(self):
        dialog = FileDialog(default_filename = self.filename+"_MATLAB_Phase"+str(self.plotA["notes"]["start"][2]), action="save as", wildcard=self.file_wildcard3)
        dialog.open()
        if dialog.return_code == OK:
            savefiledir = dialog.directory
            savefilename = dialog.filename
            path = os.path.join(savefiledir, savefilename)
            dataset = {self.plotA, self.plotB, self.plotC, self.plotD}
            scio.savemat(path, dataset, True)

    def _flatten(self):
        y=0
        x=0
        for line in self.plotA['data']:
            x_axis = np.arange(0,int(len(line)))
            y_axis = np.array(line)
            slope,intercept,r_value,p_value,std_err = stats.linregress(x_axis,y_axis)
            x=0
            for point in line:
                self.plotA['data'][y,x] = point - (slope*x + intercept)
                x+=1
            y+=1

        y=0
        x=0
        for line in self.plotB['data']:
            x_axis = np.arange(0,int(len(line)))
            y_axis = np.array(line)
            slope,intercept,r_value,p_value,std_err = stats.linregress(x_axis,y_axis)
            x=0
            for point in line:
                self.plotB['data'][y,x] = point - (slope*x + intercept)
                x+=1
            y+=1
        y=0
        x=0
        for line in self.plotC['data']:
            x_axis = np.arange(0,int(len(line)))
            y_axis = np.array(line)
            slope,intercept,r_value,p_value,std_err = stats.linregress(x_axis,y_axis)
            x=0
            for point in line:
                self.plotC['data'][y,x] = point - (slope*x + intercept)
                x+=1
            y+=1
        y=0
        x=0
        for line in self.plotD['data']:
            x_axis = np.arange(0,int(len(line)))
            y_axis = np.array(line)
            slope,intercept,r_value,p_value,std_err = stats.linregress(x_axis,y_axis)
            x=0
            for point in line:
                self.plotD['data'][y,x] = point - (slope*x + intercept)
                x+=1
            y+=1
        self._plot()
        print "Flattened"



    def _catch_bounds_fired(self):
        try:
            self.xLow = float(self.plot1.range2d.x_range.low)
            self.xHigh = float(self.plot1.range2d.x_range.high)
            self.yLow = float(self.plot1.range2d.y_range.low)
            self.yHigh = float(self.plot1.range2d.y_range.high)

            self.xLo = str(self.xLow)
            self.xHi = str(self.xHigh)
            self.yLo = str(self.yLow)
            self.yHi = str(self.yHigh)
        except: print "Please plot first"

    def _save_plot_button_fired(self):
        dialog = FileDialog(default_filename = self.filename+"_Plots_", action="save as", wildcard=self.file_wildcard2)
        dialog.open()
        if dialog.return_code == OK:
            savefiledir = dialog.directory
            savefilename = dialog.filename
            path = os.path.join(savefiledir, savefilename)
            #self.plot.do_layout(force=True)
            plot_gc = PlotGraphicsContext(self.plot.outer_bounds)
            plot_gc.render_component(self.plot)
            plot_gc.save(path)

    def _save_LineScans_button_fired(self):
        dialog = FileDialog(default_filename = self.filename+"_LineScan_", action="save as", wildcard=self.file_wildcard2)
        dialog.open()
        if dialog.return_code == OK:
            savefiledir = dialog.directory
            savefilename = dialog.filename
            path = os.path.join(savefiledir, savefilename)
            self.LineScans.do_layout(force=True)
            plot_gc = PlotGraphicsContext(self.LineScans.outer_bounds)
            plot_gc.render_component(self.LineScans)
            plot_gc.save(path)




    def _load_button_fired(self):
        dialog = FileDialog(action="open", wildcard=self.file_wildcard)
        dialog.open()
        if dialog.return_code == OK:
            self.value = dialog.filename
            title = self.value
            self.path = dialog.path
            self.filedir = dialog.directory
            self.allfiles =[]

            for filenames in os.walk(self.filedir):
                for files in filenames:
                    for afile in files:
                        if ".zis" in str(afile):
                            if ".png" not in str(afile)and ".mat" not in str(afile):
                                self.allfiles.append(self.filedir+"\\"+afile)
            self.filename = dialog.filename
            self.saved = True
            self.loaded = True
            self._loader(self.path)
            self._plot()

    def _nextfile_fired(self):
        if self.loaded:
            nextone = False
            path2=self.path
##            self.allfiles =[]
##            for filenames in os.walk(self.filedir):
##                for files in filenames:
##                    for afile in files:
##                        if ".zis" in str(afile):
##                            if ".png" not in str(afile):
##                                self.allfiles.append(self.filedir+"\\"+afile)
            for afile in self.allfiles:
                if nextone == True:
                    self.path = afile
                    junk,self.value = afile.split(self.filedir+"\\")
                    self.filename =self.value
                    self._loader(self.path)
                    self._plot()
                    nextone=False
                    break
                if afile == path2:
                    nextone = True


    def _prevfile_fired(self):
        if self.loaded:
            nextone = False
            path2=self.path
            for afile in self.allfiles:
                if afile == path2:
                    self.path = prevfile
                    junk,self.value = prevfile.split(self.filedir+"\\")
                    self.filename = self.value
                    self._loader(self.path)
                    self._plot()
                    break
                prevfile = afile

    def _genfilelist_fired(self):
        if self.loaded:
            event = {'trial': 0 , "settings" : "", "notes": "", "time": ""}
            eventlist = {"Description": "", "0" : event}
            i=1
            for afile in self.allfiles:
                #grab file name
                junk,currentfilename = afile.split(self.filedir+"\\")
                print "Working on file : " + currentfilename
                #unpickle file and grab data
                try:
                    currentfile = open(afile,'rb')
                    data = pickle.load(currentfile)
                    currentfile.close()

                    foldername,time = currentfilename.split("_")
                    time,junk = time.split(".")
                    settings = data['settings']['scan']
                    strsettings = ""
                    for key, value in settings.iteritems() :
                        strsettings += str(key) + " " + str(value)+ "\n"
                    newtrial = {'trial': i, "settings" : strsettings, "notes": "", "time": time}
                    eventlist[str(i)] = newtrial

                    i +=1
                except:
                    print "\tcorrupt file, skipping"
            settings = ""
            strsettings =""
            newtrial = ""

            #save data to data logger compatible file
            a = os.getcwd() + "/eventlist_"+foldername
            if not os.path.isdir(a):
                print "made"
                os.makedirs(a)
            b = a+ "/filelist.log"
            f1 = open(b, "w")
            pickle.dump(eventlist, f1)
            f1.close()

            print "File Write Complete"
        else:
            print "Please load a folder first"

    def _reload_button_fired(self):
        self._loader(self.path)
        self._plot()
    #-----------------------------------------------
    # Private API
    #-----------------------------------------------

    def _plot(self):
        print "...plotting"
        self.notes = ""
        for key, value in self.plotA["notes"].iteritems() :
            self.notes += str(key) + " " + str(value)+ "\n"

        self.container1 = GridPlotContainer(shape = (2,4), spacing = (0,0), use_backbuffer=True,
	                                     valign = 'top', bgcolor = 'white')

        print"\t assigning data"
        self.plotdata1 = ArrayPlotData(imagedata = self.plotA["data"])
        self.plotdata2 = ArrayPlotData(imagedata = self.plotB["data"])
        self.plotdata3 = ArrayPlotData(imagedata = self.plotC["data"])
        self.plotdata4 = ArrayPlotData(imagedata = self.plotD["data"])



        print"\t calling names"
        self.plot1 = Plot(self.plotdata1, title = self.plotA["plot"]+ str(self.plotA["notes"]["start"]))
        self.plot2 = Plot(self.plotdata2, title = self.plotB["plot"])
        self.plot3 = Plot(self.plotdata3, title = self.plotC["plot"])
        self.plot4 = Plot(self.plotdata4, title = self.plotD["plot"])


        self.plot1.img_plot("imagedata", xbounds = (self.plotA["range"][0][0],self.plotA["range"][0][1]), ybounds = (self.plotA["range"][1][0],self.plotA["range"][1][1]))
        self.plot2.img_plot("imagedata", xbounds = (self.plotB["range"][0][0],self.plotB["range"][0][1]), ybounds = (self.plotB["range"][1][0],self.plotB["range"][1][1]))
        self.plot3.img_plot("imagedata", xbounds = (self.plotC["range"][0][0],self.plotC["range"][0][1]), ybounds = (self.plotC["range"][1][0],self.plotC["range"][1][1]))
        self.plot4.img_plot("imagedata", xbounds = (self.plotD["range"][0][0],self.plotD["range"][0][1]), ybounds = (self.plotD["range"][1][0],self.plotD["range"][1][1]))



#        self.scale = Str(self.plot3.color_mapper.high)
#        plot1.index_axis.title = str(f) + ' (um)'

##ADD TOOLS
        self.plot1.tools.append(PanTool(self.plot1))
        zoom1 = ZoomTool(component=self.plot1, tool_mode="box", always_on=False)
        self.plot1.overlays.append(zoom1)

        self.plot2.tools.append(PanTool(self.plot2))
        zoom2 = ZoomTool(component=self.plot2, tool_mode="box", always_on=False)
        self.plot2.overlays.append(zoom2)

        self.plot3.tools.append(PanTool(self.plot3))
        zoom3 = ZoomTool(component=self.plot3, tool_mode="box", always_on=False)
        self.plot3.overlays.append(zoom3)

        self.plot4.tools.append(PanTool(self.plot4))
        zoom4 = ZoomTool(component=self.plot4, tool_mode="box", always_on=False)
        self.plot4.overlays.append(zoom4)

##ADD COLORBARS
        self.colorbar1 = ColorBar(index_mapper=LinearMapper(range=self.plot1.color_mapper.range),
                        color_mapper=self.plot1.color_mapper,
                        orientation='v',
                        resizable='v',
                        width=20,
                        padding=5)
        self.colorbar1.plot = self.plot1
        self.colorbar1.padding_left = 45
        self.colorbar1.padding_right= 5

        self.colorbar2 = ColorBar(index_mapper=LinearMapper(range=self.plot2.color_mapper.range),
                        color_mapper=self.plot2.color_mapper,
                        orientation='v',
                        resizable='v',
                        width=20,
                        padding=5)
        self.colorbar2.plot = self.plot2
        self.colorbar2.padding_left = 10

        self.colorbar3 = ColorBar(index_mapper=LinearMapper(range=self.plot3.color_mapper.range),
                        color_mapper=self.plot3.color_mapper,
                        orientation='v',
                        resizable='v',
                        width=20,
                        padding=5)
        self.colorbar3.plot = self.plot3
        self.colorbar3.padding_left = 45
        self.colorbar3.padding_right= 5

        self.colorbar4 = ColorBar(index_mapper=LinearMapper(range=self.plot4.color_mapper.range),
                        color_mapper=self.plot4.color_mapper,
                        orientation='v',
                        resizable='v',
                        width=20,
                        padding=5)
        self.colorbar4.plot = self.plot4
        self.colorbar4.padding_left = 15



        self.colorbar1.tools.append(PanTool(self.colorbar1, constrain_direction="y", constrain=True))
        self.zoom_overlay1 = ZoomTool(self.colorbar1, axis="index", tool_mode="range",
                            always_on=True, drag_button="right")
        self.colorbar1.overlays.append(self.zoom_overlay1)

        self.colorbar2.tools.append(PanTool(self.colorbar2, constrain_direction="y", constrain=True))
        self.zoom_overlay2 = ZoomTool(self.colorbar2, axis="index", tool_mode="range",
                            always_on=True, drag_button="right")
        self.colorbar2.overlays.append(self.zoom_overlay2)

        self.colorbar3.tools.append(PanTool(self.colorbar3, constrain_direction="y", constrain=True))
        self.zoom_overlay3 = ZoomTool(self.colorbar3, axis="index", tool_mode="range",
                            always_on=True, drag_button="right")
        self.colorbar3.overlays.append(self.zoom_overlay3)

        self.colorbar4.tools.append(PanTool(self.colorbar4, constrain_direction="y", constrain=True))
        self.zoom_overlay4 = ZoomTool(self.colorbar4, axis="index", tool_mode="range",
                            always_on=True, drag_button="right")
        self.colorbar4.overlays.append(self.zoom_overlay4)


        self.container1.add(self.colorbar1)
        self.container1.add(self.plot1)
        self.container1.add(self.plot2)
        self.container1.add(self.colorbar2)
        self.container1.add(self.colorbar3)
        self.container1.add(self.plot3)
        self.container1.add(self.plot4)
        self.container1.add(self.colorbar4)

        self.plot1.padding_right = 5
        self.plot2.padding_left = 5
        self.plot1.padding_bottom = 15
        self.plot2.padding_bottom = 15
        self.plot3.padding_top = 15
        self.plot4.padding_top = 15
        self.plot1.x_axis.orientation = "top"
        self.plot2.x_axis.orientation = "top"
        self.plot2.y_axis.orientation = "right"
        self.plot3.padding_right = 5
        self.plot4.padding_left = 5
        self.plot4.y_axis.orientation = "right"

        self.colorbar1.padding_top = self.plot1.padding_top
        self.colorbar1.padding_bottom = self.plot1.padding_bottom
        self.colorbar2.padding_top = self.plot2.padding_top
        self.colorbar2.padding_bottom = self.plot2.padding_bottom
        self.colorbar3.padding_top = self.plot3.padding_top
        self.colorbar3.padding_bottom = self.plot3.padding_bottom
        self.colorbar4.padding_top = self.plot4.padding_top
        self.colorbar4.padding_bottom = self.plot4.padding_bottom

        imgtool = ImageInspectorTool(self.plot1)
        self.plot1.tools.append(imgtool)
        overlay = ImageInspectorOverlay(component=self.plot1, image_inspector=imgtool,
                                    bgcolor="white", border_visible=True)
        self.plot1.overlays.append(overlay)

        self.plot = self.container1

        self._plot2()




#######line plots##############################################################################################

    def _plot2(self):

        print"...plotting line scans"
        title = str(self.plotE['plot']) + str(self.plotE["notes"]["start"])
        self.plotdata5 = ArrayPlotData(imagedata = self.plotE['data'])



        self._image_index = GridDataSource(array([]),
                                          array([]),
                                          sort_order=("ascending","ascending"))

        self.xs = linspace(self.plotE["range"][0][0], self.plotE["range"][0][1], self.plotE["shape"][0])
        self.ys = linspace(self.plotE["range"][1][0], self.plotE["range"][1][1], self.plotE["shape"][1])

        self._image_index.set_data(self.xs, self.ys)

        image_index_range = DataRange2D(self._image_index)


        self._image_value = ImageData(data=array([]), value_depth=1)
        self._image_value.data = self.plotE['data']
        image_value_range = DataRange1D(self._image_value)

        s = ""
        f = ""
        if self.x_s: s = "X"
        if self.y_s: s = "Y"
        if self.z_s: s = "Z"
        if self.x_f: f = "X"
        if self.y_f: f = "Y"
        if self.z_f: f = "Z"


        self.plot1lines = Plot(self.plotdata5,  title = title)
        self.plot1lines.img_plot("imagedata", xbounds = (self.plotE["range"][0][0],self.plotE["range"][0][1]), ybounds = (self.plotE["range"][1][0],self.plotE["range"][1][1]), colormap=jet)
        img_plot = self.plot1lines.img_plot("imagedata")[0]
        imgtool = ImageInspectorTool(img_plot)
        img_plot.tools.append(imgtool)
        overlay = ImageInspectorOverlay(component=img_plot, image_inspector=imgtool,
                                    bgcolor="white", border_visible=True)
        self.plot1lines.overlays.append(overlay)
##ADD TOOLS

        self.plot1lines.tools.append(PanTool(self.plot1lines))
        zoom1 = ZoomTool(component=self.plot1lines, tool_mode="box", always_on=False)
        self.plot1lines.overlays.append(zoom1)

        self.plot1lines.overlays.append(LineInspector(component=self.plot1lines,
                                               axis='index_x',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               is_listener=False,
                                               #constrain_key="right",
                                               color="white"))
        self.plot1lines.overlays.append(LineInspector(component=self.plot1lines,
                                               axis='index_y',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               color="white",
                                               is_listener=False))




##ADD COLORBAR

        self.colorbar5  = ColorBar(index_mapper=LinearMapper(range=self.plot1lines.color_mapper.range),
                        color_mapper=self.plot1lines.color_mapper,
                        orientation='v',
                        resizable='v',
                        width=20,
                        padding=5)
        self.colorbar5.plot = self.plot1lines


        self.colorbar5.tools.append(PanTool(self.colorbar5, constrain_direction="y", constrain=True))
        self.zoom_overlay5 = ZoomTool(self.colorbar5, axis="index", tool_mode="range",
                            always_on=True, drag_button="right")
        self.colorbar5.overlays.append(self.zoom_overlay5)

        self.pd = ArrayPlotData(line_index = array([]),
                                line_value = array([]))



        self.slow_plot = Plot(self.pd,   title = "Slowline : " + self.slowline)
        self.slow_plot.plot(("line_index", "line_value"),
                             line_style='solid')

        self.pd.set_data("line_index2", array([]))
        self.pd.set_data("line_value2", array([]))

        self.fast_plot = Plot(self.pd,   title = "Fastline : " + self.fastline)
        self.fast_plot.plot(("line_index2", "line_value2"),
                             line_style='solid')


        self.pd.set_data("line_index", self.xs)
        self.pd.set_data("line_index2", self.ys)
        self.pd.set_data("line_value",
                                 self._image_value.data[self.fastscanline,:])
        self.pd.set_data("line_value2",
                                 self._image_value.data[:,self.slowscanline])

        self.colorbar5.padding= 0
        self.colorbar5.padding_left = 15
        #self.colorbar5.height = 400
        self.colorbar5.padding_top =50
        self.colorbar5.padding_bottom = 0
        self.colorbar5.padding_right = 25
        self.colorbar5.padding_left = 50

        self.plot1lines.width = 300
        self.plot1lines.padding_top = 50

        self.plot1lines.index_axis.title = 'fast axis (um)'
        self.plot1lines.value_axis.title = 'slow axis (um)'

        self.slow_plot.width = 100
        self.slow_plot.padding_right = 20

        self.fast_plot.width = 100
        self.fast_plot.padding_right = 20

        self.container2 = GridPlotContainer(shape = (1,2), spacing = ((0,0)), use_backbuffer=True,
	                                     valign = 'top', halign = 'center', bgcolor = 'white')
        self.container3 = GridPlotContainer(shape = (2,1), spacing = (0,0), use_backbuffer=True,
	                                     valign = 'top', halign = 'center', bgcolor = 'grey')
        self.container4 = GridPlotContainer(shape = (1,2), spacing = (0,0), use_backbuffer=True,
	                                     valign = 'top', halign = 'center', bgcolor = 'grey')

        self.container2.add(self.colorbar5)
        self.container3.add(self.fast_plot)
        self.container3.add(self.slow_plot)
        self.container4.add(self.container3)
        self.container4.add(self.plot1lines)
        self.container2.add(self.container4)
        self.LineScans = self.container2

        self._readline_fired()
        self._scale_set_fired()

    def _load_file(self,i):
        file = open(str(i),'rb')
        print "\n\n" + str(i)   + " is open"
        self.data = pickle.load(file)
        file.close()
        print "\nfile.closed"
        return
    #########################################################
    def _special_fired(self):
        #load stitch
        print "All resolutions must be the same"
        a,b = 2,3 #raw_input("\tset rows and cols: ")
        size = 512 #raw_input("\tsize:")
        self.stitchdataA = np.zeros((size*a, size*b))
        self.stitchdataB = np.zeros((size*a, size*b))
        i = 1
        col = 0
        while i < a*b:
            j = 1
            row = 0
            while j < b:
                self._load_file(self._single_load())
                self.plotA['plot'] = self.plta
                self.plotA = self._pick_plots(self.plotA)
                print col*size, row*size
                self.stitchdataA[col*size : col*self.plotA["shape"][0], row*size : row*self.plotA["shape"][1]] = self.plotA
                row = row+1
                j = j+1
                i = i+1
            col = col+1


        i = 1
        col = 0
        while i < a*b:
            j = 1
            row = 0
            while j < b:
                self._load_file(self._single_load())
                self.plotB['plot'] = self.pltb
                self.plotB = self._pick_plots(self.plotB)
                self.stitchdataB[col*size : col*self.plotB["shape"][0], row*size : row*self.plotB["shape"][1]] = self.plotB
                row = row+1
                j = j+1
                i = i+1
            col = col+1

        self.plotA["data"] = self.stitchdataA
        self.plotB["data"] = self.stitchdataB
        self.plotC["data"] = self.stitchdataA
        self.plotD["data"] = self.stitchdataB
        self.plotA["range"][0][0] = 0
        self.plotA["range"][1][0] = 0
        self.plotA["range"][0][1] = size*a
        self.plotA["range"][1][1] = size*b
        self.plotA["shape"] = (size*b, size*a)
        self._plot()

        gc.collect()
        return

    ########################################################

    def _pick_plots(self, plotdata):
        #plotA = Enum("R", "Phase", "R0","R/R0","Fluor", "X", "Y")
        print "...loading plot"
        #gather shape info
        plotdata["notes"] = self.data['settings']['scan']
        print "\t filling shapes"
        for x in xrange(3):
            if plotdata["notes"]['axes'][x] == 0:
                fast = plotdata["notes"]['npoints'][x]
                plotdata["range"][0][1] = plotdata["notes"]["fast_axis_range"]
            if plotdata["notes"]['axes'][x] == 1:
                slow = plotdata["notes"]['npoints'][x]
                plotdata["range"][1][1] = plotdata["notes"]['range'][x]

        plotdata["shape"] =(fast,slow)

        print "\t filling data"
        data =np.zeros((fast,slow))
        try:
            if plotdata['plot']== "R":
                print "\t\tplotting R"
                data = np.sqrt(np.multiply(np.array(self.data['data']['lia_x']["0"]),np.array(self.data['data']['lia_x']["0"])) + np.multiply(np.array(self.data['data']['lia_y']["0"]),np.array(self.data['data']['lia_y']["0"])))
            if plotdata['plot'] == "Phase":
                print "\t\tplotting Phase"
                data = np.arctan2(np.array(self.data['data']['lia_y']["0"]),np.array(self.data['data']['lia_x']["0"]))
            if plotdata['plot']== "R0":
                print "\t\tplotting R0"
                data = np.sqrt(np.multiply(np.array(self.data['data']['lia_x']["3"]),np.array(self.data['data']['lia_x']["3"])) + np.multiply(np.array(self.data['data']['lia_y']["3"]),np.array(self.data['data']['lia_y']["3"])))
            if plotdata['plot']== "R/R0":
                print "\t\tplotting R/R0"
                data = np.sqrt(np.multiply(np.array(self.data['data']['lia_x']["0"]),np.array(self.data['data']['lia_x']["0"])) + np.multiply(np.array(self.data['data']['lia_y']["0"]),np.array(self.data['data']['lia_y']["0"])))/np.sqrt(np.multiply(np.array(self.data['data']['lia_x']["3"]),np.array(self.data['data']['lia_x']["3"])) + np.multiply(np.array(self.data['data']['lia_y']["3"]),np.array(self.data['data']['lia_y']["3"])))
            if plotdata['plot']== "Fluor":
                print "\t\tplotting Fluor"
                data = np.array(self.data['data']['dio2'])
            if plotdata['plot']=="X":
                print "\t\tplotting X"
                data = np.array(self.data['data']['lia_x']["0"])
            if plotdata['plot']=="Y":
                print "\t\tplotting Y"
                data = np.array(self.data['data']['lia_y']["0"])
        except:
            print "Process failed-- check dropdown assignments"

        data.shape = (plotdata["shape"])
        plotdata['data'] = data
        return plotdata

    def _single_load(self):
        dialog = FileDialog(action="open", wildcard=self.file_wildcard)
        dialog.open()
        if dialog.return_code == OK:
            return dialog.path



    def _loader(self, i):
    #open file with specified path unpickle and prepare to pass to dictionary

        self._load_file(i)

      #--APPLY LOGIC CONDITIONS TO DATA AND ASSIGN TO PLOTING 2D ARRAYS

        self.plotA['plot'] = self.plta
        self.plotB['plot'] = self.pltb
        self.plotC['plot'] = self.pltc
        self.plotD['plot'] = self.pltd
        self.plotE['plot'] = self.plte

        self.plotA = self._pick_plots(self.plotA)
        self.plotB = self._pick_plots(self.plotB)
        self.plotC = self._pick_plots(self.plotC)
        self.plotD = self._pick_plots(self.plotD)
        self.plotE = self._pick_plots(self.plotE)


        if self.xLo == "":
            print "...populating ranges"
            self.xLo = str(self.plotA["range"][0][0])
            self.xHi = str(self.plotA["range"][0][1])
            self.yLo = str(self.plotA["range"][1][0])
            self.yHi = str(self.plotA["range"][1][1])

        gc.collect()
        return

    def _load1_button_fired(self):
    #open file with specified path unpickle and prepare to pass to dictionary

        self._load_file(self._single_load())

      #--APPLY LOGIC CONDITIONS TO DATA AND ASSIGN TO PLOTING 2D ARRAYS

        self.plotA['plot'] = self.plta

        self.plotA = self._pick_plots(self.plotA)
        gc.collect()
        self._plot()
        return

    def _load2_button_fired(self):
    #open file with specified path unpickle and prepare to pass to dictionary

        self._load_file(self._single_load())

      #--APPLY LOGIC CONDITIONS TO DATA AND ASSIGN TO PLOTING 2D ARRAYS

        self.plotB['plot'] = self.pltb

        self.plotB = self._pick_plots(self.plotB)
        gc.collect()
        self._plot()
        return

    def _load3_button_fired(self):
    #open file with specified path unpickle and prepare to pass to dictionary

        self._load_file(self._single_load())

      #--APPLY LOGIC CONDITIONS TO DATA AND ASSIGN TO PLOTING 2D ARRAYS

        self.plotC['plot'] = self.pltc

        self.plotC = self._pick_plots(self.plotC)
        gc.collect()
        self._plot()
        return

    def _load4_button_fired(self):
    #open file with specified path unpickle and prepare to pass to dictionary

        self._load_file(self._single_load())

      #--APPLY LOGIC CONDITIONS TO DATA AND ASSIGN TO PLOTING 2D ARRAYS

        self.plotD['plot'] = self.pltd

        self.plotD = self._pick_plots(self.plotD)
        gc.collect()
        self._plot()
        return
class ImageGUI(HasTraits):
    
    # TO FIX : put here the last available shot
    #shot = File('L:\\data\\app3\\2011\\1108\\110823\\column_5200.ascii')
    #shot = File('/home/pmd/atomcool/lab/data/app3/2012/1203/120307/column_3195.ascii')

    #-- Shot traits
    shotdir = Directory('/home/pmd/atomcool/lab/data/app3/2012/1203/120320/')
    shots = List(Str)
    selectedshot = List(Str)
    namefilter = Str('column')

    #-- Report trait
    report = Str

    #-- Displayed analysis results
    number = Float
     
    #-- Column density plot container
    column_density = Instance(HPlotContainer)
    #---- Plot components within this container
    imgplot     = Instance(CMapImagePlot)
    cross_plot  = Instance(Plot)
    cross_plot2 = Instance(Plot)
    colorbar    = Instance(ColorBar)
    #---- Plot data
    pd = Instance(ArrayPlotData)
    #---- Colorbar 
    num_levels = Int(15)
    colormap = Enum(color_map_name_dict.keys())

    #-- Crosshair location
    cursor = Instance(BaseCursorTool)
    xy = DelegatesTo('cursor', prefix='current_position')
    xpos = Float(0.)
    ypos = Float(0.)
    xpos_read = Float(0.)
    ypos_read = Float(0.)
    cursor_group = Group( Group(Item('xpos', show_label=True), 
	                        Item('xpos_read', show_label=False, style="readonly"),
				orientation='horizontal'),
			  Group(Item('ypos', show_label=True), 
				Item('ypos_read', show_label=False, style="readonly"),
				orientation='horizontal'),
		          orientation='vertical', layout='normal',springy=True)

    
    #---------------------------------------------------------------------------
    # Traits View Definitions
    #---------------------------------------------------------------------------
    
    traits_view = View(
                    Group(
                      #Directory
                      Item( 'shotdir',style='simple', editor=DirectoryEditor(), width = 400, \
				      show_label=False, resizable=False ),
                      #Bottom
                      HSplit(
		        #-- Pane for shot selection
        	        Group(
		          Item( 'namefilter', show_label=False,springy=False),		
                          Item( 'shots',show_label=False, width=180, height= 360, \
					editor = TabularEditor(selected='selectedshot',\
					editable=False,multi_select=True,\
					adapter=SelectAdapter()) ),
			  cursor_group,
                          orientation='vertical',
		          layout='normal', ),

		        #-- Pane for column density plots
			Group(
			  Item('column_density',editor=ComponentEditor(), \
                                           show_label=False, width=600, height=500, \
                                           resizable=True ), 
			  Item('report',show_label=False, width=180, \
					springy=True, style='custom' ),
			  layout='tabbed', springy=True),

			#-- Pane for analysis results
			Group(
		          Item('number',show_label=False)
			  )
                      ),
                      orientation='vertical',
                      layout='normal',
                    ),
                  width=1400, height=500, resizable=True)
    
    #-- Pop-up view when Plot->Edit is selcted from the menu
    plot_edit_view = View(
                    Group(Item('num_levels'),
                          Item('colormap')),
                          buttons=["OK","Cancel"])
                          
    
    #---------------------------------------------------------------------------
    # Private Traits
    #---------------------------------------------------------------------------

    #-- Represents the region where the data set is defined
    _image_index = Instance(GridDataSource) 

    #-- Represents the data that will be plotted on the grid
    _image_value = Instance(ImageData)

    #-- Represents the color map that will be used
    _cmap = Trait(jet, Callable)
    
    
    #---------------------------------------------------------------------------
    # Public View interface
    #---------------------------------------------------------------------------

    def __init__(self, *args, **kwargs):
	#-- super is used to run the inherited __init__ method
	#-- this ensures that all the Traits machinery is properly setup
	#-- even though the __init__ method is overridden
        super(ImageGUI, self).__init__(*args, **kwargs)

	#-- after running the inherited __init__, a plot is created
        self.create_plot()



    def create_plot(self):

        #-- Create the index for the x an y axes and the range over
	#-- which they vary
        self._image_index = GridDataSource(array([]), array([]),
                                          sort_order=("ascending","ascending"))
        image_index_range = DataRange2D(self._image_index)
        
	#-- I believe this is what allows tracking the mouse
        self._image_index.on_trait_change(self._metadata_changed,
                                          "metadata_changed")


	#-- Create the image values and determine their range
        self._image_value = ImageData(data=array([]), value_depth=1)
        image_value_range = DataRange1D(self._image_value)
        
        # Create the image plot
        self.imgplot = CMapImagePlot( index=self._image_index,
                                      value=self._image_value,
                                      index_mapper=GridMapper(range=image_index_range),
                                      color_mapper=self._cmap(image_value_range),)
                                 

        # Add a left axis to the plot
        left = PlotAxis(orientation='left',
                        title= "axial",
                        mapper=self.imgplot.index_mapper._ymapper,
                        component=self.imgplot)
        self.imgplot.overlays.append(left)

        # Add a bottom axis to the plot
        bottom = PlotAxis(orientation='bottom',
                          title= "radial",
                          mapper=self.imgplot.index_mapper._xmapper,
                          component=self.imgplot)
        self.imgplot.overlays.append(bottom)


        # Add some tools to the plot
        self.imgplot.tools.append(PanTool(self.imgplot,drag_button="right",
                                            constrain_key="shift"))

        self.imgplot.overlays.append(ZoomTool(component=self.imgplot,
                                            tool_mode="box", always_on=False))

        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=image_value_range)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=self.imgplot,
                                 padding_top=self.imgplot.padding_top,
                                 padding_bottom=self.imgplot.padding_bottom,
                                 padding_right=40,
                                 resizable='v',
                                 width=30)


	# Add a cursor 
	self.cursor = CursorTool( self.imgplot, drag_button="left", color="white")
	# the cursor is a rendered component so it goes in the overlays list
	self.imgplot.overlays.append(self.cursor)
                        
        # Create the two cross plots
        self.pd = ArrayPlotData(line_index = array([]),
                                line_value = array([]),
                                scatter_index = array([]),
                                scatter_value = array([]),
                                scatter_color = array([]))

        self.cross_plot = Plot(self.pd, resizable="h")
        self.cross_plot.height = 100
        self.cross_plot.padding = 20
        self.cross_plot.plot(("line_index", "line_value"),
                             line_style="dot")
        self.cross_plot.plot(("scatter_index","scatter_value","scatter_color"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=6)

        self.cross_plot.index_range = self.imgplot.index_range.x_range

        self.pd.set_data("line_index2", array([]))
        self.pd.set_data("line_value2", array([]))
        self.pd.set_data("scatter_index2", array([]))
        self.pd.set_data("scatter_value2", array([]))
        self.pd.set_data("scatter_color2", array([]))

        self.cross_plot2 = Plot(self.pd, width = 140, orientation="v", resizable="v", padding=20, padding_bottom=160)
        self.cross_plot2.plot(("line_index2", "line_value2"),
                             line_style="dot")
        self.cross_plot2.plot(("scatter_index2","scatter_value2","scatter_color2"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=8)

        self.cross_plot2.index_range = self.imgplot.index_range.y_range


        # Create a container and add sub-containers and components
        self.column_density = HPlotContainer(padding=40, fill_padding=True,
                                        bgcolor = "white", use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.cross_plot)
	self.imgplot.padding =20
	inner_cont.add(self.imgplot)
        self.column_density.add(self.colorbar)
        self.column_density.add(inner_cont)
        self.column_density.add(self.cross_plot2)

    def update(self):
	#print self.cursor.current_index
	#self.cursor.current_position = 100.,100.
        self.shots = self.populate_shot_list()
	print self.selectedshot    
        imgdata, self.report = self.load_imagedata()
        if imgdata is not None:
            self.minz = imgdata.min()
            self.maxz = imgdata.max()
            self.colorbar.index_mapper.range.low = self.minz
            self.colorbar.index_mapper.range.high = self.maxz
            xs=numpy.linspace(0,imgdata.shape[0],imgdata.shape[0]+1)
            ys=numpy.linspace(0,imgdata.shape[1],imgdata.shape[1]+1)
            #print xs
            #print ys
            self._image_index.set_data(xs,ys)
            self._image_value.data = imgdata
            self.pd.set_data("line_index", xs)
            self.pd.set_data("line_index2",ys)
            self.column_density.invalidate_draw()
            self.column_density.request_redraw()                        

    def populate_shot_list(self):
        try:
            shot_list = os.listdir(self.shotdir)
	    fun = lambda x: iscol(x,self.namefilter)
            shot_list = filter( fun, shot_list)
	    shot_list = sorted(shot_list)
        except ValueError:
            print " *** Not a valid directory path ***"
        return shot_list

    def load_imagedata(self):
        try:
            directory = self.shotdir
	    if self.selectedshot == []:
		    filename = self.shots[0]
	    else:
		    filename = self.selectedshot[0]
            #shotnum = filename[filename.rindex('_')+1:filename.rindex('.ascii')]
	    shotnum = filename[:filename.index('_')]
        except ValueError:
            print " *** Not a valid path *** " 
            return None
        # Set data path
        # Prepare PlotData object
	print "Loading file #%s from %s" % (filename,directory)
        return import_data.load(directory,filename), import_data.load_report(directory,shotnum)


    #---------------------------------------------------------------------------
    # Event handlers
    #---------------------------------------------------------------------------
    
    def _selectedshot_changed(self):
	print self.selectedshot
        self.update()

    def _shots_changed(self):
        self.shots = self.populate_shot_list()
	return

    def _namefilter_changed(self):
	self.shots = self.populate_shot_list()
	return

  
    def _xpos_changed(self):
	self.cursor.current_position = self.xpos, self.ypos
    def _ypos_changed(self):
	self.cursor.current_position = self.xpos, self.ypos

    def _metadata_changed(self):
	self._xy_changed()
	    
    def _xy_changed(self):
	self.xpos_read = self.cursor.current_index[0]
	self.ypos_read = self.cursor.current_index[1]
	#print self.cursor.current_index
        """ This function takes out a cross section from the image data, based
        on the cursor selections, and updates the line and scatter
        plots."""
        self.cross_plot.value_range.low = self.minz
        self.cross_plot.value_range.high = self.maxz
        self.cross_plot2.value_range.low = self.minz
        self.cross_plot2.value_range.high = self.maxz
        if True:
            x_ndx, y_ndx = self.cursor.current_index
            if y_ndx and x_ndx:
                self.pd.set_data("line_value",
				self._image_value.data[:,y_ndx])
                self.pd.set_data("line_value2",
				self._image_value.data[x_ndx,:])
                xdata, ydata = self._image_index.get_data()
                xdata, ydata = xdata.get_data(), ydata.get_data()
                self.pd.set_data("scatter_index", array([ydata[y_ndx]]))
                self.pd.set_data("scatter_index2", array([xdata[x_ndx]]))
                self.pd.set_data("scatter_value",
                    array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_value2",
                    array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_color",
                    array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_color2",
                    array([self._image_value.data[y_ndx, x_ndx]]))
        else:
            self.pd.set_data("scatter_value", array([]))
            self.pd.set_data("scatter_value2", array([]))
            self.pd.set_data("line_value", array([]))
            self.pd.set_data("line_value2", array([]))

    def _colormap_changed(self):
        self._cmap = color_map_name_dict[self.colormap]
        if hasattr(self, "polyplot"):
            value_range = self.polyplot.color_mapper.range
            self.polyplot.color_mapper = self._cmap(value_range)
            value_range = self.cross_plot.color_mapper.range
            self.cross_plot.color_mapper = self._cmap(value_range)
            # FIXME: change when we decide how best to update plots using
            # the shared colormap in plot object
            self.cross_plot.plots["dot"][0].color_mapper = self._cmap(value_range)
            self.cross_plot2.plots["dot"][0].color_mapper = self._cmap(value_range)
            self.column_density.request_redraw()

    def _num_levels_changed(self):
        if self.num_levels > 3:
            self.polyplot.levels = self.num_levels
            self.lineplot.levels = self.num_levels
class PlotUI(HasTraits):
    
    # container for all plots
    container = Instance(HPlotContainer)
    
    # Plot components within this container:
    polyplot = Instance(ContourPolyPlot)
    lineplot = Instance(ContourLinePlot)
    cross_plot = Instance(Plot)
    cross_plot2 = Instance(Plot)
    colorbar = Instance(ColorBar)
    
    # plot data
    pd = Instance(ArrayPlotData)

    # view options
    num_levels = Int(15)
    colormap = Enum(colormaps)
    
    #Traits view definitions:
    traits_view = View(
        Group(UItem('container', editor=ComponentEditor(size=(800,600)))),
        resizable=True)

    plot_edit_view = View(
        Group(Item('num_levels'),
              Item('colormap')),
              buttons=["OK","Cancel"])

    
    #---------------------------------------------------------------------------
    # Private Traits
    #---------------------------------------------------------------------------

    _image_index = Instance(GridDataSource)
    _image_value = Instance(ImageData)

    _cmap = Trait(default_colormaps.jet, Callable)

    #---------------------------------------------------------------------------
    # Public View interface
    #---------------------------------------------------------------------------

    def __init__(self, *args, **kwargs):
        super(PlotUI, self).__init__(*args, **kwargs)
        # FIXME: 'with' wrapping is temporary fix for infinite range in initial 
        # color map, which can cause a distracting warning print. This 'with'
        # wrapping should be unnecessary after fix in color_mapper.py.
        with errstate(invalid='ignore'):
            self.create_plot()

    def create_plot(self):

        # Create the mapper, etc
        self._image_index = GridDataSource(array([]),
                                          array([]),
                                          sort_order=("ascending","ascending"))
        image_index_range = DataRange2D(self._image_index)
        self._image_index.on_trait_change(self._metadata_changed,
                                          "metadata_changed")

        self._image_value = ImageData(data=array([]), value_depth=1)
        image_value_range = DataRange1D(self._image_value)



        # Create the contour plots
        self.polyplot = ContourPolyPlot(index=self._image_index,
                                        value=self._image_value,
                                        index_mapper=GridMapper(range=
                                            image_index_range),
                                        color_mapper=\
                                            self._cmap(image_value_range),
                                        levels=self.num_levels)

        self.lineplot = ContourLinePlot(index=self._image_index,
                                        value=self._image_value,
                                        index_mapper=GridMapper(range=
                                            self.polyplot.index_mapper.range),
                                        levels=self.num_levels)


        # Add a left axis to the plot
        left = PlotAxis(orientation='left',
                        title= "y",
                        mapper=self.polyplot.index_mapper._ymapper,
                        component=self.polyplot)
        self.polyplot.overlays.append(left)

        # Add a bottom axis to the plot
        bottom = PlotAxis(orientation='bottom',
                          title= "x",
                          mapper=self.polyplot.index_mapper._xmapper,
                          component=self.polyplot)
        self.polyplot.overlays.append(bottom)


        # Add some tools to the plot
        self.polyplot.tools.append(PanTool(self.polyplot,
                                           constrain_key="shift"))
        self.polyplot.overlays.append(ZoomTool(component=self.polyplot,
                                            tool_mode="box", always_on=False))
        self.polyplot.overlays.append(LineInspector(component=self.polyplot,
                                               axis='index_x',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               is_listener=True,
                                               color="white"))
        self.polyplot.overlays.append(LineInspector(component=self.polyplot,
                                               axis='index_y',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               color="white",
                                               is_listener=True))

        # Add these two plots to one container
        contour_container = OverlayPlotContainer(padding=20,
                                                 use_backbuffer=True,
                                                 unified_draw=True)
        contour_container.add(self.polyplot)
        contour_container.add(self.lineplot)


        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=image_value_range)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=self.polyplot,
                                 padding_top=self.polyplot.padding_top,
                                 padding_bottom=self.polyplot.padding_bottom,
                                 padding_right=40,
                                 resizable='v',
                                 width=30)

        self.pd = ArrayPlotData(line_index = array([]),
                                line_value = array([]),
                                scatter_index = array([]),
                                scatter_value = array([]),
                                scatter_color = array([]))

        self.cross_plot = Plot(self.pd, resizable="h")
        self.cross_plot.height = 100
        self.cross_plot.padding = 20
        self.cross_plot.plot(("line_index", "line_value"),
                             line_style="dot")
        self.cross_plot.plot(("scatter_index","scatter_value","scatter_color"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=8)

        self.cross_plot.index_range = self.polyplot.index_range.x_range

        self.pd.set_data("line_index2", array([]))
        self.pd.set_data("line_value2", array([]))
        self.pd.set_data("scatter_index2", array([]))
        self.pd.set_data("scatter_value2", array([]))
        self.pd.set_data("scatter_color2", array([]))

        self.cross_plot2 = Plot(self.pd, width = 140, orientation="v", 
                                resizable="v", padding=20, padding_bottom=160)
        self.cross_plot2.plot(("line_index2", "line_value2"),
                             line_style="dot")
        self.cross_plot2.plot(("scatter_index2",
                               "scatter_value2",
                               "scatter_color2"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=8)

        self.cross_plot2.index_range = self.polyplot.index_range.y_range



        # Create a container and add components
        self.container = HPlotContainer(padding=40, fill_padding=True,
                                        bgcolor = "white", use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.cross_plot)
        inner_cont.add(contour_container)
        self.container.add(self.colorbar)
        self.container.add(inner_cont)
        self.container.add(self.cross_plot2)


    def update(self, model):
        self.minz = model.minz
        self.maxz = model.maxz
        self.colorbar.index_mapper.range.low = self.minz
        self.colorbar.index_mapper.range.high = self.maxz
        self._image_index.set_data(model.xs, model.ys)
        self._image_value.data = model.zs
        self.pd.update_data(line_index=model.xs, line_index2=model.ys)
        self.container.invalidate_draw()
        self.container.request_redraw()


    #---------------------------------------------------------------------------
    # Event handlers
    #---------------------------------------------------------------------------

    def _metadata_changed(self, old, new):
        """ This function takes out a cross section from the image data, based
        on the line inspector selections, and updates the line and scatter
        plots."""

        self.cross_plot.value_range.low = self.minz
        self.cross_plot.value_range.high = self.maxz
        self.cross_plot2.value_range.low = self.minz
        self.cross_plot2.value_range.high = self.maxz
        if self._image_index.metadata.has_key("selections"):
            x_ndx, y_ndx = self._image_index.metadata["selections"]
            if y_ndx and x_ndx:
                xdata, ydata = self._image_index.get_data()
                xdata, ydata = xdata.get_data(), ydata.get_data()
                self.pd.update_data(
                    line_value=self._image_value.data[y_ndx,:],
                    line_value2=self._image_value.data[:,x_ndx],
                    scatter_index=array([xdata[x_ndx]]),
                    scatter_index2=array([ydata[y_ndx]]),
                    scatter_value=array([self._image_value.data[y_ndx, x_ndx]]),
                    scatter_value2=array([self._image_value.data[y_ndx, x_ndx]]),
                    scatter_color=array([self._image_value.data[y_ndx, x_ndx]]),
                    scatter_color2=array([self._image_value.data[y_ndx, x_ndx]])
                )
        else:
            self.pd.update_data({"scatter_value": array([]),
                "scatter_value2": array([]), "line_value": array([]),
                "line_value2": array([])})

    def _colormap_changed(self):
        self._cmap = default_colormaps.color_map_name_dict[self.colormap]
        if self.polyplot is not None:
            value_range = self.polyplot.color_mapper.range
            self.polyplot.color_mapper = self._cmap(value_range)
            value_range = self.cross_plot.color_mapper.range
            self.cross_plot.color_mapper = self._cmap(value_range)
            # FIXME: change when we decide how best to update plots using
            # the shared colormap in plot object
            self.cross_plot.plots["dot"
                                  ][0].color_mapper = self._cmap(value_range)
            self.cross_plot2.plots["dot"
                                   ][0].color_mapper = self._cmap(value_range)
            self.container.request_redraw()

    def _num_levels_changed(self):
        if self.num_levels > 3:
            self.polyplot.levels = self.num_levels
            self.lineplot.levels = self.num_levels
class PlotUI(HasTraits):
    
    # container for all plots
    container = Instance(HPlotContainer)
    
    # Plot components within this container:
    polyplot = Instance(ContourPolyPlot)
    lineplot = Instance(ContourLinePlot)
    cross_plot = Instance(Plot)
    cross_plot2 = Instance(Plot)
    colorbar = Instance(ColorBar)
    
    # plot data
    pd = Instance(ArrayPlotData)

    # view options
    num_levels = Int(15)
    colormap = Enum(colormaps)
    
    #Traits view definitions:
    traits_view = View(
        Group(UItem('container', editor=ComponentEditor(size=(800,600)))),
        resizable=True)

    plot_edit_view = View(
        Group(Item('num_levels'),
              Item('colormap')),
              buttons=["OK","Cancel"])

    
    #---------------------------------------------------------------------------
    # Private Traits
    #---------------------------------------------------------------------------

    _image_index = Instance(GridDataSource)
    _image_value = Instance(ImageData)

    _cmap = Trait(default_colormaps.jet, Callable)

    #---------------------------------------------------------------------------
    # Public View interface
    #---------------------------------------------------------------------------

    def __init__(self, *args, **kwargs):
        super(PlotUI, self).__init__(*args, **kwargs)
        # FIXME: 'with' wrapping is temporary fix for infinite range in initial 
        # color map, which can cause a distracting warning print. This 'with'
        # wrapping should be unnecessary after fix in color_mapper.py.
        with errstate(invalid='ignore'):
            self.create_plot()

    def create_plot(self):

        # Create the mapper, etc
        self._image_index = GridDataSource(array([]),
                                          array([]),
                                          sort_order=("ascending","ascending"))
        image_index_range = DataRange2D(self._image_index)
        self._image_index.on_trait_change(self._metadata_changed,
                                          "metadata_changed")

        self._image_value = ImageData(data=array([]), value_depth=1)
        image_value_range = DataRange1D(self._image_value)



        # Create the contour plots
        self.polyplot = ContourPolyPlot(index=self._image_index,
                                        value=self._image_value,
                                        index_mapper=GridMapper(range=
                                            image_index_range),
                                        color_mapper=\
                                            self._cmap(image_value_range),
                                        levels=self.num_levels)

        self.lineplot = ContourLinePlot(index=self._image_index,
                                        value=self._image_value,
                                        index_mapper=GridMapper(range=
                                            self.polyplot.index_mapper.range),
                                        levels=self.num_levels)


        # Add a left axis to the plot
        left = PlotAxis(orientation='left',
                        title= "y",
                        mapper=self.polyplot.index_mapper._ymapper,
                        component=self.polyplot)
        self.polyplot.overlays.append(left)

        # Add a bottom axis to the plot
        bottom = PlotAxis(orientation='bottom',
                          title= "x",
                          mapper=self.polyplot.index_mapper._xmapper,
                          component=self.polyplot)
        self.polyplot.overlays.append(bottom)


        # Add some tools to the plot
        self.polyplot.tools.append(PanTool(self.polyplot,
                                           constrain_key="shift"))
        self.polyplot.overlays.append(ZoomTool(component=self.polyplot,
                                            tool_mode="box", always_on=False))
        self.polyplot.overlays.append(LineInspector(component=self.polyplot,
                                               axis='index_x',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               is_listener=True,
                                               color="white"))
        self.polyplot.overlays.append(LineInspector(component=self.polyplot,
                                               axis='index_y',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               color="white",
                                               is_listener=True))

        # Add these two plots to one container
        contour_container = OverlayPlotContainer(padding=20,
                                                 use_backbuffer=True,
                                                 unified_draw=True)
        contour_container.add(self.polyplot)
        contour_container.add(self.lineplot)


        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=image_value_range)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=self.polyplot,
                                 padding_top=self.polyplot.padding_top,
                                 padding_bottom=self.polyplot.padding_bottom,
                                 padding_right=40,
                                 resizable='v',
                                 width=30)

        self.pd = ArrayPlotData(line_index = array([]),
                                line_value = array([]),
                                scatter_index = array([]),
                                scatter_value = array([]),
                                scatter_color = array([]))

        self.cross_plot = Plot(self.pd, resizable="h")
        self.cross_plot.height = 100
        self.cross_plot.padding = 20
        self.cross_plot.plot(("line_index", "line_value"),
                             line_style="dot")
        self.cross_plot.plot(("scatter_index","scatter_value","scatter_color"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=8)

        self.cross_plot.index_range = self.polyplot.index_range.x_range

        self.pd.set_data("line_index2", array([]))
        self.pd.set_data("line_value2", array([]))
        self.pd.set_data("scatter_index2", array([]))
        self.pd.set_data("scatter_value2", array([]))
        self.pd.set_data("scatter_color2", array([]))

        self.cross_plot2 = Plot(self.pd, width = 140, orientation="v", resizable="v", padding=20, padding_bottom=160)
        self.cross_plot2.plot(("line_index2", "line_value2"),
                             line_style="dot")
        self.cross_plot2.plot(("scatter_index2","scatter_value2","scatter_color2"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=8)

        self.cross_plot2.index_range = self.polyplot.index_range.y_range



        # Create a container and add components
        self.container = HPlotContainer(padding=40, fill_padding=True,
                                        bgcolor = "white", use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.cross_plot)
        inner_cont.add(contour_container)
        self.container.add(self.colorbar)
        self.container.add(inner_cont)
        self.container.add(self.cross_plot2)


    def update(self, model):
        self.minz = model.minz
        self.maxz = model.maxz
        self.colorbar.index_mapper.range.low = self.minz
        self.colorbar.index_mapper.range.high = self.maxz
        self._image_index.set_data(model.xs, model.ys)
        self._image_value.data = model.zs
        self.pd.set_data("line_index", model.xs)
        self.pd.set_data("line_index2", model.ys)
        self.container.invalidate_draw()
        self.container.request_redraw()


    #---------------------------------------------------------------------------
    # Event handlers
    #---------------------------------------------------------------------------

    def _metadata_changed(self, old, new):
        """ This function takes out a cross section from the image data, based
        on the line inspector selections, and updates the line and scatter
        plots."""

        self.cross_plot.value_range.low = self.minz
        self.cross_plot.value_range.high = self.maxz
        self.cross_plot2.value_range.low = self.minz
        self.cross_plot2.value_range.high = self.maxz
        if self._image_index.metadata.has_key("selections"):
            x_ndx, y_ndx = self._image_index.metadata["selections"]
            if y_ndx and x_ndx:
                self.pd.set_data("line_value",
                                 self._image_value.data[y_ndx,:])
                self.pd.set_data("line_value2",
                                 self._image_value.data[:,x_ndx])
                xdata, ydata = self._image_index.get_data()
                xdata, ydata = xdata.get_data(), ydata.get_data()
                self.pd.set_data("scatter_index", array([xdata[x_ndx]]))
                self.pd.set_data("scatter_index2", array([ydata[y_ndx]]))
                self.pd.set_data("scatter_value",
                    array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_value2",
                    array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_color",
                    array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_color2",
                    array([self._image_value.data[y_ndx, x_ndx]]))
        else:
            self.pd.set_data("scatter_value", array([]))
            self.pd.set_data("scatter_value2", array([]))
            self.pd.set_data("line_value", array([]))
            self.pd.set_data("line_value2", array([]))

    def _colormap_changed(self):
        self._cmap = default_colormaps.color_map_name_dict[self.colormap]
        if self.polyplot is not None:
            value_range = self.polyplot.color_mapper.range
            self.polyplot.color_mapper = self._cmap(value_range)
            value_range = self.cross_plot.color_mapper.range
            self.cross_plot.color_mapper = self._cmap(value_range)
            # FIXME: change when we decide how best to update plots using
            # the shared colormap in plot object
            self.cross_plot.plots["dot"][0].color_mapper = self._cmap(value_range)
            self.cross_plot2.plots["dot"][0].color_mapper = self._cmap(value_range)
            self.container.request_redraw()

    def _num_levels_changed(self):
        if self.num_levels > 3:
            self.polyplot.levels = self.num_levels
            self.lineplot.levels = self.num_levels
def get_image_index_and_mapper(image):
    h, w = image.shape[:2]
    index = GridDataSource(np.arange(h + 1), np.arange(w + 1))
    index_mapper = GridMapper(range=DataRange2D(low=(0, 0), high=(h, w)))
    return index, index_mapper
    def __init__(self, x,y,z):
        super(ImagePlot, self).__init__()
        self.pd_all = ArrayPlotData(imagedata = z)
        #self.pd_horiz = ArrayPlotData(x=x, horiz=z[4, :])
        #self.pd_vert = ArrayPlotData(y=y, vert=z[:,5])

        self._imag_index = GridDataSource(xdata=x, ydata=y, sort_order=("ascending","ascending"))
        index_mapper = GridMapper(range=DataRange2D(self._imag_index))
        self._imag_index.on_trait_change(self._metadata_changed,
                                          "metadata_changed")
        self._image_value = ImageData(data=z, value_depth=1)
        color_mapper = jet(DataRange1D(self._image_value))

        self.color_plot= CMapImagePlot(
            index=self._imag_index,
            index_mapper=index_mapper,
            value=self._image_value,
            value_mapper=color_mapper,
            padding=20,
            use_backbuffer=True,
            unified_draw=True)

        #Add axes to image plot
        left = PlotAxis(orientation='left',
                        title= "Frequency (GHz)",
                        mapper=self.color_plot.index_mapper._ymapper,
                        component=self.color_plot)

        self.color_plot.overlays.append(left)

        bottom = PlotAxis(orientation='bottom',
                        title= "Time (us)",
                        mapper=self.color_plot.index_mapper._xmapper,
                        component=self.color_plot)
        self.color_plot.overlays.append(bottom)

        self.color_plot.tools.append(PanTool(self.color_plot,
                                           constrain_key="shift"))
        self.color_plot.overlays.append(ZoomTool(component=self.color_plot,
                                            tool_mode="box", always_on=False))

        #Add line inspector tool for horizontal and vertical
        self.color_plot.overlays.append(LineInspector(component=self.color_plot,
                                               axis='index_x',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               is_listener=True,
                                               color="white"))

        self.color_plot.overlays.append(LineInspector(component=self.color_plot,
                                               axis='index_y',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               color="white",
                                               is_listener=True))

        myrange = DataRange1D(low=amin(z),
                              high=amax(z))
        cmap=jet
        self.colormap = cmap(myrange)

        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=myrange)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=self.color_plot,
                                 padding_top=self.color_plot.padding_top,
                                 padding_bottom=self.color_plot.padding_bottom,
                                 padding_right=40,
                                 resizable='v',
                                 width=30)#, ytitle="Magvec (mV)")

        #create horizontal line plot
        self.horiz_cross_plot = Plot(self.pd_horiz, resizable="h")
        self.horiz_cross_plot.height = 100
        self.horiz_cross_plot.padding = 20
        self.horiz_cross_plot.plot(("x", "horiz"))#,
                             #line_style="dot")
#        self.cross_plot.plot(("scatter_index","scatter_value","scatter_color"),
#                             type="cmap_scatter",
#                             name="dot",
#                             color_mapper=self._cmap(image_value_range),
#                             marker="circle",
#                             marker_size=8)

        self.horiz_cross_plot.index_range = self.color_plot.index_range.x_range

        #create vertical line plot
        self.vert_cross_plot = Plot(self.pd_vert, width = 140, orientation="v",
                                resizable="v", padding=20, padding_bottom=160)
        self.vert_cross_plot.plot(("y", "vert"))#,
#                             line_style="dot")
       # self.vert_cross_plot.xtitle="Magvec (mV)"
 #       self.vertica_cross_plot.plot(("vertical_scatter_index",
 #                              "vertical_scatter_value",
 #                              "vertical_scatter_color"),
 #                            type="cmap_scatter",
 #                            name="dot",
 #                            color_mapper=self._cmap(image_value_range),
 #                            marker="circle",
  #                           marker_size=8)

        self.vert_cross_plot.index_range = self.color_plot.index_range.y_range

        # Create a container and add components
        self.container = HPlotContainer(padding=40, fill_padding=True,
                                        bgcolor = "white", use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.horiz_cross_plot)
        inner_cont.add(self.color_plot)
        self.container.add(self.colorbar)
        self.container.add(inner_cont)
        self.container.add(self.vert_cross_plot)
class ImagePlot(Atom):
    # container for all plots
    container = Typed(HPlotContainer)

    # Plot components within this container:
    color_plot = Typed(CMapImagePlot)
    vertical_cross_plot = Typed(Plot)
    horizontal_cross_plot = Typed(Plot)
    colorbar = Typed(ColorBar)

    # plot data
    pd_all = Typed(ArrayPlotData)
    #pd_horiz=Instance(ArrayPlotData)
    #pd_vert=Instance(ArrayPlotData)
    #private data storage
    _imag_index=Typed(GridDataSource)
    _image_value=Typed(ImageData)

    def __init__(self, x,y,z):
        super(ImagePlot, self).__init__()
        self.pd_all = ArrayPlotData(imagedata = z)
        #self.pd_horiz = ArrayPlotData(x=x, horiz=z[4, :])
        #self.pd_vert = ArrayPlotData(y=y, vert=z[:,5])

        self._imag_index = GridDataSource(xdata=x, ydata=y, sort_order=("ascending","ascending"))
        index_mapper = GridMapper(range=DataRange2D(self._imag_index))
        self._imag_index.on_trait_change(self._metadata_changed,
                                          "metadata_changed")
        self._image_value = ImageData(data=z, value_depth=1)
        color_mapper = jet(DataRange1D(self._image_value))

        self.color_plot= CMapImagePlot(
            index=self._imag_index,
            index_mapper=index_mapper,
            value=self._image_value,
            value_mapper=color_mapper,
            padding=20,
            use_backbuffer=True,
            unified_draw=True)

        #Add axes to image plot
        left = PlotAxis(orientation='left',
                        title= "Frequency (GHz)",
                        mapper=self.color_plot.index_mapper._ymapper,
                        component=self.color_plot)

        self.color_plot.overlays.append(left)

        bottom = PlotAxis(orientation='bottom',
                        title= "Time (us)",
                        mapper=self.color_plot.index_mapper._xmapper,
                        component=self.color_plot)
        self.color_plot.overlays.append(bottom)

        self.color_plot.tools.append(PanTool(self.color_plot,
                                           constrain_key="shift"))
        self.color_plot.overlays.append(ZoomTool(component=self.color_plot,
                                            tool_mode="box", always_on=False))

        #Add line inspector tool for horizontal and vertical
        self.color_plot.overlays.append(LineInspector(component=self.color_plot,
                                               axis='index_x',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               is_listener=True,
                                               color="white"))

        self.color_plot.overlays.append(LineInspector(component=self.color_plot,
                                               axis='index_y',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               color="white",
                                               is_listener=True))

        myrange = DataRange1D(low=amin(z),
                              high=amax(z))
        cmap=jet
        self.colormap = cmap(myrange)

        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=myrange)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=self.color_plot,
                                 padding_top=self.color_plot.padding_top,
                                 padding_bottom=self.color_plot.padding_bottom,
                                 padding_right=40,
                                 resizable='v',
                                 width=30)#, ytitle="Magvec (mV)")

        #create horizontal line plot
        self.horiz_cross_plot = Plot(self.pd_horiz, resizable="h")
        self.horiz_cross_plot.height = 100
        self.horiz_cross_plot.padding = 20
        self.horiz_cross_plot.plot(("x", "horiz"))#,
                             #line_style="dot")
#        self.cross_plot.plot(("scatter_index","scatter_value","scatter_color"),
#                             type="cmap_scatter",
#                             name="dot",
#                             color_mapper=self._cmap(image_value_range),
#                             marker="circle",
#                             marker_size=8)

        self.horiz_cross_plot.index_range = self.color_plot.index_range.x_range

        #create vertical line plot
        self.vert_cross_plot = Plot(self.pd_vert, width = 140, orientation="v",
                                resizable="v", padding=20, padding_bottom=160)
        self.vert_cross_plot.plot(("y", "vert"))#,
#                             line_style="dot")
       # self.vert_cross_plot.xtitle="Magvec (mV)"
 #       self.vertica_cross_plot.plot(("vertical_scatter_index",
 #                              "vertical_scatter_value",
 #                              "vertical_scatter_color"),
 #                            type="cmap_scatter",
 #                            name="dot",
 #                            color_mapper=self._cmap(image_value_range),
 #                            marker="circle",
  #                           marker_size=8)

        self.vert_cross_plot.index_range = self.color_plot.index_range.y_range

        # Create a container and add components
        self.container = HPlotContainer(padding=40, fill_padding=True,
                                        bgcolor = "white", use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.horiz_cross_plot)
        inner_cont.add(self.color_plot)
        self.container.add(self.colorbar)
        self.container.add(inner_cont)
        self.container.add(self.vert_cross_plot)

    def _metadata_changed(self, old, new):
        """ This function takes out a cross section from the image data, based
        on the line inspector selections, and updates the line and scatter
        plots."""

        #self.cross_plot.value_range.low = self.minz
        #self.cross_plot.value_range.high = self.maxz
        #self.cross_plot2.value_range.low = self.minz
        #self.cross_plot2.value_range.high = self.maxz
        if self._imag_index.metadata.has_key("selections"):
            x_ndx, y_ndx = self._imag_index.metadata["selections"]
            if y_ndx and x_ndx:
#                xdata, ydata = self._image_index.get_data()
#                xdata, ydata = xdata.get_data(), ydata.get_data()
                self.pd_horiz.set_data("horiz", self._image_value.data[y_ndx,:])
                self.pd_vert.set_data("vert", self._image_value.data[:,x_ndx])
#                    scatter_index=array([xdata[x_ndx]]),
#                    scatter_index2=array([ydata[y_ndx]]),
#                    scatter_value=array([self._image_value.data[y_ndx, x_ndx]]),
#                    scatter_value2=array([self._image_value.data[y_ndx, x_ndx]]),
#                    scatter_color=array([self._image_value.data[y_ndx, x_ndx]]),
#                    scatter_color2=array([self._image_value.data[y_ndx, x_ndx]])
#                )
#        else:
#            self.pd.update_data({"scatter_value": array([]),
#                "scatter_value2": array([]), "line_value": array([]),
#                "line_value2": array([])})


#if __name__ == "__main__":

#    filename="/Users/thomasaref/Dropbox/Dad stuff/sample3/digitizer/lt/sample3_digitizer_f_sweep_t_300mk_100nspulse.hdf5"
#
#    with h5py.File(filename, 'r') as f:
#
#        time=f["Traces"]["d - AvgTrace - t"][:]
#        Magvec=f["Traces"]["d - AvgTrace - Magvec"][:]
#        frequency=f["Data"]["Data"][:]
#    #    for name in f["Data"]:
#    #        print name
#
#    time=squeeze(time)
#    Magvec=squeeze(Magvec)
#    frequency=squeeze(frequency)
#
#    x = time[:,0]*1.0e6
#    y = frequency[0,:]/1.0e9
#    z=transpose(Magvec*1000.0)
#
#    ip=ImagePlot(xs,ys,z)
#    ip.configure_traits()

#class Image_Plot(Atom):
#    plot_control=Instance(Plot_Control)
#    xtitle=DelegatesTo('plot_control')
#    ytitle=DelegatesTo('plot_control')
#    ztitle=DelegatesTo('plot_control')
#    request_redraw=DelegatesTo('plot_control')
#    #ykeys=DelegatesTo('plot_control')
#    container = Typed(HPlotContainer)
#    color_plot = Typed(CMapImagePlot)
#    plot=Instance(Plot)
#    vertical_cross_plot = Typed(Plot)
#    horizontal_cross_plot = Typed(Plot)
#    colorbar = Typed(ColorBar)
#    pd_all = Instance(ArrayPlotData)
#    _image_index=Instance(GridDataSource)
#    _image_value=Instance(ImageData)
#    data=Dict()
#
#    pd=Instance(ArrayPlotData)
#
#    traits_view = View(Group(Item('container', editor=ComponentEditor(), show_label=False),
#                             orientation='horizontal'),
#        width=1000, height=700, resizable=True, title="Chaco Plot")
#
#    def _xtitle_changed(self):
#        self.horiz_cross_plot.x_axis.title=self.xtitle
#        self.plot.x_axis.title=self.xtitle
#
#    def _ytitle_changed(self):
#        self.vert_cross_plot.y_axis.title=self.ytitle
#        self.plot.y_axis.title=self.ytitle
#
#    def _request_redraw_fired(self):
#        self.color_plot.request_redraw()
#        self.horiz_cross_plot.request_redraw()
#        self.vert_cross_plot.request_redraw()
#
#    def __init__(self, data, plot_control):
#        super(Image_Plot, self).__init__()
#        self.plot_control=plot_control
#        z=zeros((len(data['y']['0']), len(data['x']['0'])))
#        z[:] = nan
#        for key, item in data['z'].iteritems():
#            z[int(key)]=item
#        x=data['x']['0']
#        y=data['y']['0']
#        self.pd = ArrayPlotData(z=z, x=x, y=y, horiz=z[0, :], vert=z[:, 0])
#        self.plot=Plot(self.pd, padding=50, fill_padding=True,
#                        bgcolor="white", use_backbuffer=True,  unified_draw=True)
#        xgrid, ygrid = meshgrid(x, y)
#
#        color_plot=self.plot.img_plot('z', name="img_plot", xbounds=xgrid, ybounds=ygrid)[0]
#        self._image_index = color_plot.index #GridDataSource(xdata=x, ydata=y, sort_order=("ascending","ascending"))
#        self._image_index.on_trait_change(self._metadata_changed, "metadata_changed")
#        self._image_value=color_plot.value
#        self.value_range=DataRange1D(self._image_value)
#        color_plot.color_mapper = jet(self.value_range)
#        color_plot.tools.append(PanTool(color_plot,
#                                           constrain_key="shift"))
#        color_plot.overlays.append(ZoomTool(component=color_plot,
#                                            tool_mode="box", always_on=False))
#
#        color_plot.overlays.append(LineInspector(component=color_plot,
#                                               axis='index_x',
#                                               inspect_mode="indexed",
#                                               write_metadata=True,
#                                               is_listener=True,
#                                               color="white"))
#
#        color_plot.overlays.append(LineInspector(component=color_plot,
#                                               axis='index_y',
#                                               inspect_mode="indexed",
#                                               write_metadata=True,
#                                               color="white",
#                                               is_listener=True))
#
#        cbar_index_mapper = LinearMapper(range=self.value_range)
#        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
#                                 plot=color_plot,
#                                 padding_top=color_plot.padding_top,
#                                 padding_bottom=color_plot.padding_bottom,
#                                 padding_right=40,
#                                 resizable='v',
#                                 width=30)#, ytitle="Magvec (mV)")
#
#        #create horizontal line plot
#        self.horiz_cross_plot = Plot(self.pd, resizable="h", height=100, padding=50)
#        self.horiz_cross_plot.plot(("x", "horiz"))#,
#        self.horiz_cross_plot.index_range = color_plot.index_range.x_range
#
#        #create vertical line plot
#        self.vert_cross_plot = Plot(self.pd, width = 100, orientation="v",
#                                resizable="v", padding=50, padding_bottom=250)
#        self.vert_cross_plot.plot(("y", "vert"))
#        self.vert_cross_plot.index_range = color_plot.index_range.y_range
#        #self.vert_cross_plot.x_axis.tick_label_formatter = lambda x: '%.2g'%x
#        self.color_plot=color_plot
#
#        self.container = HPlotContainer(padding=40, fill_padding=True,
#                                        bgcolor = "white", use_backbuffer=False)
#        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
#        inner_cont.add(self.horiz_cross_plot)
#        inner_cont.add(self.plot)
#        self.container.add(self.colorbar)
#        self.container.add(inner_cont)
#        self.container.add(self.vert_cross_plot)
#        #self.vert_cross_plot.y_axis.title="Frequency"
#        #self.horiz_cross_plot.x_axis.title="Time (us)"
#
#    def _metadata_changed(self, old, new):
#        if self._image_index.metadata.has_key("selections"):
#            x_ndx, y_ndx = self._image_index.metadata["selections"]
#            if y_ndx and x_ndx:
#                self.pd.set_data("horiz", self._image_value.data[y_ndx,:])
#                self.pd.set_data("vert", self._image_value.data[:,x_ndx])
#
#class Line_Plot(HasTraits):
#    plot_control=Instance(Plot_Control)
#    request_redraw=DelegatesTo('plot_control')
#    new_plot=DelegatesTo('plot_control')
#    xtitle=DelegatesTo('plot_control')
#    ytitle=DelegatesTo('plot_control')
#    title=DelegatesTo('plot_control')
#    show_legend=DelegatesTo('plot_control')
#    xyformat=DelegatesTo('plot_control')
#    plot=Instance(Plot)
#    keymap=DelegatesTo('plot_control') #Dict()
#    color_index=Int()
#    mycolors=List([ 'blue', 'red', 'green', 'purple',  'black', 'darkgray', 'cyan', 'magenta', 'orange'])
#    value_scale=DelegatesTo('plot_control')
#    index_scale=DelegatesTo('plot_control')
#    xcomplex=DelegatesTo('plot_control')
#    ycomplex=DelegatesTo('plot_control')
#
#    xkeys=DelegatesTo('plot_control')
#    zkeys=DelegatesTo('plot_control')
#    xindices=DelegatesTo('plot_control')
#    zindices=DelegatesTo('plot_control')
#    pd = Instance(ArrayPlotData)
#
#    def _value_scale_changed(self):
#         #if self.color_index!=0:
#             self.plot.value_scale = self.value_scale
#             self.plot.request_redraw()
#
#    def _index_scale_changed(self):
#         #if self.color_index!=0:
#             self.plot.index_scale = self.index_scale
#             self.plot.request_redraw()
#
#    def _show_legend_changed(self):
#        self.plot.legend.visible = self.show_legend
#        self.plot.request_redraw()
#
#    def _title_changed(self):
#        self.plot.title = self.title
#        self.plot.request_redraw()
#
#    def _xtitle_changed(self):
#        self.plot.x_axis.title=self.xtitle
#        self.plot.request_redraw()
#
#    def _ytitle_changed(self):
#        self.plot.y_axis.title=self.ytitle
#        self.plot.request_redraw()
#
#    def _request_redraw_fired(self):
#        self.plot.request_redraw()
#
#    def _new_plot_fired(self):
#        for key in self.plot.plots.keys():
#            self.remove_plot(key)
#        self.color_index=0
#        for n, name in enumerate(self.zkeys):
#            key='z'+str(name)
#            self.add_plot(key)
#
#    def _zkeys_changed(self,  name, old, new):
#        #print self.pd.list_data()
#        n=0
#        for key in self.pd.list_data():
#            if int(key[1:]) in new:
#                self.add_plot(key)
#                n=n+1
#            else:
#                self.remove_plot(key)
#
##        for key, plot in self.plot.plots.iteritems():
##            if int(key[1:]) in new:
##                if self.xyformat.t_color=="transparent" or self.xyformat.t_color==(1.0, 1.0, 1.0, 1.0) :
##                    color=self.mycolors[mod(n, len(self.mycolors))]
##                else:
##                    color=self.xyformat.t_color
##                plot[0].color=color
##                #plot[0].outline_color=self.xyformat.outline_color,
##                n=n+1
##
##            else:
##               plot[0].color="none"
#               #plot[0].outline_color="none"
#
#    def add_plot(self, key, z, xkey='x0', x=None):
#        if key not in self.plot.plots.keys() and key[0]!='x':
#            if self.xyformat.t_color=="transparent" or self.xyformat.t_color==(1.0, 1.0, 1.0, 1.0) :
#                color=self.mycolors[mod(self.color_index, len(self.mycolors))]
#            else:
#                color=self.xyformat.t_color
#
#            if x!=None:
#                self.pd.set_data(xkey, x)
#            self.pd.set_data(key, z)
#
#            #if self.color_index<len(self.xkeys):
#            #    xkey='x'+str(self.xkeys[self.color_index])
#            #else:
#            #    xkey='x'+str(self.xkeys[0])
#            self.plot.plot((xkey, key),
#                           name=key,
#                           type=self.xyformat.plot_type,
#                           line_width=self.xyformat.line_width,
#                           color=color,
#                           outline_color=self.xyformat.outline_color,
#                           marker = self.xyformat.marker,
#                           marker_size = self.xyformat.marker_size)
#            self.color_index=self.color_index+1
#
#    def remove_plot(self, key):
#        if key in self.plot.plots.keys():
#            self.plot.delplot(key)
#
#    def __init__(self, data, plot_control, *args, **kws):
#        super(Line_Plot, self).__init__(*args, **kws)
#        self.plot_control=plot_control
#        self.pd = ArrayPlotData()
#
#        for name, arr in sorted(data['z'].iteritems()):
#                self.pd.set_data('z'+str(name), arr)
#
#        for name, arr in sorted(data['x'].iteritems()):
#                self.pd.set_data('x'+str(name), arr)
#
#        plot = Plot(self.pd, padding=50, fill_padding=True,
#                        bgcolor="white", use_backbuffer=True)
#
#        # Attach some tools to the plot
#        plot.tools.append(PanTool(plot))
#        zoom = ZoomTool(component=plot, tool_mode="box", always_on=False)
#        plot.overlays.append(zoom)
#        plot.legend.tools.append(LegendTool(plot.legend, drag_button="right"))
#        self.plot=plot
#
#        for n, item in enumerate(self.zkeys):
#                key='z'+str(item)
#                if self.xyformat.t_color=="transparent" or self.xyformat.t_color==(1.0, 1.0, 1.0, 1.0) :
#                    color=self.mycolors[mod(n, len(self.mycolors))]
#                else:
#                    color=self.xyformat.t_color
#                if n<len(self.xkeys):
#                    xkey='x'+str(self.xkeys[n])
#                else:
#                    xkey='x'+str(self.xkeys[0])
#                #n=n+1
#                #self.pd.set_data(key, magphase(self._image_value.data[int(item)], self.ycomplex))
#                self.plot.plot((xkey, key),
#                                name=key,
#                                type=self.xyformat.plot_type,
#                                line_width=self.xyformat.line_width,
#                                color=color,
#                                outline_color=self.xyformat.outline_color,
#                                marker = self.xyformat.marker,
#                                marker_size = self.xyformat.marker_size)
#        self.plot.value_scale = self.value_scale
#        self.plot.index_scale= self.index_scale
#
#
#    traits_view = View(Item('plot', style='custom',editor=ComponentEditor(),
#                             show_label=False),
#                    resizable=True, title="Chaco Plot",
#                    width=800, height=700, #kind='modal',
#                    buttons=[OKButton, CancelButton]
#                    )
示例#14
0
class myImagePlot(HasTraits):
    # container for all plots
    container = Instance(HPlotContainer)
    
    # Plot components within this container:
    color_plot = Instance(CMapImagePlot)
    vertical_cross_plot = Instance(Plot)
    horizontal_cross_plot = Instance(Plot)
    colorbar = Instance(ColorBar)
    
    # plot data
    pd_all = Instance(ArrayPlotData)
    pd_horiz=Instance(ArrayPlotData)
    pd_vert=Instance(ArrayPlotData)
    #private data storage
    _imag_index=Instance(GridDataSource)
    _image_value=Instance(ImageData)   
    
    traits_view = View(
        Item('container', editor=ComponentEditor(), show_label=False),
        width=1000, height=700, resizable=True, title="Chaco Plot")

    def __init__(self, x,y,z):
        super(myImagePlot, self).__init__()
        self.pd_all = ArrayPlotData(imagedata = z)
        self.pd_horiz = ArrayPlotData(x=x, horiz=z[4, :])
        self.pd_vert = ArrayPlotData(y=y, vert=z[:,5])
    
        self._imag_index = GridDataSource(xdata=x, ydata=y, sort_order=("ascending","ascending"))
        index_mapper = GridMapper(range=DataRange2D(self._imag_index))
        self._imag_index.on_trait_change(self._metadata_changed,
                                          "metadata_changed")
        self._image_value = ImageData(data=z, value_depth=1)
        color_mapper = jet(DataRange1D(self._image_value))

        self.color_plot= CMapImagePlot(
            index=self._imag_index,
            index_mapper=index_mapper,
            value=self._image_value,
            value_mapper=color_mapper,
            padding=20,
            use_backbuffer=True,
            unified_draw=True)

        #Add axes to image plot            
        left = PlotAxis(orientation='left',
                        title= "Frequency (GHz)",
                        mapper=self.color_plot.index_mapper._ymapper,
                        component=self.color_plot)

        self.color_plot.overlays.append(left)
        
        bottom = PlotAxis(orientation='bottom',
                        title= "Time (us)",
                        mapper=self.color_plot.index_mapper._xmapper,
                        component=self.color_plot)
        self.color_plot.overlays.append(bottom)

        self.color_plot.tools.append(PanTool(self.color_plot,
                                           constrain_key="shift"))
        self.color_plot.overlays.append(ZoomTool(component=self.color_plot,
                                            tool_mode="box", always_on=False))
                                            
        #Add line inspector tool for horizontal and vertical
        self.color_plot.overlays.append(LineInspector(component=self.color_plot,
                                               axis='index_x',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               is_listener=True,
                                               color="white"))

        self.color_plot.overlays.append(LineInspector(component=self.color_plot,
                                               axis='index_y',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               color="white",
                                               is_listener=True))         

        myrange = DataRange1D(low=amin(z),
                              high=amax(z))
        cmap=jet                         
        self.colormap = cmap(myrange)

        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=myrange)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=self.color_plot,
                                 padding_top=self.color_plot.padding_top,
                                 padding_bottom=self.color_plot.padding_bottom,
                                 padding_right=40,
                                 resizable='v',
                                 width=30)#, ytitle="Magvec (mV)")

        #create horizontal line plot
        self.horiz_cross_plot = Plot(self.pd_horiz, resizable="h")
        self.horiz_cross_plot.height = 100
        self.horiz_cross_plot.padding = 20
        self.horiz_cross_plot.plot(("x", "horiz"))#,
                             #line_style="dot")
#        self.cross_plot.plot(("scatter_index","scatter_value","scatter_color"),
#                             type="cmap_scatter",
#                             name="dot",
#                             color_mapper=self._cmap(image_value_range),
#                             marker="circle",
#                             marker_size=8)

        self.horiz_cross_plot.index_range = self.color_plot.index_range.x_range

        #create vertical line plot
        self.vert_cross_plot = Plot(self.pd_vert, width = 140, orientation="v", 
                                resizable="v", padding=20, padding_bottom=160)
        self.vert_cross_plot.plot(("y", "vert"))#,
#                             line_style="dot")
       # self.vert_cross_plot.xtitle="Magvec (mV)"
 #       self.vertica_cross_plot.plot(("vertical_scatter_index",
 #                              "vertical_scatter_value",
 #                              "vertical_scatter_color"),
 #                            type="cmap_scatter",
 #                            name="dot",
 #                            color_mapper=self._cmap(image_value_range),
 #                            marker="circle",
  #                           marker_size=8)

        self.vert_cross_plot.index_range = self.color_plot.index_range.y_range

        # Create a container and add components
        self.container = HPlotContainer(padding=40, fill_padding=True,
                                        bgcolor = "white", use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.horiz_cross_plot)
        inner_cont.add(self.color_plot)
        self.container.add(self.colorbar)
        self.container.add(inner_cont)
        self.container.add(self.vert_cross_plot)
        
    def _metadata_changed(self, old, new):
        """ This function takes out a cross section from the image data, based
        on the line inspector selections, and updates the line and scatter
        plots."""

        #self.cross_plot.value_range.low = self.minz
        #self.cross_plot.value_range.high = self.maxz
        #self.cross_plot2.value_range.low = self.minz
        #self.cross_plot2.value_range.high = self.maxz
        if self._imag_index.metadata.has_key("selections"):
            x_ndx, y_ndx = self._imag_index.metadata["selections"]
            if y_ndx and x_ndx:
#                xdata, ydata = self._image_index.get_data()
#                xdata, ydata = xdata.get_data(), ydata.get_data()
                self.pd_horiz.set_data("horiz", self._image_value.data[y_ndx,:])
                self.pd_vert.set_data("vert", self._image_value.data[:,x_ndx])
示例#15
0
    def _plot2(self):

        print"...plotting line scans"
        title = str(self.plotE['plot']) + str(self.plotE["notes"]["start"])
        self.plotdata5 = ArrayPlotData(imagedata = self.plotE['data'])



        self._image_index = GridDataSource(array([]),
                                          array([]),
                                          sort_order=("ascending","ascending"))

        self.xs = linspace(self.plotE["range"][0][0], self.plotE["range"][0][1], self.plotE["shape"][0])
        self.ys = linspace(self.plotE["range"][1][0], self.plotE["range"][1][1], self.plotE["shape"][1])

        self._image_index.set_data(self.xs, self.ys)

        image_index_range = DataRange2D(self._image_index)


        self._image_value = ImageData(data=array([]), value_depth=1)
        self._image_value.data = self.plotE['data']
        image_value_range = DataRange1D(self._image_value)

        s = ""
        f = ""
        if self.x_s: s = "X"
        if self.y_s: s = "Y"
        if self.z_s: s = "Z"
        if self.x_f: f = "X"
        if self.y_f: f = "Y"
        if self.z_f: f = "Z"


        self.plot1lines = Plot(self.plotdata5,  title = title)
        self.plot1lines.img_plot("imagedata", xbounds = (self.plotE["range"][0][0],self.plotE["range"][0][1]), ybounds = (self.plotE["range"][1][0],self.plotE["range"][1][1]), colormap=jet)
        img_plot = self.plot1lines.img_plot("imagedata")[0]
        imgtool = ImageInspectorTool(img_plot)
        img_plot.tools.append(imgtool)
        overlay = ImageInspectorOverlay(component=img_plot, image_inspector=imgtool,
                                    bgcolor="white", border_visible=True)
        self.plot1lines.overlays.append(overlay)
##ADD TOOLS

        self.plot1lines.tools.append(PanTool(self.plot1lines))
        zoom1 = ZoomTool(component=self.plot1lines, tool_mode="box", always_on=False)
        self.plot1lines.overlays.append(zoom1)

        self.plot1lines.overlays.append(LineInspector(component=self.plot1lines,
                                               axis='index_x',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               is_listener=False,
                                               #constrain_key="right",
                                               color="white"))
        self.plot1lines.overlays.append(LineInspector(component=self.plot1lines,
                                               axis='index_y',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               color="white",
                                               is_listener=False))




##ADD COLORBAR

        self.colorbar5  = ColorBar(index_mapper=LinearMapper(range=self.plot1lines.color_mapper.range),
                        color_mapper=self.plot1lines.color_mapper,
                        orientation='v',
                        resizable='v',
                        width=20,
                        padding=5)
        self.colorbar5.plot = self.plot1lines


        self.colorbar5.tools.append(PanTool(self.colorbar5, constrain_direction="y", constrain=True))
        self.zoom_overlay5 = ZoomTool(self.colorbar5, axis="index", tool_mode="range",
                            always_on=True, drag_button="right")
        self.colorbar5.overlays.append(self.zoom_overlay5)

        self.pd = ArrayPlotData(line_index = array([]),
                                line_value = array([]))



        self.slow_plot = Plot(self.pd,   title = "Slowline : " + self.slowline)
        self.slow_plot.plot(("line_index", "line_value"),
                             line_style='solid')

        self.pd.set_data("line_index2", array([]))
        self.pd.set_data("line_value2", array([]))

        self.fast_plot = Plot(self.pd,   title = "Fastline : " + self.fastline)
        self.fast_plot.plot(("line_index2", "line_value2"),
                             line_style='solid')


        self.pd.set_data("line_index", self.xs)
        self.pd.set_data("line_index2", self.ys)
        self.pd.set_data("line_value",
                                 self._image_value.data[self.fastscanline,:])
        self.pd.set_data("line_value2",
                                 self._image_value.data[:,self.slowscanline])

        self.colorbar5.padding= 0
        self.colorbar5.padding_left = 15
        #self.colorbar5.height = 400
        self.colorbar5.padding_top =50
        self.colorbar5.padding_bottom = 0
        self.colorbar5.padding_right = 25
        self.colorbar5.padding_left = 50

        self.plot1lines.width = 300
        self.plot1lines.padding_top = 50

        self.plot1lines.index_axis.title = 'fast axis (um)'
        self.plot1lines.value_axis.title = 'slow axis (um)'

        self.slow_plot.width = 100
        self.slow_plot.padding_right = 20

        self.fast_plot.width = 100
        self.fast_plot.padding_right = 20

        self.container2 = GridPlotContainer(shape = (1,2), spacing = ((0,0)), use_backbuffer=True,
	                                     valign = 'top', halign = 'center', bgcolor = 'white')
        self.container3 = GridPlotContainer(shape = (2,1), spacing = (0,0), use_backbuffer=True,
	                                     valign = 'top', halign = 'center', bgcolor = 'grey')
        self.container4 = GridPlotContainer(shape = (1,2), spacing = (0,0), use_backbuffer=True,
	                                     valign = 'top', halign = 'center', bgcolor = 'grey')

        self.container2.add(self.colorbar5)
        self.container3.add(self.fast_plot)
        self.container3.add(self.slow_plot)
        self.container4.add(self.container3)
        self.container4.add(self.plot1lines)
        self.container2.add(self.container4)
        self.LineScans = self.container2

        self._readline_fired()
        self._scale_set_fired()
 def setUp(self):
     self.data_source = GridDataSource(xdata=array([1, 2, 3]),
                                       ydata=array([1.5, 0.5, -0.5, -1.5]),
                                       sort_order=('ascending',
                                                   'descending'))
示例#17
0
 def setUp(self):
     self.x_ary = array([5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
     self.y_ary = array([1.0, 1.0, 2.0, 2.0, 3.0, 3.0])
     ds = GridDataSource(xdata=self.x_ary, ydata=self.y_ary)
     r = DataRange2D(ds)
     self.mapper = GridMapper(range=r)
class GridDataSourceTestCase(UnittestTools, unittest.TestCase):

    def setUp(self):
        self.data_source = GridDataSource(
            xdata=array([1, 2, 3]),
            ydata=array([1.5, 0.5, -0.5, -1.5]),
            sort_order=('ascending', 'descending'))

    def test_empty(self):
        data_source = GridDataSource()
        self.assertEqual(data_source.sort_order, ('none', 'none'))
        self.assertEqual(data_source.index_dimension, 'image')
        self.assertEqual(data_source.value_dimension, 'scalar')
        self.assertEqual(data_source.metadata,
                         {"selections":[], "annotations":[]})
        xdata, ydata = data_source.get_data()
        assert_array_equal(xdata.get_data(), array([]))
        assert_array_equal(ydata.get_data(), array([]))
        self.assertEqual(data_source.get_bounds(), ((0,0),(0,0)))

    def test_init(self):
        test_xd = array([1, 2, 3])
        test_yd = array([1.5, 0.5, -0.5, -1.5])
        test_sort_order = ('ascending', 'descending')

        self.assertEqual(self.data_source.sort_order, test_sort_order)
        xd, yd = self.data_source.get_data()
        assert_array_equal(xd.get_data(), test_xd)
        assert_array_equal(yd.get_data(), test_yd)
        self.assertEqual(self.data_source.get_bounds(),
                         ((min(test_xd),min(test_yd)),
                          (max(test_xd),max(test_yd))))

    def test_set_data(self):

        test_xd = array([0,2,4])
        test_yd = array([0,1,2,3,4,5])
        test_sort_order = ('none', 'none')

        self.data_source.set_data(xdata=test_xd, ydata=test_yd,
                             sort_order=('none', 'none'))

        self.assertEqual(self.data_source.sort_order, test_sort_order)
        xd, yd = self.data_source.get_data()
        assert_array_equal(xd.get_data(), test_xd)
        assert_array_equal(yd.get_data(), test_yd)
        self.assertEqual(self.data_source.get_bounds(),
                         ((min(test_xd),min(test_yd)),
                          (max(test_xd),max(test_yd))))

    def test_metadata(self):
        self.assertEqual(self.data_source.metadata,
                         {'annotations': [], 'selections': []})

    def test_metadata_changed(self):
        with self.assertTraitChanges(self.data_source, 'metadata_changed', count=1):
            self.data_source.metadata = {'new_metadata': True}

    def test_metadata_items_changed(self):
        with self.assertTraitChanges(self.data_source, 'metadata_changed', count=1):
            self.data_source.metadata['new_metadata'] = True
    def create_plot(self):

        # Create the mapper, etc
        self._image_index = GridDataSource(array([]),
                                          array([]),
                                          sort_order=("ascending","ascending"))
        image_index_range = DataRange2D(self._image_index)
        self._image_index.on_trait_change(self._metadata_changed,
                                          "metadata_changed")

        self._image_value = ImageData(data=array([]), value_depth=1)
        image_value_range = DataRange1D(self._image_value)



        # Create the contour plots
        self.polyplot = ContourPolyPlot(index=self._image_index,
                                        value=self._image_value,
                                        index_mapper=GridMapper(range=
                                            image_index_range),
                                        color_mapper=\
                                            self._cmap(image_value_range),
                                        levels=self.num_levels)

        self.lineplot = ContourLinePlot(index=self._image_index,
                                        value=self._image_value,
                                        index_mapper=GridMapper(range=
                                            self.polyplot.index_mapper.range),
                                        levels=self.num_levels)


        # Add a left axis to the plot
        left = PlotAxis(orientation='left',
                        title= "y",
                        mapper=self.polyplot.index_mapper._ymapper,
                        component=self.polyplot)
        self.polyplot.overlays.append(left)

        # Add a bottom axis to the plot
        bottom = PlotAxis(orientation='bottom',
                          title= "x",
                          mapper=self.polyplot.index_mapper._xmapper,
                          component=self.polyplot)
        self.polyplot.overlays.append(bottom)


        # Add some tools to the plot
        self.polyplot.tools.append(PanTool(self.polyplot,
                                           constrain_key="shift"))
        self.polyplot.overlays.append(ZoomTool(component=self.polyplot,
                                            tool_mode="box", always_on=False))
        self.polyplot.overlays.append(LineInspector(component=self.polyplot,
                                               axis='index_x',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               is_listener=True,
                                               color="white"))
        self.polyplot.overlays.append(LineInspector(component=self.polyplot,
                                               axis='index_y',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               color="white",
                                               is_listener=True))

        # Add these two plots to one container
        contour_container = OverlayPlotContainer(padding=20,
                                                 use_backbuffer=True,
                                                 unified_draw=True)
        contour_container.add(self.polyplot)
        contour_container.add(self.lineplot)


        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=image_value_range)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=self.polyplot,
                                 padding_top=self.polyplot.padding_top,
                                 padding_bottom=self.polyplot.padding_bottom,
                                 padding_right=40,
                                 resizable='v',
                                 width=30)

        self.pd = ArrayPlotData(line_index = array([]),
                                line_value = array([]),
                                scatter_index = array([]),
                                scatter_value = array([]),
                                scatter_color = array([]))

        self.cross_plot = Plot(self.pd, resizable="h")
        self.cross_plot.height = 100
        self.cross_plot.padding = 20
        self.cross_plot.plot(("line_index", "line_value"),
                             line_style="dot")
        self.cross_plot.plot(("scatter_index","scatter_value","scatter_color"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=8)

        self.cross_plot.index_range = self.polyplot.index_range.x_range

        self.pd.set_data("line_index2", array([]))
        self.pd.set_data("line_value2", array([]))
        self.pd.set_data("scatter_index2", array([]))
        self.pd.set_data("scatter_value2", array([]))
        self.pd.set_data("scatter_color2", array([]))

        self.cross_plot2 = Plot(self.pd, width = 140, orientation="v", resizable="v", padding=20, padding_bottom=160)
        self.cross_plot2.plot(("line_index2", "line_value2"),
                             line_style="dot")
        self.cross_plot2.plot(("scatter_index2","scatter_value2","scatter_color2"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=8)

        self.cross_plot2.index_range = self.polyplot.index_range.y_range



        # Create a container and add components
        self.container = HPlotContainer(padding=40, fill_padding=True,
                                        bgcolor = "white", use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.cross_plot)
        inner_cont.add(contour_container)
        self.container.add(self.colorbar)
        self.container.add(inner_cont)
        self.container.add(self.cross_plot2)
 def setUp(self):
     self.data_source = GridDataSource(
         xdata=array([1, 2, 3]),
         ydata=array([1.5, 0.5, -0.5, -1.5]),
         sort_order=('ascending', 'descending'))
    def create_plot(self):

        # Create the mapper, etc
        self._image_index = GridDataSource(array([]),
                                          array([]),
                                          sort_order=("ascending","ascending"))
        image_index_range = DataRange2D(self._image_index)
        self._image_index.on_trait_change(self._metadata_changed,
                                          "metadata_changed")

        self._image_value = ImageData(data=array([]), value_depth=1)
        image_value_range = DataRange1D(self._image_value)



        # Create the contour plots
        self.polyplot = ContourPolyPlot(index=self._image_index,
                                        value=self._image_value,
                                        index_mapper=GridMapper(range=
                                            image_index_range),
                                        color_mapper=\
                                            self._cmap(image_value_range),
                                        levels=self.num_levels)

        self.lineplot = ContourLinePlot(index=self._image_index,
                                        value=self._image_value,
                                        index_mapper=GridMapper(range=
                                            self.polyplot.index_mapper.range),
                                        levels=self.num_levels)


        # Add a left axis to the plot
        left = PlotAxis(orientation='left',
                        title= "y",
                        mapper=self.polyplot.index_mapper._ymapper,
                        component=self.polyplot)
        self.polyplot.overlays.append(left)

        # Add a bottom axis to the plot
        bottom = PlotAxis(orientation='bottom',
                          title= "x",
                          mapper=self.polyplot.index_mapper._xmapper,
                          component=self.polyplot)
        self.polyplot.overlays.append(bottom)


        # Add some tools to the plot
        self.polyplot.tools.append(PanTool(self.polyplot,
                                           constrain_key="shift"))
        self.polyplot.overlays.append(ZoomTool(component=self.polyplot,
                                            tool_mode="box", always_on=False))
        self.polyplot.overlays.append(LineInspector(component=self.polyplot,
                                               axis='index_x',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               is_listener=True,
                                               color="white"))
        self.polyplot.overlays.append(LineInspector(component=self.polyplot,
                                               axis='index_y',
                                               inspect_mode="indexed",
                                               write_metadata=True,
                                               color="white",
                                               is_listener=True))

        # Add these two plots to one container
        contour_container = OverlayPlotContainer(padding=20,
                                                 use_backbuffer=True,
                                                 unified_draw=True)
        contour_container.add(self.polyplot)
        contour_container.add(self.lineplot)


        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=image_value_range)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=self.polyplot,
                                 padding_top=self.polyplot.padding_top,
                                 padding_bottom=self.polyplot.padding_bottom,
                                 padding_right=40,
                                 resizable='v',
                                 width=30)

        self.pd = ArrayPlotData(line_index = array([]),
                                line_value = array([]),
                                scatter_index = array([]),
                                scatter_value = array([]),
                                scatter_color = array([]))

        self.cross_plot = Plot(self.pd, resizable="h")
        self.cross_plot.height = 100
        self.cross_plot.padding = 20
        self.cross_plot.plot(("line_index", "line_value"),
                             line_style="dot")
        self.cross_plot.plot(("scatter_index","scatter_value","scatter_color"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=8)

        self.cross_plot.index_range = self.polyplot.index_range.x_range

        self.pd.set_data("line_index2", array([]))
        self.pd.set_data("line_value2", array([]))
        self.pd.set_data("scatter_index2", array([]))
        self.pd.set_data("scatter_value2", array([]))
        self.pd.set_data("scatter_color2", array([]))

        self.cross_plot2 = Plot(self.pd, width = 140, orientation="v", 
                                resizable="v", padding=20, padding_bottom=160)
        self.cross_plot2.plot(("line_index2", "line_value2"),
                             line_style="dot")
        self.cross_plot2.plot(("scatter_index2",
                               "scatter_value2",
                               "scatter_color2"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=8)

        self.cross_plot2.index_range = self.polyplot.index_range.y_range



        # Create a container and add components
        self.container = HPlotContainer(padding=40, fill_padding=True,
                                        bgcolor = "white", use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.cross_plot)
        inner_cont.add(contour_container)
        self.container.add(self.colorbar)
        self.container.add(inner_cont)
        self.container.add(self.cross_plot2)
class ImagePlot(Atom):
    # container for all plots
    container = Typed(HPlotContainer)

    # Plot components within this container:
    color_plot = Typed(CMapImagePlot)
    vertical_cross_plot = Typed(Plot)
    horizontal_cross_plot = Typed(Plot)
    colorbar = Typed(ColorBar)

    # plot data
    pd_all = Typed(ArrayPlotData)
    #pd_horiz=Instance(ArrayPlotData)
    #pd_vert=Instance(ArrayPlotData)
    #private data storage
    _imag_index = Typed(GridDataSource)
    _image_value = Typed(ImageData)

    def __init__(self, x, y, z):
        super(ImagePlot, self).__init__()
        self.pd_all = ArrayPlotData(imagedata=z)
        #self.pd_horiz = ArrayPlotData(x=x, horiz=z[4, :])
        #self.pd_vert = ArrayPlotData(y=y, vert=z[:,5])

        self._imag_index = GridDataSource(xdata=x,
                                          ydata=y,
                                          sort_order=("ascending",
                                                      "ascending"))
        index_mapper = GridMapper(range=DataRange2D(self._imag_index))
        self._imag_index.on_trait_change(self._metadata_changed,
                                         "metadata_changed")
        self._image_value = ImageData(data=z, value_depth=1)
        color_mapper = jet(DataRange1D(self._image_value))

        self.color_plot = CMapImagePlot(index=self._imag_index,
                                        index_mapper=index_mapper,
                                        value=self._image_value,
                                        value_mapper=color_mapper,
                                        padding=20,
                                        use_backbuffer=True,
                                        unified_draw=True)

        #Add axes to image plot
        left = PlotAxis(orientation='left',
                        title="Frequency (GHz)",
                        mapper=self.color_plot.index_mapper._ymapper,
                        component=self.color_plot)

        self.color_plot.overlays.append(left)

        bottom = PlotAxis(orientation='bottom',
                          title="Time (us)",
                          mapper=self.color_plot.index_mapper._xmapper,
                          component=self.color_plot)
        self.color_plot.overlays.append(bottom)

        self.color_plot.tools.append(
            PanTool(self.color_plot, constrain_key="shift"))
        self.color_plot.overlays.append(
            ZoomTool(component=self.color_plot,
                     tool_mode="box",
                     always_on=False))

        #Add line inspector tool for horizontal and vertical
        self.color_plot.overlays.append(
            LineInspector(component=self.color_plot,
                          axis='index_x',
                          inspect_mode="indexed",
                          write_metadata=True,
                          is_listener=True,
                          color="white"))

        self.color_plot.overlays.append(
            LineInspector(component=self.color_plot,
                          axis='index_y',
                          inspect_mode="indexed",
                          write_metadata=True,
                          color="white",
                          is_listener=True))

        myrange = DataRange1D(low=amin(z), high=amax(z))
        cmap = jet
        self.colormap = cmap(myrange)

        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=myrange)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=self.color_plot,
                                 padding_top=self.color_plot.padding_top,
                                 padding_bottom=self.color_plot.padding_bottom,
                                 padding_right=40,
                                 resizable='v',
                                 width=30)  #, ytitle="Magvec (mV)")

        #create horizontal line plot
        self.horiz_cross_plot = Plot(self.pd_horiz, resizable="h")
        self.horiz_cross_plot.height = 100
        self.horiz_cross_plot.padding = 20
        self.horiz_cross_plot.plot(("x", "horiz"))  #,
        #line_style="dot")
        #        self.cross_plot.plot(("scatter_index","scatter_value","scatter_color"),
        #                             type="cmap_scatter",
        #                             name="dot",
        #                             color_mapper=self._cmap(image_value_range),
        #                             marker="circle",
        #                             marker_size=8)

        self.horiz_cross_plot.index_range = self.color_plot.index_range.x_range

        #create vertical line plot
        self.vert_cross_plot = Plot(self.pd_vert,
                                    width=140,
                                    orientation="v",
                                    resizable="v",
                                    padding=20,
                                    padding_bottom=160)
        self.vert_cross_plot.plot(("y", "vert"))  #,
        #                             line_style="dot")
        # self.vert_cross_plot.xtitle="Magvec (mV)"
        #       self.vertica_cross_plot.plot(("vertical_scatter_index",
        #                              "vertical_scatter_value",
        #                              "vertical_scatter_color"),
        #                            type="cmap_scatter",
        #                            name="dot",
        #                            color_mapper=self._cmap(image_value_range),
        #                            marker="circle",
        #                           marker_size=8)

        self.vert_cross_plot.index_range = self.color_plot.index_range.y_range

        # Create a container and add components
        self.container = HPlotContainer(padding=40,
                                        fill_padding=True,
                                        bgcolor="white",
                                        use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.horiz_cross_plot)
        inner_cont.add(self.color_plot)
        self.container.add(self.colorbar)
        self.container.add(inner_cont)
        self.container.add(self.vert_cross_plot)

    def _metadata_changed(self, old, new):
        """ This function takes out a cross section from the image data, based
        on the line inspector selections, and updates the line and scatter
        plots."""

        #self.cross_plot.value_range.low = self.minz
        #self.cross_plot.value_range.high = self.maxz
        #self.cross_plot2.value_range.low = self.minz
        #self.cross_plot2.value_range.high = self.maxz
        if self._imag_index.metadata.has_key("selections"):
            x_ndx, y_ndx = self._imag_index.metadata["selections"]
            if y_ndx and x_ndx:
                #                xdata, ydata = self._image_index.get_data()
                #                xdata, ydata = xdata.get_data(), ydata.get_data()
                self.pd_horiz.set_data("horiz",
                                       self._image_value.data[y_ndx, :])
                self.pd_vert.set_data("vert", self._image_value.data[:, x_ndx])


#                    scatter_index=array([xdata[x_ndx]]),
#                    scatter_index2=array([ydata[y_ndx]]),
#                    scatter_value=array([self._image_value.data[y_ndx, x_ndx]]),
#                    scatter_value2=array([self._image_value.data[y_ndx, x_ndx]]),
#                    scatter_color=array([self._image_value.data[y_ndx, x_ndx]]),
#                    scatter_color2=array([self._image_value.data[y_ndx, x_ndx]])
#                )
#        else:
#            self.pd.update_data({"scatter_value": array([]),
#                "scatter_value2": array([]), "line_value": array([]),
#                "line_value2": array([])})

#if __name__ == "__main__":

#    filename="/Users/thomasaref/Dropbox/Dad stuff/sample3/digitizer/lt/sample3_digitizer_f_sweep_t_300mk_100nspulse.hdf5"
#
#    with h5py.File(filename, 'r') as f:
#
#        time=f["Traces"]["d - AvgTrace - t"][:]
#        Magvec=f["Traces"]["d - AvgTrace - Magvec"][:]
#        frequency=f["Data"]["Data"][:]
#    #    for name in f["Data"]:
#    #        print name
#
#    time=squeeze(time)
#    Magvec=squeeze(Magvec)
#    frequency=squeeze(frequency)
#
#    x = time[:,0]*1.0e6
#    y = frequency[0,:]/1.0e9
#    z=transpose(Magvec*1000.0)
#
#    ip=ImagePlot(xs,ys,z)
#    ip.configure_traits()

#class Image_Plot(Atom):
#    plot_control=Instance(Plot_Control)
#    xtitle=DelegatesTo('plot_control')
#    ytitle=DelegatesTo('plot_control')
#    ztitle=DelegatesTo('plot_control')
#    request_redraw=DelegatesTo('plot_control')
#    #ykeys=DelegatesTo('plot_control')
#    container = Typed(HPlotContainer)
#    color_plot = Typed(CMapImagePlot)
#    plot=Instance(Plot)
#    vertical_cross_plot = Typed(Plot)
#    horizontal_cross_plot = Typed(Plot)
#    colorbar = Typed(ColorBar)
#    pd_all = Instance(ArrayPlotData)
#    _image_index=Instance(GridDataSource)
#    _image_value=Instance(ImageData)
#    data=Dict()
#
#    pd=Instance(ArrayPlotData)
#
#    traits_view = View(Group(Item('container', editor=ComponentEditor(), show_label=False),
#                             orientation='horizontal'),
#        width=1000, height=700, resizable=True, title="Chaco Plot")
#
#    def _xtitle_changed(self):
#        self.horiz_cross_plot.x_axis.title=self.xtitle
#        self.plot.x_axis.title=self.xtitle
#
#    def _ytitle_changed(self):
#        self.vert_cross_plot.y_axis.title=self.ytitle
#        self.plot.y_axis.title=self.ytitle
#
#    def _request_redraw_fired(self):
#        self.color_plot.request_redraw()
#        self.horiz_cross_plot.request_redraw()
#        self.vert_cross_plot.request_redraw()
#
#    def __init__(self, data, plot_control):
#        super(Image_Plot, self).__init__()
#        self.plot_control=plot_control
#        z=zeros((len(data['y']['0']), len(data['x']['0'])))
#        z[:] = nan
#        for key, item in data['z'].iteritems():
#            z[int(key)]=item
#        x=data['x']['0']
#        y=data['y']['0']
#        self.pd = ArrayPlotData(z=z, x=x, y=y, horiz=z[0, :], vert=z[:, 0])
#        self.plot=Plot(self.pd, padding=50, fill_padding=True,
#                        bgcolor="white", use_backbuffer=True,  unified_draw=True)
#        xgrid, ygrid = meshgrid(x, y)
#
#        color_plot=self.plot.img_plot('z', name="img_plot", xbounds=xgrid, ybounds=ygrid)[0]
#        self._image_index = color_plot.index #GridDataSource(xdata=x, ydata=y, sort_order=("ascending","ascending"))
#        self._image_index.on_trait_change(self._metadata_changed, "metadata_changed")
#        self._image_value=color_plot.value
#        self.value_range=DataRange1D(self._image_value)
#        color_plot.color_mapper = jet(self.value_range)
#        color_plot.tools.append(PanTool(color_plot,
#                                           constrain_key="shift"))
#        color_plot.overlays.append(ZoomTool(component=color_plot,
#                                            tool_mode="box", always_on=False))
#
#        color_plot.overlays.append(LineInspector(component=color_plot,
#                                               axis='index_x',
#                                               inspect_mode="indexed",
#                                               write_metadata=True,
#                                               is_listener=True,
#                                               color="white"))
#
#        color_plot.overlays.append(LineInspector(component=color_plot,
#                                               axis='index_y',
#                                               inspect_mode="indexed",
#                                               write_metadata=True,
#                                               color="white",
#                                               is_listener=True))
#
#        cbar_index_mapper = LinearMapper(range=self.value_range)
#        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
#                                 plot=color_plot,
#                                 padding_top=color_plot.padding_top,
#                                 padding_bottom=color_plot.padding_bottom,
#                                 padding_right=40,
#                                 resizable='v',
#                                 width=30)#, ytitle="Magvec (mV)")
#
#        #create horizontal line plot
#        self.horiz_cross_plot = Plot(self.pd, resizable="h", height=100, padding=50)
#        self.horiz_cross_plot.plot(("x", "horiz"))#,
#        self.horiz_cross_plot.index_range = color_plot.index_range.x_range
#
#        #create vertical line plot
#        self.vert_cross_plot = Plot(self.pd, width = 100, orientation="v",
#                                resizable="v", padding=50, padding_bottom=250)
#        self.vert_cross_plot.plot(("y", "vert"))
#        self.vert_cross_plot.index_range = color_plot.index_range.y_range
#        #self.vert_cross_plot.x_axis.tick_label_formatter = lambda x: '%.2g'%x
#        self.color_plot=color_plot
#
#        self.container = HPlotContainer(padding=40, fill_padding=True,
#                                        bgcolor = "white", use_backbuffer=False)
#        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
#        inner_cont.add(self.horiz_cross_plot)
#        inner_cont.add(self.plot)
#        self.container.add(self.colorbar)
#        self.container.add(inner_cont)
#        self.container.add(self.vert_cross_plot)
#        #self.vert_cross_plot.y_axis.title="Frequency"
#        #self.horiz_cross_plot.x_axis.title="Time (us)"
#
#    def _metadata_changed(self, old, new):
#        if self._image_index.metadata.has_key("selections"):
#            x_ndx, y_ndx = self._image_index.metadata["selections"]
#            if y_ndx and x_ndx:
#                self.pd.set_data("horiz", self._image_value.data[y_ndx,:])
#                self.pd.set_data("vert", self._image_value.data[:,x_ndx])
#
#class Line_Plot(HasTraits):
#    plot_control=Instance(Plot_Control)
#    request_redraw=DelegatesTo('plot_control')
#    new_plot=DelegatesTo('plot_control')
#    xtitle=DelegatesTo('plot_control')
#    ytitle=DelegatesTo('plot_control')
#    title=DelegatesTo('plot_control')
#    show_legend=DelegatesTo('plot_control')
#    xyformat=DelegatesTo('plot_control')
#    plot=Instance(Plot)
#    keymap=DelegatesTo('plot_control') #Dict()
#    color_index=Int()
#    mycolors=List([ 'blue', 'red', 'green', 'purple',  'black', 'darkgray', 'cyan', 'magenta', 'orange'])
#    value_scale=DelegatesTo('plot_control')
#    index_scale=DelegatesTo('plot_control')
#    xcomplex=DelegatesTo('plot_control')
#    ycomplex=DelegatesTo('plot_control')
#
#    xkeys=DelegatesTo('plot_control')
#    zkeys=DelegatesTo('plot_control')
#    xindices=DelegatesTo('plot_control')
#    zindices=DelegatesTo('plot_control')
#    pd = Instance(ArrayPlotData)
#
#    def _value_scale_changed(self):
#         #if self.color_index!=0:
#             self.plot.value_scale = self.value_scale
#             self.plot.request_redraw()
#
#    def _index_scale_changed(self):
#         #if self.color_index!=0:
#             self.plot.index_scale = self.index_scale
#             self.plot.request_redraw()
#
#    def _show_legend_changed(self):
#        self.plot.legend.visible = self.show_legend
#        self.plot.request_redraw()
#
#    def _title_changed(self):
#        self.plot.title = self.title
#        self.plot.request_redraw()
#
#    def _xtitle_changed(self):
#        self.plot.x_axis.title=self.xtitle
#        self.plot.request_redraw()
#
#    def _ytitle_changed(self):
#        self.plot.y_axis.title=self.ytitle
#        self.plot.request_redraw()
#
#    def _request_redraw_fired(self):
#        self.plot.request_redraw()
#
#    def _new_plot_fired(self):
#        for key in self.plot.plots.keys():
#            self.remove_plot(key)
#        self.color_index=0
#        for n, name in enumerate(self.zkeys):
#            key='z'+str(name)
#            self.add_plot(key)
#
#    def _zkeys_changed(self,  name, old, new):
#        #print self.pd.list_data()
#        n=0
#        for key in self.pd.list_data():
#            if int(key[1:]) in new:
#                self.add_plot(key)
#                n=n+1
#            else:
#                self.remove_plot(key)
#
##        for key, plot in self.plot.plots.iteritems():
##            if int(key[1:]) in new:
##                if self.xyformat.t_color=="transparent" or self.xyformat.t_color==(1.0, 1.0, 1.0, 1.0) :
##                    color=self.mycolors[mod(n, len(self.mycolors))]
##                else:
##                    color=self.xyformat.t_color
##                plot[0].color=color
##                #plot[0].outline_color=self.xyformat.outline_color,
##                n=n+1
##
##            else:
##               plot[0].color="none"
#               #plot[0].outline_color="none"
#
#    def add_plot(self, key, z, xkey='x0', x=None):
#        if key not in self.plot.plots.keys() and key[0]!='x':
#            if self.xyformat.t_color=="transparent" or self.xyformat.t_color==(1.0, 1.0, 1.0, 1.0) :
#                color=self.mycolors[mod(self.color_index, len(self.mycolors))]
#            else:
#                color=self.xyformat.t_color
#
#            if x!=None:
#                self.pd.set_data(xkey, x)
#            self.pd.set_data(key, z)
#
#            #if self.color_index<len(self.xkeys):
#            #    xkey='x'+str(self.xkeys[self.color_index])
#            #else:
#            #    xkey='x'+str(self.xkeys[0])
#            self.plot.plot((xkey, key),
#                           name=key,
#                           type=self.xyformat.plot_type,
#                           line_width=self.xyformat.line_width,
#                           color=color,
#                           outline_color=self.xyformat.outline_color,
#                           marker = self.xyformat.marker,
#                           marker_size = self.xyformat.marker_size)
#            self.color_index=self.color_index+1
#
#    def remove_plot(self, key):
#        if key in self.plot.plots.keys():
#            self.plot.delplot(key)
#
#    def __init__(self, data, plot_control, *args, **kws):
#        super(Line_Plot, self).__init__(*args, **kws)
#        self.plot_control=plot_control
#        self.pd = ArrayPlotData()
#
#        for name, arr in sorted(data['z'].iteritems()):
#                self.pd.set_data('z'+str(name), arr)
#
#        for name, arr in sorted(data['x'].iteritems()):
#                self.pd.set_data('x'+str(name), arr)
#
#        plot = Plot(self.pd, padding=50, fill_padding=True,
#                        bgcolor="white", use_backbuffer=True)
#
#        # Attach some tools to the plot
#        plot.tools.append(PanTool(plot))
#        zoom = ZoomTool(component=plot, tool_mode="box", always_on=False)
#        plot.overlays.append(zoom)
#        plot.legend.tools.append(LegendTool(plot.legend, drag_button="right"))
#        self.plot=plot
#
#        for n, item in enumerate(self.zkeys):
#                key='z'+str(item)
#                if self.xyformat.t_color=="transparent" or self.xyformat.t_color==(1.0, 1.0, 1.0, 1.0) :
#                    color=self.mycolors[mod(n, len(self.mycolors))]
#                else:
#                    color=self.xyformat.t_color
#                if n<len(self.xkeys):
#                    xkey='x'+str(self.xkeys[n])
#                else:
#                    xkey='x'+str(self.xkeys[0])
#                #n=n+1
#                #self.pd.set_data(key, magphase(self._image_value.data[int(item)], self.ycomplex))
#                self.plot.plot((xkey, key),
#                                name=key,
#                                type=self.xyformat.plot_type,
#                                line_width=self.xyformat.line_width,
#                                color=color,
#                                outline_color=self.xyformat.outline_color,
#                                marker = self.xyformat.marker,
#                                marker_size = self.xyformat.marker_size)
#        self.plot.value_scale = self.value_scale
#        self.plot.index_scale= self.index_scale
#
#
#    traits_view = View(Item('plot', style='custom',editor=ComponentEditor(),
#                             show_label=False),
#                    resizable=True, title="Chaco Plot",
#                    width=800, height=700, #kind='modal',
#                    buttons=[OKButton, CancelButton]
#                    )
    def create_plot(self):

        #-- Create the index for the x an y axes and the range over
	#-- which they vary
        self._image_index = GridDataSource(array([]), array([]),
                                          sort_order=("ascending","ascending"))
        image_index_range = DataRange2D(self._image_index)
        
	#-- I believe this is what allows tracking the mouse
        self._image_index.on_trait_change(self._metadata_changed,
                                          "metadata_changed")


	#-- Create the image values and determine their range
        self._image_value = ImageData(data=array([]), value_depth=1)
        image_value_range = DataRange1D(self._image_value)
        
        # Create the image plot
        self.imgplot = CMapImagePlot( index=self._image_index,
                                      value=self._image_value,
                                      index_mapper=GridMapper(range=image_index_range),
                                      color_mapper=self._cmap(image_value_range),)
                                 

        # Add a left axis to the plot
        left = PlotAxis(orientation='left',
                        title= "axial",
                        mapper=self.imgplot.index_mapper._ymapper,
                        component=self.imgplot)
        self.imgplot.overlays.append(left)

        # Add a bottom axis to the plot
        bottom = PlotAxis(orientation='bottom',
                          title= "radial",
                          mapper=self.imgplot.index_mapper._xmapper,
                          component=self.imgplot)
        self.imgplot.overlays.append(bottom)


        # Add some tools to the plot
        self.imgplot.tools.append(PanTool(self.imgplot,drag_button="right",
                                            constrain_key="shift"))

        self.imgplot.overlays.append(ZoomTool(component=self.imgplot,
                                            tool_mode="box", always_on=False))

        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=image_value_range)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=self.imgplot,
                                 padding_top=self.imgplot.padding_top,
                                 padding_bottom=self.imgplot.padding_bottom,
                                 padding_right=40,
                                 resizable='v',
                                 width=30)


	# Add a cursor 
	self.cursor = CursorTool( self.imgplot, drag_button="left", color="white")
	# the cursor is a rendered component so it goes in the overlays list
	self.imgplot.overlays.append(self.cursor)
                        
        # Create the two cross plots
        self.pd = ArrayPlotData(line_index = array([]),
                                line_value = array([]),
                                scatter_index = array([]),
                                scatter_value = array([]),
                                scatter_color = array([]))

        self.cross_plot = Plot(self.pd, resizable="h")
        self.cross_plot.height = 100
        self.cross_plot.padding = 20
        self.cross_plot.plot(("line_index", "line_value"),
                             line_style="dot")
        self.cross_plot.plot(("scatter_index","scatter_value","scatter_color"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=6)

        self.cross_plot.index_range = self.imgplot.index_range.x_range

        self.pd.set_data("line_index2", array([]))
        self.pd.set_data("line_value2", array([]))
        self.pd.set_data("scatter_index2", array([]))
        self.pd.set_data("scatter_value2", array([]))
        self.pd.set_data("scatter_color2", array([]))

        self.cross_plot2 = Plot(self.pd, width = 140, orientation="v", resizable="v", padding=20, padding_bottom=160)
        self.cross_plot2.plot(("line_index2", "line_value2"),
                             line_style="dot")
        self.cross_plot2.plot(("scatter_index2","scatter_value2","scatter_color2"),
                             type="cmap_scatter",
                             name="dot",
                             color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=8)

        self.cross_plot2.index_range = self.imgplot.index_range.y_range


        # Create a container and add sub-containers and components
        self.column_density = HPlotContainer(padding=40, fill_padding=True,
                                        bgcolor = "white", use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.cross_plot)
	self.imgplot.padding =20
	inner_cont.add(self.imgplot)
        self.column_density.add(self.colorbar)
        self.column_density.add(inner_cont)
        self.column_density.add(self.cross_plot2)
class ImageGUI(HasTraits):

    # TO FIX : put here the last available shot
    shot = File("L:\\data\\app3\\2011\\1108\\110823\\column_5200.ascii")

    # ---------------------------------------------------------------------------
    # Traits View Definitions
    # ---------------------------------------------------------------------------

    traits_view = View(
        HSplit(
            Item(
                "shot",
                style="custom",
                editor=FileEditor(filter=["column_*.ascii"]),
                show_label=False,
                resizable=True,
                width=400,
            ),
            Item("container", editor=ComponentEditor(), show_label=False, width=800, height=800),
        ),
        width=1200,
        height=800,
        resizable=True,
        title="APPARATUS 3 :: Analyze Images",
    )

    plot_edit_view = View(Group(Item("num_levels"), Item("colormap")), buttons=["OK", "Cancel"])

    num_levels = Int(15)
    colormap = Enum(color_map_name_dict.keys())

    # ---------------------------------------------------------------------------
    # Private Traits
    # ---------------------------------------------------------------------------

    _image_index = Instance(GridDataSource)
    _image_value = Instance(ImageData)

    _cmap = Trait(jet, Callable)

    # ---------------------------------------------------------------------------
    # Public View interface
    # ---------------------------------------------------------------------------

    def __init__(self, *args, **kwargs):
        super(ImageGUI, self).__init__(*args, **kwargs)
        self.create_plot()

    def create_plot(self):

        # Create the mapper, etc
        self._image_index = GridDataSource(array([]), array([]), sort_order=("ascending", "ascending"))
        image_index_range = DataRange2D(self._image_index)

        self._image_index.on_trait_change(self._metadata_changed, "metadata_changed")

        self._image_value = ImageData(data=array([]), value_depth=1)
        image_value_range = DataRange1D(self._image_value)

        # Create the image plot
        self.imgplot = CMapImagePlot(
            index=self._image_index,
            value=self._image_value,
            index_mapper=GridMapper(range=image_index_range),
            color_mapper=self._cmap(image_value_range),
        )

        # Create the contour plots
        # ~ self.polyplot = ContourPolyPlot(index=self._image_index,
        # ~ value=self._image_value,
        # ~ index_mapper=GridMapper(range=
        # ~ image_index_range),
        # ~ color_mapper=\
        # ~ self._cmap(image_value_range),
        # ~ levels=self.num_levels)

        # ~ self.lineplot = ContourLinePlot(index=self._image_index,
        # ~ value=self._image_value,
        # ~ index_mapper=GridMapper(range=
        # ~ self.polyplot.index_mapper.range),
        # ~ levels=self.num_levels)

        # Add a left axis to the plot
        left = PlotAxis(
            orientation="left", title="axial", mapper=self.imgplot.index_mapper._ymapper, component=self.imgplot
        )
        self.imgplot.overlays.append(left)

        # Add a bottom axis to the plot
        bottom = PlotAxis(
            orientation="bottom", title="radial", mapper=self.imgplot.index_mapper._xmapper, component=self.imgplot
        )
        self.imgplot.overlays.append(bottom)

        # Add some tools to the plot
        # ~ self.polyplot.tools.append(PanTool(self.polyplot,
        # ~ constrain_key="shift"))
        self.imgplot.overlays.append(ZoomTool(component=self.imgplot, tool_mode="box", always_on=False))
        self.imgplot.overlays.append(
            LineInspector(
                component=self.imgplot,
                axis="index_x",
                inspect_mode="indexed",
                write_metadata=True,
                is_listener=False,
                color="white",
            )
        )
        self.imgplot.overlays.append(
            LineInspector(
                component=self.imgplot,
                axis="index_y",
                inspect_mode="indexed",
                write_metadata=True,
                color="white",
                is_listener=False,
            )
        )

        # Add these two plots to one container
        contour_container = OverlayPlotContainer(padding=20, use_backbuffer=True, unified_draw=True)
        contour_container.add(self.imgplot)
        # ~ contour_container.add(self.polyplot)
        # ~ contour_container.add(self.lineplot)

        # Create a colorbar
        cbar_index_mapper = LinearMapper(range=image_value_range)
        self.colorbar = ColorBar(
            index_mapper=cbar_index_mapper,
            plot=self.imgplot,
            padding_top=self.imgplot.padding_top,
            padding_bottom=self.imgplot.padding_bottom,
            padding_right=40,
            resizable="v",
            width=30,
        )

        # Create the two cross plots
        self.pd = ArrayPlotData(
            line_index=array([]),
            line_value=array([]),
            scatter_index=array([]),
            scatter_value=array([]),
            scatter_color=array([]),
        )

        self.cross_plot = Plot(self.pd, resizable="h")
        self.cross_plot.height = 100
        self.cross_plot.padding = 20
        self.cross_plot.plot(("line_index", "line_value"), line_style="dot")
        self.cross_plot.plot(
            ("scatter_index", "scatter_value", "scatter_color"),
            type="cmap_scatter",
            name="dot",
            color_mapper=self._cmap(image_value_range),
            marker="circle",
            marker_size=8,
        )

        self.cross_plot.index_range = self.imgplot.index_range.x_range

        self.pd.set_data("line_index2", array([]))
        self.pd.set_data("line_value2", array([]))
        self.pd.set_data("scatter_index2", array([]))
        self.pd.set_data("scatter_value2", array([]))
        self.pd.set_data("scatter_color2", array([]))

        self.cross_plot2 = Plot(self.pd, width=140, orientation="v", resizable="v", padding=20, padding_bottom=160)
        self.cross_plot2.plot(("line_index2", "line_value2"), line_style="dot")
        self.cross_plot2.plot(
            ("scatter_index2", "scatter_value2", "scatter_color2"),
            type="cmap_scatter",
            name="dot",
            color_mapper=self._cmap(image_value_range),
            marker="circle",
            marker_size=8,
        )

        self.cross_plot2.index_range = self.imgplot.index_range.y_range

        # Create a container and add components
        self.container = HPlotContainer(padding=40, fill_padding=True, bgcolor="white", use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        inner_cont.add(self.cross_plot)
        inner_cont.add(contour_container)
        self.container.add(self.colorbar)
        self.container.add(inner_cont)
        self.container.add(self.cross_plot2)

    def update(self):
        imgdata = self.load_imagedata()
        if imgdata is not None:
            self.minz = imgdata.min()
            self.maxz = imgdata.max()
            self.colorbar.index_mapper.range.low = self.minz
            self.colorbar.index_mapper.range.high = self.maxz
            xs = numpy.linspace(0, imgdata.shape[0], imgdata.shape[0] + 1)
            ys = numpy.linspace(0, imgdata.shape[1], imgdata.shape[1] + 1)
            print xs
            print ys
            self._image_index.set_data(xs, ys)
            self._image_value.data = imgdata
            self.pd.set_data("line_index", xs)
            self.pd.set_data("line_index2", ys)
            self.container.invalidate_draw()
            self.container.request_redraw()

    def load_imagedata(self):
        try:
            dir = self.shot[self.shot.index(":\\") + 2 : self.shot.rindex("\\") + 1]
            shotnum = self.shot[self.shot.rindex("_") + 1 : self.shot.rindex(".ascii")]
        except ValueError:
            print " *** Not a valid column density path *** "
            return None
        # Set data path
        # Prepare PlotData object
        print dir
        print shotnum
        return load(dir, shotnum)

    # ---------------------------------------------------------------------------
    # Event handlers
    # ---------------------------------------------------------------------------

    def _shot_changed(self):
        self.update()

    def _metadata_changed(self, old, new):
        """ This function takes out a cross section from the image data, based
        on the line inspector selections, and updates the line and scatter
        plots."""

        self.cross_plot.value_range.low = self.minz
        self.cross_plot.value_range.high = self.maxz
        self.cross_plot2.value_range.low = self.minz
        self.cross_plot2.value_range.high = self.maxz
        if self._image_index.metadata.has_key("selections"):
            x_ndx, y_ndx = self._image_index.metadata["selections"]
            if y_ndx and x_ndx:
                self.pd.set_data("line_value", self._image_value.data[y_ndx, :])
                self.pd.set_data("line_value2", self._image_value.data[:, x_ndx])
                xdata, ydata = self._image_index.get_data()
                xdata, ydata = xdata.get_data(), ydata.get_data()
                self.pd.set_data("scatter_index", array([xdata[x_ndx]]))
                self.pd.set_data("scatter_index2", array([ydata[y_ndx]]))
                self.pd.set_data("scatter_value", array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_value2", array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_color", array([self._image_value.data[y_ndx, x_ndx]]))
                self.pd.set_data("scatter_color2", array([self._image_value.data[y_ndx, x_ndx]]))
        else:
            self.pd.set_data("scatter_value", array([]))
            self.pd.set_data("scatter_value2", array([]))
            self.pd.set_data("line_value", array([]))
            self.pd.set_data("line_value2", array([]))

    def _colormap_changed(self):
        self._cmap = color_map_name_dict[self.colormap]
        if hasattr(self, "polyplot"):
            value_range = self.polyplot.color_mapper.range
            self.polyplot.color_mapper = self._cmap(value_range)
            value_range = self.cross_plot.color_mapper.range
            self.cross_plot.color_mapper = self._cmap(value_range)
            # FIXME: change when we decide how best to update plots using
            # the shared colormap in plot object
            self.cross_plot.plots["dot"][0].color_mapper = self._cmap(value_range)
            self.cross_plot2.plots["dot"][0].color_mapper = self._cmap(value_range)
            self.container.request_redraw()

    def _num_levels_changed(self):
        if self.num_levels > 3:
            self.polyplot.levels = self.num_levels
            self.lineplot.levels = self.num_levels