def test_data_changed_events(self):
        # Test data.
        grumpy = numpy.ones((3, 4))
        grumpy_too = numpy.zeros(16)

        plot_data = ArrayPlotData()

        with self.monitor_events(plot_data) as events:
            plot_data.set_data('Grumpy', grumpy)
            self.assertEqual(events, [{'added': ['Grumpy']}])

        # While we're here, check that get_data works as advertised.
        grumpy_out = plot_data.get_data('Grumpy')
        self.assertIs(grumpy_out, grumpy)

        with self.monitor_events(plot_data) as events:
            plot_data.set_data('Grumpy', grumpy_too)
            self.assertEqual(events, [{'changed': ['Grumpy']}])

        with self.monitor_events(plot_data) as events:
            plot_data.del_data('Grumpy')
            self.assertEqual(events, [{'removed': ['Grumpy']}])
    def test_data_changed_events(self):
        # Test data.
        grumpy = numpy.ones((3, 4))
        grumpy_too = numpy.zeros(16)

        plot_data = ArrayPlotData()

        with self.monitor_events(plot_data) as events:
            plot_data.set_data("Grumpy", grumpy)
            self.assertEqual(events, [{"added": ["Grumpy"]}])

        # While we're here, check that get_data works as advertised.
        grumpy_out = plot_data.get_data("Grumpy")
        self.assertIs(grumpy_out, grumpy)

        with self.monitor_events(plot_data) as events:
            plot_data.set_data("Grumpy", grumpy_too)
            self.assertEqual(events, [{"changed": ["Grumpy"]}])

        with self.monitor_events(plot_data) as events:
            plot_data.del_data("Grumpy")
            self.assertEqual(events, [{"removed": ["Grumpy"]}])
