示例#1
0
    def __init__(self, client_pars=None, plot_template=None, interactive=True, **kwargs):
        self.client=Client(client_pars)
        self.connect()
        # initialize data containers
        self.pr=u.Param()
        self.ob=u.Param()
        self.err=[]
        self.ferr=None
        # initialize the plotter
        from matplotlib import pyplot
        self.interactive = interactive
        self.pp = pyplot
        pyplot.interactive(interactive)
        
        #self.template_default = default_template
        self.templates = templates
        #self.template = u.Param()
        self.p = u.Param(DEFAULT)
        #self.update_plot_layout(plot_template=plot_template)

        # save as 'cmd': tuple(ticket,buffer,key)
        # will call get cmds with 'cmd', save the ticket in <ticket> and
        # save the resulting data in buffer[key]
        
        self.cmd_dct = {}
        self.server_dcts={}
示例#2
0
class Plot_Client(object):
    """
    Plotting Client for Ptypy
    Very simple for now. Does not inherit from Client but 
    holds a client instance.
    TODO:
    - maybe inherit von Client?
    """
    DEFAULT=DEFAULT
    
    def __init__(self, client_pars=None, plot_template=None, interactive=True, **kwargs):
        self.client=Client(client_pars)
        self.connect()
        # initialize data containers
        self.pr=u.Param()
        self.ob=u.Param()
        self.err=[]
        self.ferr=None
        # initialize the plotter
        from matplotlib import pyplot
        self.interactive = interactive
        self.pp = pyplot
        pyplot.interactive(interactive)
        
        #self.template_default = default_template
        self.templates = templates
        #self.template = u.Param()
        self.p = u.Param(DEFAULT)
        #self.update_plot_layout(plot_template=plot_template)

        # save as 'cmd': tuple(ticket,buffer,key)
        # will call get cmds with 'cmd', save the ticket in <ticket> and
        # save the resulting data in buffer[key]
        
        self.cmd_dct = {}
        self.server_dcts={}
        
    def connect(self):
        self.client.activate()
        # pause until connected
        while not self.client.connected:
            time.sleep(0.1)
            
    def disconnect(self):
        self.client.stop()
        
    
    def get_data(self,cmd_list,flush=True):
        c=self.client
        tlist=[]
        # get tickets for requests
        for cmd in cmd_list:
            ticket=c.get(cmd)
            tlist+=[ticket]
        # wait for the last ticket to be processed
        completed = c.wait(ticket)
        # change data for tickets
        data=[c.data[ticket] for ticket in tlist]
        #data = c.data[ticket]
        # empty buffer
        if flush: c.flush()
        return data
    # get the probe
    
    def init_client_dict(self,server_dict_list):
        # makes client dict after server dct. This can be solved nicer I guess
        c=self.client
        keys_list = self.get_data([server_dict+'.keys()' for server_dict in server_dict_list])
        
        self.server_dcts.update(dict(zip(server_dict_list,[u.Param([(key,u.Param()) for key in keys]) for keys in keys_list])))
        #print self.server_dcts
        return self.server_dcts #
        
    def initialize(self):
        pr_base='Ptycho.probe.S'
        ob_base='Ptycho.obj.S'
        runtime='Ptycho.runtime'

        dct=self.init_client_dict([pr_base,ob_base])
        self.pr=dct[pr_base] 
        self.ob=dct[ob_base] 
        storage_keys=['data','psize','center']
        for kb,base in dct.iteritems():
            for ks,storage in base.iteritems():
                if ks==str(ks):
                    kkey='\''+ks+'\''
                else:
                    kkey=str(ks)

                for skey in storage_keys:
                    self.cmd_dct[kb+'['+kkey+'].' +skey]=  [None,storage,skey]
        
        dct=self.init_client_dict([runtime]) 
        self.runtime=dct[runtime]
        for key in self.runtime.keys():
            if key==str(key):
                kkey='\''+key+'\''
            else:
                kkey=str(key)
            self.cmd_dct[runtime+'['+kkey+']']=[None,self.runtime,key]
            
    def request_data(self):
        for cmd,item in self.cmd_dct.iteritems():
            #print cmd,item[0]
            item[0]=self.client.get(cmd) 
            
    def mv_data_from_client_buffer(self):
        for cmd,item in self.cmd_dct.iteritems():
            #print cmd,item
            item[1][item[2]] = self.client.data[item[0]]
            
        self.client.flush()
        
    """
    def get_all_data(self):
        cmds=[]
        buffers=[]
        storage_keys=['data','psize','center']
        for kb,base in self.server_dicts.items():
            for ks,storage in base.items():
                if ks==str(ks):
                    kkey='\''+ks+'\''
                else:
                    kkey=str(ks)
                    
                for skey in storage_keys:
                    cmds.append(kb+'['+kkey+'].' +skey)
                    buffers.append((storage,skey))
        print cmds
        data=self.get_data(cmds)
        for d,b in zip(data,buffers):
            b[0][b[1]]=d
    """
    def update_plot_layout(self,plot_template=None,**kwargs):
        # generate the plot frame
        def simplify_aspect_ratios(sh):
            ratio= sh[1] / float(sh[0])    
            rp = 1 - int(ratio < 2./3.) + int(ratio >= 3./2.)
            if rp==0:
                sh =(4,2)
            elif rp==2:
                sh =(2,4)
            else:
                sh =(3,3)
            return sh
        # local references:
        ob=self.ob
        pr=self.pr
        ptemplate = self.p
        if plot_template is not None:
            ptemplate = plot_template
        elif ptemplate is None:
            ptemplate = 'legacy'
            
        if ptemplate is not None:
            if isinstance(ptemplate,str):
                template = self.templates.get(ptemplate)
                if template is None:
                    raise RuntimeError('Plot template not known. Look in class.templates.keys() for a template of parameters')
            elif isinstance(ptemplate,dict):
                template=ptemplate
                self.templates.update({'custom':ptemplate})
                
            self.p.update(template)
        
        #axes =[]
        self.num_shape_list=[]
        num_shape_list=self.num_shape_list
        for key in sorted(ob.keys()):
            cont=ob[key]
            #print key
            # attach a plotting dict from above
            # this will need tweaking
            cp=self.p.ob.copy()
            # determine the shape
            sh = cont.data.shape[-2:]
            if self.p.simplified_aspect_ratios:
                sh = simplify_aspect_ratios(sh)
                
            layers=cp.layers
            if layers is None:
                layers=cont.data.shape[0]
            if np.isscalar(layers):
                layers=range(layers)
            cp.layers=layers
            cp.axes_index=len(num_shape_list)
            num_shape=[len(layers)*len(cp.auto_display)+int(cp.local_error),sh]
            num_shape_list.append(num_shape)
            cont.plot = cp
            
        # per default we will use the a frame similar to the last object frame for plotting
        if np.array(self.p.plot_error).any():
            self.error_axes_index = len(num_shape_list)-1
            self.error_frame=num_shape_list[-1][0]
            num_shape_list[-1][0]+=1 # add a frame
        
        for key in sorted(pr.keys()):
            cont=pr[key]
            # attach a plotting dict from above
            # this will need tweaking
            cont.plot=self.p.pr.copy()
            # determine the shape
            sh = cont.data.shape[-2:]
            if self.p.simplified_aspect_ratios:
                sh = simplify_aspect_ratios(sh)
                
            layers=cont.plot.layers
            if layers is None:
                layers=cont.data.shape[0]
            if np.isscalar(layers):
                layers=range(layers)
            cont.plot.layers=layers
            cont.plot.axes_index=len(num_shape_list)
            num_shape=[len(layers)*len(cont.plot.auto_display),sh]
            num_shape_list.append(num_shape)
        
        axes_list,plot_fig,gs = self.create_plot_from_tile_list(1,num_shape_list,self.p.figsize)
        #plot_fig.suptitle(p.paramdict.get('scans')[0])
        #obj_axes = plot_axes_list[0]
        #pr_axes = plot_axes_list[1]
        #err_axes = [pr_axes.pop(0)]
        #local_error_axes=[obj_axes.pop(0) for i in range(lerr_axes_num)]
        sy,sx = gs.get_geometry()
        w,h,l,r,b,t = self.p.gridspecpars
        gs.update(wspace=w*sy,hspace=h*sx,left=l,right=r,bottom=b,top=t)
        self.draw()
        #plot_axes = obj_axes+pr_axes+err_axes
        plot_fig.hold(False)
        for axes in axes_list:
            for pl in axes: 
                pl.hold(False)
                self.pp.setp(pl.get_xticklabels(), fontsize=8) #doesn't do nothin
                self.pp.setp(pl.get_yticklabels(), fontsize=8)
        self.plot_fig = plot_fig
        self.axes_list = axes_list
        self.gs= gs

    def create_plot_from_tile_list(self,fignum=1,num_shape_list=[(4,(2,2))],figsize=(8,8)):        
        def fill_with_tiles(size,sh,num,figratio=16./9.):
            coords_tl=[]
            while num > 0:
                Horizontal = True
                N_h = size[1]//sh[1]
                N_v = size[0]//sh[0]
                #looking for tight fit
                if num<=N_v and np.abs(N_h-num) >= np.abs(N_v-num):
                    Horizontal = False
                elif num<=N_h and np.abs(N_h-num) <= np.abs(N_v-num):
                    Horizontal = True
                elif size[0]==0 or size[1]/float(size[0]+0.00001) > figratio:
                    Horizontal = True
                else:
                    Horizontal = False
                     
                if Horizontal:
                    N=N_h
                    a=size[1]%sh[1]
                    coords=[(size[0],int(ii*sh[1])+a) for ii in range(N)]
                    size[0]+=sh[0]
                else:
                    N=N_v
                    a=size[0]%sh[0]
                    coords=[(int(ii*sh[0])+a,size[1]) for ii in range(N)]
                    size[1]+=sh[1]
                    
                num -=N
                coords_tl+=coords
                coords_tl.sort()
                
            return coords_tl, size
        
        coords_list=[]
        fig_aspect_ratio = figsize[0]/float(figsize[1])
        size=[0,0] 
        # determine frame thickness
        aa= np.array([sh[0]*sh[1] for N,sh in num_shape_list])
        N,bigsh = num_shape_list[np.argmax(aa)]
        frame=int(0.2*min(bigsh))
        for N,sh in num_shape_list:
            nsh=np.array(sh)+frame
            coords, size  = fill_with_tiles(size,nsh,N,fig_aspect_ratio)
            coords_list.append(coords)

        gs=gridspec.GridSpec(size[0],size[1])
        fig = self.pp.figure(fignum)
        fig.clf()
        
        mag=min(figsize[0]/float(size[1]),figsize[1]/float(size[0]))
        figsize=(size[1]*mag,size[0]*mag)
        fig.set_size_inches(figsize,forward=True)
        space =0.1*size[0]
        gs.update(wspace=0.1*size[0],hspace=0.12*size[0],left=0.07,right=0.95,bottom=0.05,top=0.93) #this is still a stupid hardwired parameter
        axes_list=[]
        for (N,sh),coords in zip(num_shape_list,coords_list):
            axes_list.append([fig.add_subplot(gs[co[0]+frame//2:co[0]+frame//2+sh[0],co[1]+frame//2:co[1]+frame//2+sh[1]]) for co in coords]) 
    
        return axes_list,fig,gs

    def draw(self):
        if self.interactive:
            #self.plot_fig.canvas.draw()
            self.pp.draw()
            time.sleep(0.1)
        else:
            self.pp.show()
            
    def plot_error(self):
        if np.array(self.p.plot_error).any():
            # get axis            
            try:
                axis=self.axes_list[self.error_axes_index][self.error_frame]
                # get runtime info
                error = np.array([info['error'].sum(0) for info in self.runtime.iter_info])
                err_fmag = error[:,0]
                err_phot = error[:,1]
                err_exit = error[:,2]
                axis.hold(False)
                fmag = err_fmag/np.max(err_fmag)
                axis.plot(fmag,label='err_fmag %2.2f%% of %.2e' % (fmag[-1]*100,np.max(err_fmag)))
                axis.hold(True)
                phot = err_phot/np.max(err_phot)
                axis.plot(phot,label='err_phot %2.2f%% of %.2e' % (phot[-1]*100,np.max(err_phot)))
                ex = err_exit/np.max(err_exit)
                axis.plot(ex,label='err_exit %2.2f%% of %.2e' % (ex[-1]*100,np.max(err_exit)))
                axis.legend(loc=1,fontsize=10) #('err_fmag %.2e' % np.max(err_fmag),'err_phot %.2e' % np.max(err_phot),'err_exit %.2e' % np.max(err_exit)),
                self.pp.setp(axis.get_xticklabels(), fontsize=10)
                self.pp.setp(axis.get_yticklabels(), fontsize=10)
            except:
                pass
            
    def plot_storage(self,storage,title="",typ='obj'):
        # get plotting paramters
        pp=storage.plot
        axes=self.axes_list[pp.axes_index]
        weight=pp.get('weight')
        # plotting mask for ramp removal
        sh=storage.data.shape[-2:]
        x,y=np.indices(sh)-np.reshape(np.array(sh)//2,(len(sh),)+len(sh)*(1,))
        mask= (x**2+y**2 < 0.1*min(sh)**2)
        pp.mask=mask
        # cropping
        crop=np.array(sh)*np.array(pp.crop)//2
        crop=-crop.astype(int)
        data=u.crop_pad(storage.data,crop,axes=[-2,-1])
        plot_mask=u.crop_pad(mask,crop,axes=[-2,-1])
        for ii,ind in enumerate([(l,a) for l in pp.layers for a in pp.auto_display]):
            #print ii, ind
            if ii >= len(axes):
                break
            # get the layer
            dat=data[ind[0]]
            if ind[1]=='p' or ind[1]=='c':
                if pp.rm_pr:                    
                    if weight is None:
                        ndat = U.rmphaseramp(dat, np.abs(dat) * plot_mask.astype(float))
                        mean_ndat = (ndat*plot_mask).sum() / plot_mask.sum()
                    else:
                        ndat = U.rmphaseramp(dat, np.abs(dat) * weight)
                        mean_ndat = (ndat*weight).sum() / weight.sum()
                else:
                    ndat=dat.copy()
                    mean_ndat = (ndat*plot_mask).sum() / plot_mask.sum()
            else:
                ndat=dat.copy()
            
            if typ=='obj':
                mm = np.mean(np.abs(ndat*plot_mask)**2)
                info = 'T=%.2f' % mm
            else:
                mm = np.sum(np.abs(ndat)**2)
                info = 'P=%1.1e' % mm
                
            if ind[1]=='c':
                dat_i = U.imsave(np.flipud(ndat))
                if not axes[ii].images:
                    axes[ii].imshow(dat_i)
                    self.pp.setp(axes[ii].get_xticklabels(), fontsize=8)
                    self.pp.setp(axes[ii].get_yticklabels(), fontsize=8)
                else:
                    axes[ii].images[0].set_data(dat_i)            
                axes[ii].set_title('%s#%d (C)\n%s' % (title,ind[0],info),size=12)
                continue
                
            if ind[1]=='p':        
                d = np.angle(ndat / mean_ndat)
                ttl = '%s#%d (P)' % (title,ind[0]) #% (ind[0],ind[1])
                cmap = self.pp.get_cmap(pp.cmaps[1])
                clims = pp.clims[1]
            elif ind[1]=='a':    
                d = np.abs(ndat)
                ttl = '%s#%d (M)\n%s' % (title,ind[0],info)
                cmap = self.pp.get_cmap(pp.cmaps[0])
                clims = pp.clims[0]
            
            vmin = d[plot_mask].min() if clims is None else clims[0]
            vmax = d[plot_mask].max() if clims is None else clims[1]
            if not axes[ii].images:
                axes[ii].imshow(d,vmin=vmin, vmax=vmax,cmap=cmap)
                self.pp.setp(axes[ii].get_xticklabels(), fontsize=8)
                self.pp.setp(axes[ii].get_yticklabels(), fontsize=8)
            else:
                axes[ii].images[0].set_data(d)
                axes[ii].images[0].set_clim(vmin=vmin, vmax=vmax)
            axes[ii].set_title(ttl,size=12)
            #ii+=1
    
    def plot_all(self):
        for key,storage in self.pr.items():
            #print key
            self.plot_storage(storage,str(key),'pr')
        for key,storage in self.ob.items():
            #print key
            self.plot_storage(storage,str(key),'obj')
        self.plot_error()
        
    def loop_plot(self, timeout=0.1):
        self.initialize()
        self.request_data()
        while len(self.client.pending): time.sleep(timeout)
        self.mv_data_from_client_buffer()
        self.update_plot_layout()
        self.request_data()
        while True:
            if len(self.client.pending) != 0:
                # at this point the gui should be able to react again
                U.pause(timeout)
            else:
                self.mv_data_from_client_buffer()
                self.plot_all()
                self.draw()
                self.request_data()