示例#3
0
class calc(HasTraits):
    
    m1 = Float(25.0)
    t1 = Float(2.44)
    m2 = Float(30.0)
    t2 = Float(2.08)
    m3 = Float(35.0)
    t3 = Float(1.80)
    m4 = Float(40.0)
    t4 = Float(1.56)
    m5 = Float(45.0)
    t5 = Float(1.36)
    m6 = Float(50.0)
    t6 = Float(1.20)
    m7 = Float(55.0)
    t7 = Float(1.06)
    m8 = Float(60.0)
    t8 = Float(0.93)
    m9 = Float(65.0)
    t9 = Float(0.81)
    ct1 = Float
    ct2 = Float
    ct3 = Float
    ct4 = Float
    ct5 = Float
    ct6 = Float
    ct7 = Float
    ct8 = Float
    ct9 = Float
    rff= Float
    rss= Float
    drg12 = Float
    drg13 = Float
    des12 = Float
    des13 = Float
    drg2 = Float
    des2 = Float
    dbx1 = Float
    dbx2 = Float
    dbx3 = Float
    dbx4 = Float
    dbx5 = Float
    dbx6 = Float
    dbx7 = Float
    dbx8 = Float
    dbx9 = Float
    dby1 = Float
    dby2 = Float
    dby3 = Float
    dby4 = Float
    dby5 = Float
    dby6 = Float
    dby7 = Float
    dby8 = Float
    dby9 = Float
    dbn = Float
    dbk = Float
    dv1 = Float
    dv2 = Float
    dv3 = Float
    dv4 = Float
    dv5 = Float
    dv6 = Float
    dv7 = Float
    dv8 = Float
    dv9 = Float
    
    x=[]
    y=[]
    xrt=[]
    tt1 = Button()

    plot = Instance(HPlotContainer)

    traits_view = View(Group(Group(
        Group(
            Item(name='m1',label= u"温度1"),
            Item(name='t1',label= u"R1")
            ),
        Group(
            Item(name = 'm2',label= u"温度2"),
            Item(name = 't2',label= u"R2")
            ),
        Group(
            Item(name = 'm3',label= u"温度3"),
            Item(name = 't3',label= u"R3")
            ),
        Group(
            Item(name = 'm4',label= u"温度4"),
            Item(name = 't4',label= u"R4")
            ),
        Group(
            Item(name = 'm5',label= u"温度5"),
            Item(name = 't5',label= u"R5")
            ),
        Group(
            Item(name = 'm6',label= u"温度6"),
            Item(name = 't6',label= u"R6"),
            ),
        Group(
            Item(name = 'm7',label= u"温度7"),
            Item(name = 't7',label= u"R7"),
            ),
        Group(
            Item(name = 'm8',label= u"温度8"),
            Item(name = 't8',label= u"R8"),
            ),
        Group(
            Item(name = 'm9',label= u"温度9"),
            Item(name = 't9',label= u"R9"),
            ),
        Item('tt1', label=u"输入完成",show_label=False),
        orientation= 'horizontal',
        label = u'输入数据',
        show_border = True
        ),
        Group(Group(Item('plot', editor=ComponentEditor(), show_label=False),
            orientation = "vertical",            
            label = u'数据处理',
        show_border = True)),
        label=u'第一页'),
        Group(Group(
            Group(
                Item(name='m1',label= u"温度1"),
                Item(name='ct1',label= u"R1")
                ),
            Group(
                Item(name = 'm2',label= u"温度2"),
                Item(name = 'ct2',label= u"R2")
                ),
            Group(
                Item(name = 'm3',label= u"温度3"),
                Item(name = 'ct3',label= u"R3")
                ),
            Group(
                Item(name = 'm4',label= u"温度4"),
                Item(name = 'ct4',label= u"R4")
                ),
            Group(
                Item(name = 'm5',label= u"温度5"),
                Item(name = 'ct5',label= u"R5")
                ),
            Group(
                Item(name = 'm6',label= u"温度6"),
                Item(name = 'ct6',label= u"R6"),
                ),
            Group(
                Item(name = 'm7',label= u"温度7"),
                Item(name = 'ct7',label= u"R7"),
                ),
            Group(
                Item(name = 'm8',label= u"温度8"),
                Item(name = 'ct8',label= u"R8"),
                ),
            Group(
                Item(name = 'm9',label= u"温度9"),
                Item(name = 'ct9',label= u"R9"),
                ),
                orientation= 'horizontal',
                label = u'理论数据',
                show_border = True
                ),
            Group(
                Item(name='rff',label= u'Rf'),
                Item(name='rss',label= u'Rs'),
            ),
            Group(Group(
                Item(name='drg12',label= u'Rg12'),
                Item(name='drg13',label= u'Rg13'),
                ),
                Group(
                Item(name='des12',label= u'Es12'),
                Item(name='des13',label= u'Es13')
                ),
                Group(
                Item(name='drg2',label= u'Rg2'),
                Item(name='des2',label= u'Es2'),
                ),
            orientation= 'horizontal',
            ),
        Group(
            Group(
                Item(name='dbx1',label= u"X1"),
                Item(name='dby1',label= u"Y1")
                ),
            Group(
                Item(name = 'dbx2',label= u"X2"),
                Item(name = 'dby2',label= u"Y2")
                ),
            Group(
                Item(name = 'dbx3',label= u"X3"),
                Item(name = 'dby3',label= u"Y3")
                ),
            Group(
                Item(name = 'dbx4',label= u"X4"),
                Item(name = 'dby4',label= u"Y4")
                ),
            Group(
                Item(name = 'dbx5',label= u"X5"),
                Item(name = 'dby5',label= u"Y5")
                ),
            Group(
                Item(name = 'dbx6',label= u"X6"),
                Item(name = 'dby6',label= u"Y6"),
                ),
            Group(
                Item(name = 'dbx7',label= u"X7"),
                Item(name = 'dby7',label= u"Y7"),
                ),
            Group(
                Item(name = 'dbx8',label= u"X8"),
                Item(name = 'dby8',label= u"Y8"),
                ),
            Group(
                Item(name = 'dbx9',label= u"X9"),
                Item(name = 'dby9',label= u"Y9"),
                ),
                orientation= 'horizontal',
                label = u'最小二乘法求bn',
                show_border = True
                ),
            Group(
                Item(name='dbn',label= u'Bn(K)斜率'),
                Item(name='dbk',label= u'截距')
                ),
            Group(
            Group(
                Item(name='m1',label= u"温度1"),
                Item(name='dv1',label= u"V1")
                ),
            Group(
                Item(name = 'm2',label= u"温度2"),
                Item(name = 'dv2',label= u"V2")
                ),
            Group(
                Item(name = 'm3',label= u"温度3"),
                Item(name = 'dv3',label= u"V3")
                ),
            Group(
                Item(name = 'm4',label= u"温度4"),
                Item(name = 'dv4',label= u"V4")
                ),
            Group(
                Item(name = 'm5',label= u"温度5"),
                Item(name = 'dv5',label= u"V5")
                ),
            Group(
                Item(name = 'm6',label= u"温度6"),
                Item(name = 'dv6',label= u"V6"),
                ),
            Group(
                Item(name = 'm7',label= u"温度7"),
                Item(name = 'dv7',label= u"V7"),
                ),
            Group(
                Item(name = 'm8',label= u"温度8"),
                Item(name = 'dv8',label= u"V8"),
                ),
            Group(
                Item(name = 'm9',label= u"温度9"),
                Item(name = 'dv9',label= u"V9"),
                ),
                orientation= 'horizontal',
                label = u'温度传感器的电压-温度 关系',
                show_border = True
                ),
            label = u'第二页'
        ),
        width=800, height=600, resizable=True,
        title=u"物理实验 by:hanliumaozhi"
                    )

    def __init__(self):
        super(calc, self).__init__()
        self.addele()
        self.plotdata = ArrayPlotData(x = self.x, y = self.y, xrt=self.xrt, xx=xx, yy=yy,tt=self.ttt)
        plot1 =  Plot(self.plotdata)
        plot1.plot(("x", "y"),type="line",color="blue",name='1')
        plot1.plot(("x", "y"),type="scatter",color="blue", marker = 'circle', marker_size = 2,name='1')
        plot1.plot(("x", "xrt"),type="line",color="red",name='2')
        plot1.plot(("x", "xrt"),type="scatter",color="red", marker = 'circle', marker_size = 2,name='2')
        plot2 = Plot(self.plotdata)
        plot2.plot(("xx", "yy"),type="line",color="blue")
        plot2.plot(("xx","yy"),type="scatter",color="blue")
        plot3 = Plot(self.plotdata)
        plot3.plot(("x", "tt"),type="line",color="blue")
        plot3.plot(("x","tt"),type="scatter",color="blue")
        container = HPlotContainer(plot1,plot2,plot3)        
        self.plot= container
        
        legend= Legend(padding=10, align="ur")
        legend.plots = plot1.plots
        plot1.overlays.append(legend)
        

        
        
        
    def _tt1_fired(self):
        self.addele()
        self.plotdata.del_data("x")
        self.plotdata.del_data("y")
        self.plotdata.del_data("xrt")
        self.plotdata.del_data("xx")
        self.plotdata.del_data("yy")
        self.plotdata.del_data("tt")
        self.plotdata.set_data("x",self.x)
        self.plotdata.set_data("y",self.y)
        self.plotdata.set_data("xrt",self.xrt)
        self.plotdata.set_data("xx",xx)
        self.plotdata.set_data("yy",yy)
        self.plotdata.set_data("tt",self.ttt)
        plot1 =  Plot(self.plotdata)
        plot1.plot(("x", "y"),type="line",color="blue",name='1')
        plot1.plot(("x", "y"),type="scatter",color="blue", marker = 'circle', marker_size = 2,name='1')
        plot1.plot(("x", "xrt"),type="line",color="red",name='2')
        plot1.plot(("x", "xrt"),type="scatter",color="red", marker = 'circle', marker_size = 2,name='2')
        plot2 = Plot(self.plotdata)
        plot2.plot(("xx", "yy"),type="line",color="blue")
        plot2.plot(("xx","yy"),type="scatter",color="blue")
        plot3 = Plot(self.plotdata)
        plot3.plot(("x", "tt"),type="line",color="blue")
        plot3.plot(("x","tt"),type="scatter",color="blue")
        container = HPlotContainer(plot1,plot2,plot3)        
        self.plot= container
        
        legend= Legend(padding=10, align="ur")
        legend.plots = plot1.plots
        plot1.overlays.append(legend)
        

        

            
        
        

        
    def addele(self):
        self.xrt=[]
        self.x=[self.m1,self.m2,self.m3,self.m4,self.m5,self.m6,self.m7,self.m8,self.m9]
        self.y=[self.t1,self.t2,self.t3,self.t4,self.t5,self.t6,self.t7,self.t8,self.t9]
        self.r=firstcalc(self.x, self.y)
        for i in self.x:
            self.xrt.append((e**(self.r[0][1]))*(e**(self.r[0][0]*(1.0/(i+273)-(1.0)/298))))
        self.rfs=calcc(self.y)
        self.ttt=[]
        for i in self.y:
            self.ttt.append((self.rfs[0]/((self.y[0]*i)/(self.y[0]+i)+self.rfs[1]))*(((self.y[0]*i)/(self.y[0]+i)+self.rfs[0]+self.rfs[1])/(rg2+self.rfs[0]+self.rfs[1])*es2-(i)/(self.y[0]+i)))
        f=open("t.txt","w")
        for i in self.xrt:
            f.write(str(i))
            f.write("\n")
        f.write(str(self.r[0][0]))
        f.write("\n")
        f.write(str(self.r[0][1]))
        f.close()
        self.rff=self.rfs[0]
        self.rss=self.rfs[1]
        self.ct1=self.xrt[0]
        self.ct2=self.xrt[1]
        self.ct3=self.xrt[2]
        self.ct4=self.xrt[3]
        self.ct5=self.xrt[4]
        self.ct6=self.xrt[5]
        self.ct7=self.xrt[6]
        self.ct8=self.xrt[7]
        self.ct9=self.xrt[8]
        self.drg12 = rg12
        self.drg13 = rg13
        self.des12 = es12
        self.des13 = es13
        self.drg2 = rg2
        self.des2 = es2
        self.dbx1 = round(xx[0],5)
        self.dbx2 = round(xx[1],5)
        self.dbx3 = round(xx[2],5)
        self.dbx4 = round(xx[3],5)
        self.dbx5 = round(xx[4],5)
        self.dbx6 = round(xx[5],5)
        self.dbx7 = round(xx[6],5)
        self.dbx8 = round(xx[7],5)
        self.dbx9 = round(xx[8],5)
        self.dby1 = round(yy[0],5)
        self.dby2 = round(yy[1],5)
        self.dby3 = round(yy[2],5)
        self.dby4 = round(yy[3],5)
        self.dby5 = round(yy[4],5)
        self.dby6 = round(yy[5],5)
        self.dby7 = round(yy[6],5)
        self.dby8 = round(yy[7],5)
        self.dby9 = round(yy[8],5)
        self.dbn = self.r[0][0]
        self.dbk = self.r[0][1]
        self.dv1= round(self.ttt[0],2)
        self.dv2= round(self.ttt[1],2)
        self.dv3= round(self.ttt[2],2)
        self.dv4= round(self.ttt[3],2)
        self.dv5= round(self.ttt[4],2)
        self.dv6= round(self.ttt[5],2)
        self.dv7= round(self.ttt[6],2)
        self.dv8= round(self.ttt[7],2)
        self.dv9= round(self.ttt[8],2)
示例#4
0
class MDAViewController(BaseImageController):
    # the image data for the factor plot (including any scatter data and
    #    quiver data)
    factor_plotdata = Instance(ArrayPlotData)
    # the actual plot object
    factor_plot = Instance(BasePlotContainer)
    # the image data for the score plot (may be a parent image for scatter overlays)
    score_plotdata = Instance(ArrayPlotData)
    score_plot = Instance(BasePlotContainer)
    component_index = Int(0)
    _selected_peak = Int(0)
    _contexts = List([])
    _context = Int(-1)
    dimensionality = Int(1)
    _characteristics = List(["None", "Height", "Orientation", "Eccentricity"])
    _characteristic = Int(0)    
    _vectors = List(["None", "Shifts", "Skew"])
    _vector = Int(0)
    vector_scale = Float(1.0)   
    _can_map_peaks = Bool(False)

    def __init__(self, treasure_chest=None, data_path='/rawdata',
                 *args, **kw):
        super(MDAViewController, self).__init__(*args, **kw)
        self.factor_plotdata = ArrayPlotData()
        self.score_plotdata = ArrayPlotData()
        if treasure_chest is not None:
            self.numfiles = len(treasure_chest.list_nodes(data_path))
            self.chest = treasure_chest
            self.data_path = data_path
            # populate the list of available contexts (if any)
            if self.chest.root.mda_description.nrows>0:
                self._contexts = self.chest.root.mda_description.col('context').tolist()
                context = self.get_context_name()
                self.dimensionality = self.chest.get_node_attr('/mda_results/'+context, 
                                                             'dimensionality')
                self.update_factor_image()
                self.update_score_image()

    @on_trait_change('_context')
    def context_changed(self):
        context = self.get_context_name()
        self.dimensionality = self.chest.get_node_attr('/mda_results/'+context, 
                                                     'dimensionality')
        self.render_active_factor_image(context)
        self.render_active_score_image(context)
        
        
    def increase_selected_component(self):
        # TODO: need to measure dimensionality somehow (node attribute, or array size?)
        if self.component_index == (self.dimensionality - 1):
            self.component_index = 0
        else:
            self.component_index += 1
    
    def decrease_selected_component(self):
        if self.component_index == 0:
            self.component_index = int(self.dimensionality - 1)
        else:
            self.component_index -= 1    

    def get_context_name(self):
        return str(self._contexts[self._context])

    def get_characteristic_name(self):
        return self._characteristics[self._characteristic]

    def get_vector_name(self):
        return self._vectors[self._vector]

    def render_active_factor_image(self, context):
        if self.chest.get_node_attr('/mda_results/'+context, 'on_peaks'):
            factors = self.chest.get_node('/mda_results/'+context+'/peak_factors')
            # return average cell image (will be overlaid with peak info)
            self.factor_plotdata.set_data('imagedata', 
                                          self.chest.root.cells.average[:])
            component = factors.read(start = self.component_index,
                             stop = self.component_index+1,
                             step = 1,)[:]
            numpeaks = self.chest.root.cell_peaks.getAttr('number_of_peaks')
            index_keys = ['y%i' % i for i in xrange(numpeaks)]
            value_keys = ['x%i' % i for i in xrange(numpeaks)]

            values = np.array(component[value_keys]).view(float)
            indices = np.array(component[index_keys]).view(float)

            self.factor_plotdata.set_data('value', values)
            self.factor_plotdata.set_data('index', indices)

            if self.get_vector_name() != "None":
                field = ''
                y_invert=1
                if self.get_vector_name() == 'Shifts':
                    field = 'd'
                    y_invert=-1
                elif self.get_vector_name() == 'Skew':
                    field = 's'
                if field != '':
                    vector_x_keys = ['%sx%i' % (field, i) for i in xrange(numpeaks)]
                    vector_y_keys = ['%sy%i' % (field, i) for i in xrange(numpeaks)]
                    vector_x = np.array(component[vector_x_keys]).view(float).reshape((-1,1))
                    vector_y = y_invert*np.array(component[vector_y_keys]).view(float).reshape((-1,1))
                    vectors = np.hstack((vector_x,vector_y))
                    vectors *= self.vector_scale
                    self.factor_plotdata.set_data('vectors',vectors)
                else:
                    print "%s field not recognized for vector plots."%field
                    if 'vectors' in self.plotdata.arrays:
                        self.factor_plotdata.del_data('vectors')
                
            else:
                if 'vectors' in self.plotdata.arrays:
                    self.plotdata.del_data('vectors')
                # clear vector data
            if self.get_characteristic_name() != "None":
                color_prefix = self._characteristics[self._characteristic][0].lower()
                color_keys = ['%s%i' % (color_prefix, i) for i in xrange(numpeaks)]
                color = np.array(component[color_keys]).view(float)
                self.factor_plotdata.set_data('color', color)
            else:
                if 'color' in self.plotdata.arrays:
                    self.plotdata.del_data('color')
            self.factor_plot = self.get_scatter_quiver_plot(self.factor_plotdata,
                                                          tools=['colorbar'])
            self._can_map_peaks=True
        else:
            factors = self.chest.get_node('/mda_results/'+context+'/image_factors')
            # return current factor image (MDA on images themselves)
            self.factor_plotdata.set_data('imagedata', 
                                          factors[self.component_index,:,:])
            # return current factor image (MDA on images themselves)
            self.factor_plot = self.get_simple_image_plot(self.factor_plotdata)
            
    def render_active_score_image(self, context):
        self.score_plotdata.set_data('imagedata', self.get_active_image())
        values = self.chest.root.cell_description.read_where(
                'filename == "%s"' % self.get_active_name(),
                field='y_coordinate',)
    
        indices = self.chest.root.cell_description.read_where(
                'filename == "%s"' % self.get_active_name(),
                field='x_coordinate',)
        if self.chest.get_node_attr('/mda_results/'+context, 'on_peaks'):
            scores = self.chest.get_node('/mda_results/'+context+'/peak_scores')          
        else:
            scores = self.chest.get_node('/mda_results/'+context+'/image_scores')
        color = scores.read_where(
            'filename == "%s"' % self.get_active_name(),
            field='c%i' % self.component_index,
        )
        self.score_plotdata.set_data('index', values)
        self.score_plotdata.set_data('value', indices)
        self.score_plotdata.set_data('color', color)
        self.score_plot = self.get_scatter_overlay_plot(self.score_plotdata, title=self.get_active_name(),
                                                        tools=["colorbar","zoom","pan"])

    @on_trait_change("component_index, _characteristic, _vector, vector_scale")
    def update_factor_image(self):
        context = self.get_context_name()
        self.render_active_factor_image(context)

    @on_trait_change("selected_index, component_index")
    def update_score_image(self):
        context = self.get_context_name()
        self.render_active_score_image(context)
        
    def open_factor_save_UI(self):
        self.open_save_UI(plot_id = 'factor_plot')
    
    def open_score_save_UI(self):
        self.open_save_UI(plot_id = 'score_plot')
示例#5
0
class ProfileEditor(HasTraits):
    """ The line profile intitial guess editor class.
    This is the line profile editor module.
    It can be used to provide initial guesses to any fitting package according
    to the user's tastes.

    Usage: ProfileEditor(wave, data, errors, center)

    * wave: one-dimensional wavelength-like array, can be either
            wavelength or velocity units.
    * data: one-dimensional spectrum.
    * errors: one-dimensional noise spectrum.
    * center: Float. Central wavelength. 0 if in velocity space.
    """
    # TODO: Implement the model in a different way. Make a class for each, add
    # them as new instances, use DelegatesTo for the important parts, maybe
    # even not, maybe just use the 'object.MyInstance.attribute' notation and
    # only store the plotdata of the given model in this class...?
    # Probably needs a new way to store the parameters, though. Either the
    # Components Dict can take an extra "Kind" keyword in it, or restructure
    # the whole thing into a DataFrame object...? The latter will require a
    # major amount of work. On the other hand, it could mean much better
    # modularity.

    CompNum = Int(1)
    Components = Dict
    Locks = Dict
    # FitSettings = Dict
    CompoList = List()
    CompType = Enum(['Gauss', 'Absorption Voigt' 'Absorption Gauss'])

    x = Array
    mod_x = Array
    Feedback = Str

    sigmin = .1
    sigmax = 30.
    Sigma = Range(sigmin, sigmax)
    #Centr = Range(-100., 100., 0.)
    #Heigh = Range(0., 200000., 15)
    N = Range(1e12, 1e24, 1e13)
    b_param = Range(0., 200., 10.)

    # Define vars to regulate whether the above vars are locked
    #    down in the GUI:
    LockSigma = Bool()
    LockCentr = Bool()
    LockHeigh = Bool()
    LockConti = Bool()
    LockN = Bool()
    LockB = Bool()

    #continuum_estimate = Range(0., 2000.)
    plots = {}
    plotrange = ()
    resplot = Instance(Plot)
    Model = Array
    Resids = Property(Array, depends_on='Model')
    y = {}
    # Define buttons for interface:
    add_profile = Button(label='Add component')
    remove_profile = Button(label='Remove selected')
    Go_Button = Button(label='Fit model')
    plwin = Instance(GridContainer)
    select = Str
    line_center = Float()

    # Non-essentials, for use by outside callers:
    linesstring = Str('')
    transname = Str('')

    def _line_center_changed(self):
        self.build_plot()

    def _get_Resids(self):
        intmod = sp.interp(self.x, self.mod_x, self.Model)
        resids = (self.indata - intmod) / self.errs
        return resids

    def _Components_default(self):
        return {
            'Contin': [self.continuum_estimate, np.nan],
            'Comp1': [
                0.,
                .1,
                0.,
                'a',
                np.nan,
                np.nan,
                np.nan,
            ]
        }

    # Center, Sigma, Height, Identifier, Center-stddev, sigma-stddev,
    # ampl-stddev

    def _CompType_default(self):
        return 'Gauss'

    def _Locks_default(self):
        return {'Comp1': [False, False, False, False]}

    def _CompoList_default(self):
        return ['Comp1']

    def _y_default(self):
        return {}

    def _select_default(self):
        return 'Comp1'

    def build_plot(self):
        print 'Building plot...'
        fitrange = self.fitrange  # Just for convenience
        onearray = Array
        onearray = sp.ones(self.indata.shape[0])
        minuses = onearray * (-1.)

        # Define index array for fit function:
        self.mod_x = sp.arange(self.line_center - 50., self.line_center + 50.,
                               .01)
        self.Model = sp.zeros(self.mod_x.shape[0])

        # Establish continuum array in a way that opens for other, more
        #   elaborate continua.
        self.contarray = sp.ones(self.mod_x.shape[0]) * \
                self.Components['Contin'][0]
        self.y = {}

        for comp in self.CompoList:
            self.y[comp] = gauss(  # x, mu, sigma, amplitude
                self.mod_x, self.Components[comp][0] + self.line_center,
                self.Components[comp][1], self.Components[comp][2])

        self.Model = self.contarray + self.y[self.select]

        broca = BroadcasterTool()

        # Define the part of the data to show in initial view:
        plotrange = sp.where((self.x > self.line_center - 30)
                             & (self.x < self.line_center + 30))
        # Define the y axis max value in initial view (can be panned/zoomed):
        maxval = float(self.indata[fitrange].max() * 1.2)
        minval = maxval / 15.
        minval = abs(np.median(self.indata[fitrange])) * 1.5
        maxerr = self.errs[fitrange].max() * 1.3
        resmin = max(sp.absolute(self.Resids[self.fitrange]).max(), 5.) * 1.2
        cenx = sp.array([self.line_center, self.line_center])
        ceny = sp.array([-minval, maxval])
        cenz = sp.array([-maxval, maxval])
        # Gray shading of ignored ranges
        rangelist = np.array(self.rangelist)
        grayx = np.array(rangelist.flatten().repeat(2))
        grayx = np.hstack((self.x.min(), grayx, self.x.max()))
        grayy = np.ones_like(grayx) * self.indata.max() * 2.
        grayy[1::4] = -grayy[1::4]
        grayy[2::4] = -grayy[2::4]
        grayy = np.hstack((grayy[-1], grayy[:-1]))

        # Build plot of data and model
        self.plotdata = ArrayPlotData(
            wl=self.x,
            data=self.indata,
            xs=self.mod_x,
            cont=self.contarray,
            ones=onearray,
            minus=minuses,
            model=self.Model,
            errors=self.errs,
            ceny=ceny,
            cenz=cenz,
            cenx=cenx,
            Residuals=self.Resids,
            grayx=grayx,
            grayy=grayy,
        )

        # Add dynamically created components to plotdata
        for comp in self.CompoList:
            self.plotdata.set_data(comp, self.y[comp])
        olplot = GridContainer(shape=(2, 1),
                               padding=10,
                               fill_padding=True,
                               bgcolor='transparent',
                               spacing=(5, 10))
        plot = Plot(self.plotdata)
        plot.y_axis.title = 'Flux density'
        resplot = Plot(self.plotdata, tick_visible=True, y_auto=True)
        resplot.x_axis.title = u'Wavelength [Å]'
        resplot.y_axis.title = u'Residuals/std. err.'

        # Create initial plot: Spectrum data, default first component,
        #   default total line profile.

        self.comprenders = []

        self.datarender = plot.plot(('wl', 'data'),
                                    color='black',
                                    name='Data',
                                    render_style='connectedhold')

        self.contrender = plot.plot(('xs', 'cont'),
                                    color='darkgray',
                                    name='Cont')

        self.modlrender = plot.plot(('xs', 'model'),
                                    color='blue',
                                    line_width=1.6,
                                    name='Model')

        self.centrender = plot.plot(('cenx', 'ceny'),
                                    color='black',
                                    type='line',
                                    line_style='dot',
                                    name='Line center',
                                    line_width=1.)

        self.rangrender = plot.plot(
            ('grayx', 'grayy'),
            type='polygon',
            face_color='lightgray',
            edge_color='gray',
            face_alpha=0.3,
            alpha=0.3,
        )

        # There may be an arbitrary number of gaussian components, so:
        print 'Updating model'
        for comp in self.CompoList:
            self.comprenders.append(
                plot.plot(
                    ('xs', comp),
                    type='line',
                    color=Paired[self.Components[comp]
                                 [3]],  # tuple(COLOR_PALETTE[self.CompNum]),
                    line_color=Paired[self.Components[comp][
                        3]],  # tuple(COLOR_PALETTE[self.CompNum]),
                    line_style='dash',
                    name=comp))

        # Create panel with residuals:
        resplot.plot(('wl', 'Residuals'), color='black', name='Resids')
        resplot.plot(('wl', 'ones'), color='green')
        resplot.plot(('wl', 'minus'), color='green')
        resplot.plot(('cenx', 'cenz'),
                     color='red',
                     type='line',
                     line_style='dot',
                     line_width=.5)
        resplot.plot(
            ('grayx', 'grayy'),  # Yes, that one again
            type='polygon',
            face_color='lightgray',
            edge_color='gray',
            face_alpha=0.3,
            alpha=0.3,
        )
        plot.x_axis.visible = False

        # Set ranges to change automatically when plot values change.
        plot.value_range.low_setting,\
            plot.value_range.high_setting = (-minval, maxval)
        plot.index_range.low_setting,\
            plot.index_range.high_setting = (self.line_center - 30.,
                                             self.line_center + 30.)
        resplot.value_range.low_setting,\
            resplot.value_range.high_setting = (-resmin, resmin)
        resplot.index_range.low_setting,\
            resplot.index_range.high_setting = (plot.index_range.low_setting,
                                                plot.index_range.high_setting)
        #resplot.index_range = plot.index_range  # Yes or no? FIXME
        plot.overlays.append(
            ZoomTool(plot,
                     tool_mode='box',
                     drag_button='left',
                     always_on=False))

        resplot.overlays.append(
            ZoomTool(resplot,
                     tool_mode='range',
                     drag_button='left',
                     always_on=False))

        # List of renderers to tell the legend what to write
        self.plots['Contin'] = self.contrender
        self.plots['Center'] = self.centrender
        self.plots['Model'] = self.modlrender
        for i in sp.arange(len(self.comprenders)):
            self.plots[self.CompoList[i]] = self.comprenders[i]

        # Build Legend:
        legend = Legend(component=plot, padding=10, align="ur")
        legend.tools.append(LegendTool(legend, drag_button="right"))
        legend.plots = self.plots
        plot.overlays.append(legend)
        olplot.tools.append(broca)
        pan = PanTool(plot)
        respan = PanTool(resplot, constrain=True, constrain_direction='x')
        broca.tools.append(pan)
        broca.tools.append(respan)
        plot.overlays.append(ZoomTool(plot, tool_mode='box', always_on=False))
        olplot.add(plot)
        olplot.add(resplot)
        olplot.components[0].set(resizable='hv', bounds=[500, 400])
        olplot.components[1].set(resizable='h', bounds=[500, 100])
        self.plot = plot
        self.resplot = resplot
        self.plwin = olplot
        self.legend = legend
        self.plotrange = plotrange

    def __init__(self,
                 wavlens,
                 indata,
                 inerrs,
                 linecen,
                 fitrange=None,
                 fitter='lmfit',
                 crange=[-100., 100.]):
        halfrange = 30.
        self.fitter = fitter
        self.x = wavlens
        wavmin = float(linecen - halfrange)
        wavmax = float(linecen + halfrange)
        self.fitrange = fitrange
        # print self.fitrange
        if fitrange is None:
            self.fitrange = [()]
        fitrange = []
        if len(self.fitrange) == 0:
            self.fitrange = [(self.line_center - halfrange,
                              self.line_center + halfrange)]
        self.rangelist = self.fitrange
        if len(self.fitrange) > 0:
            print 'Nonzero fitranges given: ', self.fitrange
            for ran in self.fitrange:
                rmin, rmax = ran[0], ran[1]
                fitrange += sp.where((self.x > rmin) & (self.x < rmax))
            fitrange = sp.hstack(fitrange[:])
            fitrange.sort()
        self.fitrange = fitrange
        # Now the rest of the things
        self.indata = indata
        self.add_trait('Centr', Range(min(crange), max(crange), 0.))
        ### Set top and bottom data values and fit value ranges for amplitude:
        ###            ------------
        ampmin = float(-indata.std())
        ampmax = float(indata.max() + 2 * indata.std()) * 4.
        self.add_trait('Heigh', Range(ampmin, ampmax, 0.))
        ### same, for continuum:
        self.add_trait('continuum_estimate', Range(ampmin, ampmax / 4., 0.))
        ### Now add traits to represent fit limits.
        ### Then think about appropriate GUI for them
        ### So far, they will set sane default values and be scriptable.
        ###              ---------------
        self.add_trait('ampfitmax', Range(0, ampmax, ampmax))
        self.add_trait('ampfitmin', Range(ampmin, 0, ampmin))
        self.add_trait('contfitmax', Range(0, ampmax, ampmax))
        self.add_trait('contfitmin', Range(ampmin, 0, ampmin))
        ### The below version needs working on but could be
        ### the beginning of an interative interface.
        ### Commented out for now.
        ###         ---------------
        # self.add_trait('wavfitmax', Range(0, wavmax, wavmax))
        # self.add_trait('wavfitmin', Range(wavmin, 0, wavmin))
        ### Instead, we just set some sensible values.
        self.add_trait('wavfitmax', Range(0, 10., 10.))
        self.add_trait('wavfitmin', Range(-10., 0, -10.))
        self.add_trait('sigfitmax', Range(0.1, 20., 20.))
        self.add_trait('sigfitmin', Range(0.1, 20, 0.1))
        ### Add dict representing fit settings, now that
        ### all information needed is available.
        self.ampmin = ampmin
        self.ampmax = ampmax
        #self._Components_default()
        self.errs = inerrs
        self.line_center = linecen
        # Define index array for data:
        self.build_plot()

    ### =======================================================================
    #     Reactive functions: What happens when buttons are pressed, parameters
    #     are changes etc.

    # Add component to model

    def _add_profile_fired(self):
        """ Add new component to model
        """
        self.CompNum += 1
        next_num = int(self.CompoList[-1][-1]) + 1
        Name = 'Comp' + str(next_num)
        self.CompoList.append(Name)
        print "Added component nr. " + Name
        self.Components[Name] = [
            0., .1, 0.,
            chr(self.CompNum + 96), np.nan, np.nan, np.nan
        ]
        self.Locks[Name] = [False, False, False, False]
        self.select = Name
        # And the plotting part:
        #    Add y array for this component.
        # self.y[self.select] = stats.norm.pdf(
        #     self.mod_x,
        #     self.Centr + self.line_center,
        #     self.Sigma) * self.Sigma * sp.sqrt(2. * sp.pi) * self.Heigh
        self.y[self.select] = gauss(self.mod_x, self.Centr + self.line_center,
                                    self.Sigma, self.Heigh)
        self.plotdata[self.select] = self.y[self.select]
        render = self.plot.plot(
            ('xs', self.select),
            type='line',
            line_style='dash',
            color=Paired[self.Components[Name]
                         [3]],  # tuple(COLOR_PALETTE[self.CompNum]),
            line_color=Paired[self.Components[Name]
                              [3]],  # tuple(COLOR_PALETTE[self.CompNum]),
            name=Name)
        self.plots[self.select] = render
        self.legend.plots = self.plots
        return

    def _remove_profile_fired(self):
        """ Remove the ~~last added~~ currently selected component.
        """
        if len(self.CompoList) > 1:
            comp_idx = self.CompoList.index(self.select)
            oldName = self.select  # 'Comp' + str(self.CompNum)
            newName = self.CompoList[comp_idx -
                                     1]  # 'Comp' + str(self.CompNum - 1)
            ### newName = 'Comp' + str(self.CompNum - 1)
            self.plot.delplot(oldName)
            self.plotdata.del_data(oldName)
            del self.y[oldName]
            del self.plots[oldName]
            del self.Components[oldName]
            del self.Locks[oldName]
            self.select = newName
            print 'Removed component nr. ' + str(self.CompNum)
            self.legend.plots = self.plots
            self.CompoList.pop(comp_idx)
            self.CompNum -= 1
        else:
            print 'No more components to remove'

    ##=========================================================================
    #    Here follows the functionality of the GO button, split up into one
    #    function per logical step, so it is easier to script this and do some
    #    non-standard tinkering like e.g. setting odd fit constraints in the
    #    model for this transition etc.
    ##=========================================================================

    def set_fit_data(self):
        # Make sure no data arrays belonging to the parent class are altered.
        x = self.x.copy()
        data = self.indata.copy()
        errs = self.errs.copy()
        if len(self.fitrange) > 0:
            x = x[self.fitrange]
            data = data[self.fitrange]
            errs = errs[self.fitrange]
        return x, data, errs

    def create_fit_param_frame(self):
        tmpdict = self.Components.copy()
        tmpdict.pop('Contin')
        tofit = pd.DataFrame.from_dict(tmpdict).T
        tofit.columns = [
            'Pos', 'Sigma', 'Ampl', 'Identifier', 'Pos_stddev', 'Sigma_stddev',
            'Ampl_stddev'
        ]
        tofit.set_value('Contin', 'Ampl', self.Components['Contin'][0])
        tofit['Line center'] = self.line_center
        tofit.set_value('Contin', 'Lock', self.LockConti)
        for lines in self.Components.keys():
            if lines == 'Contin':
                continue
            tofit.set_value(lines, 'Lock', self.Locks[lines][:3])
            tofit.set_value(lines, 'AmpMax', self.ampfitmax)
            tofit.set_value(lines, 'AmpMin', self.ampfitmin)
            tofit.set_value(lines, 'SigMax', self.sigfitmax)
            tofit.set_value(lines, 'SigMin', self.sigfitmin)
            tofit.set_value(lines, 'WavMax', self.wavfitmax)
            tofit.set_value(lines, 'WavMin', self.wavfitmin)
        self.tofit = tofit

    def load_parameters_to_fitter(self, fitter='lmfit'):
        if fitter == 'lmfit':
            try:
                import lmfit_wrapper as lw
            except ImportError:
                print 'Could not import LMfit'
                return
            self.params = lw.load_params(self.tofit)

    def fit_with_lmfit(self, method='lbfgsb', conf='covar', report=True):
        try:
            import lmfit_wrapper as lw
        except ImportError:
            print 'Could not import LMfit'
            return
        x, data, errs = self.set_fit_data()
        result = lw.fit_it(self.params,
                           args=(self.x[self.fitrange],
                                 self.indata[self.fitrange],
                                 self.errs[self.fitrange]),
                           method=method)
        if report:
            lw.lf.report_fit(result)
        output = lw.params_to_grism(result, output_format='df')
        output['Identifier'] = self.tofit['Identifier']
        output.set_value('Contin', 'Identifier', sp.float64('nan'))
        output['Pos'] -= self.tofit['Line center']
        outdict = {}
        for i in output.index:
            row = output.ix[i]
            if i == 'Contin':
                outdict[i] = [row['Ampl'], row['Ampl_stddev'], row['RedChi2']]
            else:
                outdict[i] = [
                    row['Pos'], row['Sigma'], row['Ampl'], row['Identifier'],
                    row['Pos_stddev'], row['Sigma_stddev'], row['Ampl_stddev']
                ]
        self.Components = outdict
        self.import_model()
        self.result = result
        self.output = output

    def _Go_Button_fired(self):
        # Transform the internal dict holding the model to a Pandas dataframe
        # that the lmfit wrapper will digest:
        print('Now fitting lines {}'.format(self.linesstring))
        self.create_fit_param_frame()
        self.load_parameters_to_fitter()
        if self.fitter == 'lmfit':
            self.fit_with_lmfit()
        else:
            raise NotImplementedError('Only LMfit backend implemented so far.')

        print('Successfully fitted lines {} \n \n '.format(self.linesstring))

    ##=========================================================================
    #    END of GO button functionality.
    ##=========================================================================

    # Define what to do when a new component is selected.
    def _select_changed(self):
        # First, show the values of current component in sliders!
        self.Centr = self.Components[self.select][0]
        self.Sigma = self.Components[self.select][1]
        self.Heigh = \
            min(self.ampmax, max(self.Components[self.select][2], self.ampmin))
        self.LockCentr = self.Locks[self.select][0]
        self.LockSigma = self.Locks[self.select][1]
        self.LockHeigh = self.Locks[self.select][2]
        self.plot.request_redraw()
        return

    # Every time one of the parameters in the interactive window is changed,
    #   write the change to the parameters list of the selected component.
    #   Do this one-by-one, as it is otherwise going to mess up the
    #   creation and selection of components.

    def _Centr_changed(self):
        self.Components[self.select][0] = self.Centr
        self.update_plot()
        return

    def _Sigma_changed(self):
        self.Components[self.select][1] = self.Sigma
        self.update_plot()
        return

    def _Heigh_changed(self):
        self.Components[self.select][2] = self.Heigh
        self.update_plot()
        return

    def _continuum_estimate_changed(self):
        self.Components['Contin'][0] = self.continuum_estimate
        self.update_plot()

    def _LockCentr_changed(self):
        self.Locks[self.select][0] = self.LockCentr
        return

    def _LockSigma_changed(self):
        self.Locks[self.select][1] = self.LockSigma
        return

    def _LockHeigh_changed(self):
        self.Locks[self.select][2] = self.LockHeigh
        return

    ###========================================================================
    # Define the graphical user interface

    view = View(
        Group(Group(
            VGroup(Item('plwin',
                        editor=ComponentEditor(),
                        show_label=False,
                        springy=True),
                   Group(HGroup(
                       Item('Centr',
                            label='Center',
                            enabled_when='LockCentr==False'),
                       Item('LockCentr', label='Lock'),
                   ),
                         HGroup(
                             Item('Sigma',
                                  label='Sigma',
                                  enabled_when='LockSigma==False'),
                             Item('LockSigma', label='Lock'),
                         ),
                         HGroup(
                             Item('Heigh',
                                  label=u'Strength ',
                                  enabled_when='LockHeigh==False'),
                             Item('LockHeigh', label='Lock'),
                         ),
                         HGroup(
                             Item('continuum_estimate',
                                  enabled_when='LockConti==False',
                                  label='Contin.  ',
                                  springy=True),
                             Item('LockConti', label='Lock'),
                             springy=True,
                             show_border=False,
                         ),
                         show_border=True,
                         label='Component parameters'),
                   springy=True),
            show_border=True,
        ),
              Group(Item('add_profile'),
                    Item('remove_profile'),
                    Item('Feedback', style='readonly'),
                    Item('Feedback', style='readonly'),
                    Item(name='select',
                         editor=EnumEditor(name='CompoList'),
                         style='custom'),
                    Item('Go_Button'),
                    orientation='vertical',
                    show_labels=False,
                    show_border=True),
              show_border=True,
              orientation='horizontal'),
        resizable=True,
        height=700,
        width=1000,  # ),
        buttons=[UndoButton, ApplyButton, CancelButton, OKButton],
        close_result=True,
        kind='livemodal',  # Works but not perfect.
        title="Pychelle - line profile editor")

    compview = View(HGroup(Item('sigfitmin'), Item('sigfitmax')),
                    HGroup(Item('ampfitmin'), Item('ampfitmax')),
                    HGroup(Item('wavfitmin'), Item('wavfitmax')),
                    title="Pychelle - component fit settings")

    def import_model(self):
        ''' Once lpbuilder's Components dict is set; use this to set
        the state variables of the LPbuilder instance.
        '''
        self.CompoList = sorted(self.Components.keys())[:-1]
        print self.CompoList, self.Components.keys()
        #import ipdb; ipdb.set_trace()  # XXX BREAKPOINT
        self.CompNum = len(self.CompoList)
        for com in self.CompoList:
            self.Locks[com] = [False] * 4
        self.continuum_estimate = self.Components['Contin'][0]
        self.select = self.CompoList[-1]
        self._select_changed()
        self.build_plot()
        self.update_plot()
        print '    '

    def update_plot(self):
        self.y[self.select] = gauss(self.mod_x, self.Centr + self.line_center,
                                    self.Sigma, self.Heigh)
        ys = sp.asarray(self.y.values()).sum(0)
        self.contarray = sp.ones(self.mod_x.shape[0]) * self.continuum_estimate
        self.Model = self.contarray + ys
        self.plotdata.set_data('cont', self.contarray)
        self.plotdata.set_data(self.select, self.y[self.select])
        self.plotdata.set_data('model', self.Model)
        self.plotdata.set_data('Residuals', self.Resids)
        self.update_resid_window()  # Uncomment to keep static yscale on resids

    @on_trait_change('Resids')
    def update_resid_window(self):
        resmin = max(sp.absolute(self.Resids[self.fitrange]).max(), 5.) * 1.2
        self.resplot.value_range.low_setting,\
            self.resplot.value_range.high_setting = (-resmin, resmin)
        self.resplot.request_redraw()
示例#6
0
class MDAViewController(BaseImageController):
    # the image data for the factor plot (including any scatter data and
    #    quiver data)
    factor_plotdata = Instance(ArrayPlotData)
    # the actual plot object
    factor_plot = Instance(BasePlotContainer)
    # the image data for the score plot (may be a parent image for scatter overlays)
    score_plotdata = Instance(ArrayPlotData)
    score_plot = Instance(BasePlotContainer)
    component_index = Int(0)
    _selected_peak = Int(0)
    _contexts = List([])
    _context = Int(-1)
    dimensionality = Int(1)
    _characteristics = List(["None", "Height", "Orientation", "Eccentricity"])
    _characteristic = Int(0)
    _vectors = List(["None", "Shifts", "Skew"])
    _vector = Int(0)
    vector_scale = Float(1.0)
    _can_map_peaks = Bool(False)

    def __init__(self, treasure_chest=None, data_path='/rawdata', *args, **kw):
        super(MDAViewController, self).__init__(*args, **kw)
        self.factor_plotdata = ArrayPlotData()
        self.score_plotdata = ArrayPlotData()
        if treasure_chest is not None:
            self.numfiles = len(treasure_chest.list_nodes(data_path))
            self.chest = treasure_chest
            self.data_path = data_path
            # populate the list of available contexts (if any)
            if self.chest.root.mda_description.nrows > 0:
                self._contexts = self.chest.root.mda_description.col(
                    'context').tolist()
                context = self.get_context_name()
                self.dimensionality = self.chest.get_node_attr(
                    '/mda_results/' + context, 'dimensionality')
                self.update_factor_image()
                self.update_score_image()

    @on_trait_change('_context')
    def context_changed(self):
        context = self.get_context_name()
        self.dimensionality = self.chest.get_node_attr(
            '/mda_results/' + context, 'dimensionality')
        self.render_active_factor_image(context)
        self.render_active_score_image(context)

    def increase_selected_component(self):
        # TODO: need to measure dimensionality somehow (node attribute, or array size?)
        if self.component_index == (self.dimensionality - 1):
            self.component_index = 0
        else:
            self.component_index += 1

    def decrease_selected_component(self):
        if self.component_index == 0:
            self.component_index = int(self.dimensionality - 1)
        else:
            self.component_index -= 1

    def get_context_name(self):
        return str(self._contexts[self._context])

    def get_characteristic_name(self):
        return self._characteristics[self._characteristic]

    def get_vector_name(self):
        return self._vectors[self._vector]

    def render_active_factor_image(self, context):
        if self.chest.get_node_attr('/mda_results/' + context, 'on_peaks'):
            factors = self.chest.get_node('/mda_results/' + context +
                                          '/peak_factors')
            # return average cell image (will be overlaid with peak info)
            self.factor_plotdata.set_data('imagedata',
                                          self.chest.root.cells.average[:])
            component = factors.read(
                start=self.component_index,
                stop=self.component_index + 1,
                step=1,
            )[:]
            numpeaks = self.chest.root.cell_peaks.getAttr('number_of_peaks')
            index_keys = ['y%i' % i for i in xrange(numpeaks)]
            value_keys = ['x%i' % i for i in xrange(numpeaks)]

            values = np.array(component[value_keys]).view(float)
            indices = np.array(component[index_keys]).view(float)

            self.factor_plotdata.set_data('value', values)
            self.factor_plotdata.set_data('index', indices)

            if self.get_vector_name() != "None":
                field = ''
                y_invert = 1
                if self.get_vector_name() == 'Shifts':
                    field = 'd'
                    y_invert = -1
                elif self.get_vector_name() == 'Skew':
                    field = 's'
                if field != '':
                    vector_x_keys = [
                        '%sx%i' % (field, i) for i in xrange(numpeaks)
                    ]
                    vector_y_keys = [
                        '%sy%i' % (field, i) for i in xrange(numpeaks)
                    ]
                    vector_x = np.array(
                        component[vector_x_keys]).view(float).reshape((-1, 1))
                    vector_y = y_invert * np.array(
                        component[vector_y_keys]).view(float).reshape((-1, 1))
                    vectors = np.hstack((vector_x, vector_y))
                    vectors *= self.vector_scale
                    self.factor_plotdata.set_data('vectors', vectors)
                else:
                    print "%s field not recognized for vector plots." % field
                    if 'vectors' in self.plotdata.arrays:
                        self.factor_plotdata.del_data('vectors')

            else:
                if 'vectors' in self.plotdata.arrays:
                    self.plotdata.del_data('vectors')
                # clear vector data
            if self.get_characteristic_name() != "None":
                color_prefix = self._characteristics[
                    self._characteristic][0].lower()
                color_keys = [
                    '%s%i' % (color_prefix, i) for i in xrange(numpeaks)
                ]
                color = np.array(component[color_keys]).view(float)
                self.factor_plotdata.set_data('color', color)
            else:
                if 'color' in self.plotdata.arrays:
                    self.plotdata.del_data('color')
            self.factor_plot = self.get_scatter_quiver_plot(
                self.factor_plotdata, tools=['colorbar'])
            self._can_map_peaks = True
        else:
            factors = self.chest.get_node('/mda_results/' + context +
                                          '/image_factors')
            # return current factor image (MDA on images themselves)
            self.factor_plotdata.set_data('imagedata',
                                          factors[self.component_index, :, :])
            # return current factor image (MDA on images themselves)
            self.factor_plot = self.get_simple_image_plot(self.factor_plotdata)

    def render_active_score_image(self, context):
        self.score_plotdata.set_data('imagedata', self.get_active_image())
        values = self.chest.root.cell_description.read_where(
            'filename == "%s"' % self.get_active_name(),
            field='y_coordinate',
        )

        indices = self.chest.root.cell_description.read_where(
            'filename == "%s"' % self.get_active_name(),
            field='x_coordinate',
        )
        if self.chest.get_node_attr('/mda_results/' + context, 'on_peaks'):
            scores = self.chest.get_node('/mda_results/' + context +
                                         '/peak_scores')
        else:
            scores = self.chest.get_node('/mda_results/' + context +
                                         '/image_scores')
        color = scores.read_where(
            'filename == "%s"' % self.get_active_name(),
            field='c%i' % self.component_index,
        )
        self.score_plotdata.set_data('index', values)
        self.score_plotdata.set_data('value', indices)
        self.score_plotdata.set_data('color', color)
        self.score_plot = self.get_scatter_overlay_plot(
            self.score_plotdata,
            title=self.get_active_name(),
            tools=["colorbar", "zoom", "pan"])

    @on_trait_change("component_index, _characteristic, _vector, vector_scale")
    def update_factor_image(self):
        context = self.get_context_name()
        self.render_active_factor_image(context)

    @on_trait_change("selected_index, component_index")
    def update_score_image(self):
        context = self.get_context_name()
        self.render_active_score_image(context)

    def open_factor_save_UI(self):
        self.open_save_UI(plot_id='factor_plot')

    def open_score_save_UI(self):
        self.open_save_UI(plot_id='score_plot')