コード例 #1
0
ファイル: ptysim.py プロジェクト: zhenchen16/ptypy
def simulate_basic_with_pods(ptypy_pars_tree=None, sim_pars=None, save=False):
    """
    Basic Simulation
    """
    p = DEFAULT.copy()
    ppt = ptypy_pars_tree
    if ppt is not None:
        p.update(ppt.get('simulation'))
    if sim_pars is not None:
        p.update(sim_pars)

    P = ptypy.core.Ptycho(ppt, level=1)

    # make a data source that has is basicaly empty
    P.datasource = make_sim_datasource(P.modelm, p.pos_drift, p.pos_scale,
                                       p.pos_noise)

    P.modelm.new_data()
    u.parallel.barrier()
    P.print_stats()

    # Propagate and apply psf for simulationg partial coherence (if not done so with modes)
    for name, pod in P.pods.iteritems():
        if not pod.active: continue
        pod.diff += conv(u.abs2(pod.fw(pod.exit)), p.psf)

    # Filter storage data similar to a detector.
    if p.detector is not None:
        Det = Detector(p.detector)
        save_dtype = Det.dtype
        for ID, Sdiff in P.diff.S.items():
            # get the mask storage too although their content will be overriden
            Smask = P.mask.S[ID]
            dat, mask = Det.filter(Sdiff.data)
            if p.frame_size is not None:
                hplanes = u.expect2(p.frame_size) - u.expect2(dat.shape[-2:])
                dat = u.crop_pad(dat, hplanes, axes=[-2, -1]).astype(dat.dtype)
                mask = u.crop_pad(mask, hplanes, axes=[-2,
                                                       -1]).astype(mask.dtype)
            Sdiff.fill(dat)
            Smask.fill(mask)
    else:
        save_dtype = None

    if save:
        P.modelm.collect_diff_mask_meta(save=save, dtype=save_dtype)

    u.parallel.barrier()
    return P
コード例 #2
0
ファイル: ptysim.py プロジェクト: aglowacki/ptypy
def simulate_basic_with_pods(ptypy_pars_tree=None,sim_pars=None,save=False):
    """
    Basic Simulation
    """
    p = DEFAULT.copy()
    ppt = ptypy_pars_tree
    if ppt is not None:
        p.update(ppt.get('simulation'))
    if sim_pars is not None:
        p.update(sim_pars)
        
    P = ptypy.core.Ptycho(ppt,level=1)

    # make a data source that has is basicaly empty
    P.datasource = make_sim_datasource(P.modelm,p.pos_drift,p.pos_scale,p.pos_noise)
    
    P.modelm.new_data()
    u.parallel.barrier()
    P.print_stats()
    
    # Propagate and apply psf for simulationg partial coherence (if not done so with modes)
    for name,pod in P.pods.iteritems():
        if not pod.active: continue
        pod.diff += conv(u.abs2(pod.fw(pod.exit)),p.psf)
    
    # Filter storage data similar to a detector.
    if p.detector is not None:
        Det = Detector(p.detector)
        save_dtype = Det.dtype
        for ID,Sdiff in P.diff.S.items():
            # get the mask storage too although their content will be overriden
            Smask = P.mask.S[ID]
            dat, mask = Det.filter(Sdiff.data)
            if p.frame_size is not None:
                hplanes = u.expect2(p.frame_size)-u.expect2(dat.shape[-2:])
                dat = u.crop_pad(dat,hplanes,axes=[-2,-1]).astype(dat.dtype)
                mask = u.crop_pad(mask,hplanes,axes=[-2,-1]).astype(mask.dtype)
            Sdiff.fill(dat)
            Smask.fill(mask) 
    else:
        save_dtype = None        
    
    if save:
        P.modelm.collect_diff_mask_meta(save=save,dtype=save_dtype)
           
    u.parallel.barrier()
    return P
コード例 #3
0
    def bw(self, W):
        """
        computes backward propagated wavefront of input wavefront W
        """
        # check for cropping
        if (self.crop_pad != 0).any():
            w = u.crop_pad(W, self.crop_pad)
        else:
            w = W

        # compute transform
        w = self.isc * self.post_ifft * self.ifft(self.pre_ifft * w)

        # cropping again
        if (self.crop_pad != 0).any():
            return u.crop_pad(w, -self.crop_pad)
        else:
            return w
コード例 #4
0
ファイル: geometry.py プロジェクト: aglowacki/ptypy
 def bw(self,W):
     """
     computes backward propagated wavefront of input wavefront W
     """
     # check for cropping
     if (self.crop_pad != 0).any() : 
         w = u.crop_pad(W,self.crop_pad)
     else:
         w = W
         
     # compute transform
     w = self.isc * self.post_ifft * self.ifft(self.pre_ifft * w)
     
     # cropping again
     if (self.crop_pad != 0).any() : 
         return u.crop_pad(w,-self.crop_pad)
     else:
         return w
コード例 #5
0
 def plot_storage(self,storage,title="",typ='obj'):
     # get plotting paramters
     pp=storage.plot
     axes=self.axes_list[pp.axes_index]
     weight=pp.get('weight')
     # plotting mask for ramp removal
     sh=storage.data.shape[-2:]
     x,y=np.indices(sh)-np.reshape(np.array(sh)//2,(len(sh),)+len(sh)*(1,))
     mask= (x**2+y**2 < 0.1*min(sh)**2)
     pp.mask=mask
     # cropping
     crop=np.array(sh)*np.array(pp.crop)//2
     crop=-crop.astype(int)
     data=u.crop_pad(storage.data,crop,axes=[-2,-1])
     plot_mask=u.crop_pad(mask,crop,axes=[-2,-1])
     for ii,ind in enumerate([(l,a) for l in pp.layers for a in pp.auto_display]):
         #print ii, ind
         if ii >= len(axes):
             break
         # get the layer
         dat=data[ind[0]]
         if ind[1]=='p' or ind[1]=='c':
             if pp.rm_pr:                    
                 if weight is None:
                     ndat = U.rmphaseramp(dat, np.abs(dat) * plot_mask.astype(float))
                     mean_ndat = (ndat*plot_mask).sum() / plot_mask.sum()
                 else:
                     ndat = U.rmphaseramp(dat, np.abs(dat) * weight)
                     mean_ndat = (ndat*weight).sum() / weight.sum()
             else:
                 ndat=dat.copy()
                 mean_ndat = (ndat*plot_mask).sum() / plot_mask.sum()
         else:
             ndat=dat.copy()
         
         if typ=='obj':
             mm = np.mean(np.abs(ndat*plot_mask)**2)
             info = 'T=%.2f' % mm
         else:
             mm = np.sum(np.abs(ndat)**2)
             info = 'P=%1.1e' % mm
             
         if ind[1]=='c':
             dat_i = U.imsave(np.flipud(ndat))
             if not axes[ii].images:
                 axes[ii].imshow(dat_i)
                 self.pp.setp(axes[ii].get_xticklabels(), fontsize=8)
                 self.pp.setp(axes[ii].get_yticklabels(), fontsize=8)
             else:
                 axes[ii].images[0].set_data(dat_i)            
             axes[ii].set_title('%s#%d (C)\n%s' % (title,ind[0],info),size=12)
             continue
             
         if ind[1]=='p':        
             d = np.angle(ndat / mean_ndat)
             ttl = '%s#%d (P)' % (title,ind[0]) #% (ind[0],ind[1])
             cmap = self.pp.get_cmap(pp.cmaps[1])
             clims = pp.clims[1]
         elif ind[1]=='a':    
             d = np.abs(ndat)
             ttl = '%s#%d (M)\n%s' % (title,ind[0],info)
             cmap = self.pp.get_cmap(pp.cmaps[0])
             clims = pp.clims[0]
         
         vmin = d[plot_mask].min() if clims is None else clims[0]
         vmax = d[plot_mask].max() if clims is None else clims[1]
         if not axes[ii].images:
             axes[ii].imshow(d,vmin=vmin, vmax=vmax,cmap=cmap)
             self.pp.setp(axes[ii].get_xticklabels(), fontsize=8)
             self.pp.setp(axes[ii].get_yticklabels(), fontsize=8)
         else:
             axes[ii].images[0].set_data(d)
             axes[ii].images[0].set_clim(vmin=vmin, vmax=vmax)
         axes[ii].set_title(ttl,size=12)
コード例 #6
0
    def read(self, scan=None, **kwargs):
        """\
        Read in the data
        TODO: (maybe?) MPI to avoid loading all data in a single process for large scans. 
        """

        scan = scan if scan is not None else self.p.scan
        logger.info('Processing scan number %s' % str(scan))

        self.scan = self.get_nexus_file(scan)
        logger.debug('Data will be read from path: %s' % self.scan)

        self.exp = load(self.scan, self.nxs.frame)
        try:
            self.motors = load(self.scan, self.nxs.motors)
        except:
            self.motors = None
        self.command = load(self.scan, self.nxs.command)
        self.data = load(self.scan, self.nxs.frame).astype(float)
        self.label = load(self.scan, self.nxs.label)[0]

        if self.p.experimentID is None:
            try:
                experimentID = load(self.scan, self.nxs.experiment)[0]
            except:
                logger.debug(
                    'Could not find experiment ID from nexus file %s.' %
                    self.scan)
                experimentID = os.path.split(base_path[:-1])[1]
            logger.debug('experimentID: "%s" (automatically set).' %
                         experimentID)
        else:
            logger.debug('experimentID: "%s".' % self.p.experimentID)
            self.experimentID = self.p.experimentID

        dpsize = self.p.dpsize
        ctr = self.p.ctr

        sh = self.data.shape[-2:]
        fullframe = False
        if dpsize is None:
            dpsize = sh
            logger.debug(
                'Full frames (%d x %d) will be saved (so no recentering).' %
                (sh))
            fullframe = True
        self.p.dpsize = expect2(dpsize)

        #data_filename = self.get_save_filename(scan_number, dpsize)
        #logger.info( 'Data will be saved to %s' % data_filename)
        f = self.data
        if not fullframe:
            # Compute center of mass
            if self.mask is None:
                ctr_auto = mass_center(f.sum(0))
            else:
                ctr_auto = mass_center(f.sum(0) * self.mask)
            print ctr_auto
            # Check for center position
            if ctr is None:
                ctr = ctr_auto
                logger.debug('Using center: (%d, %d)' % (ctr[0], ctr[1]))

            #elif ctr == 'inter':
            #import matplotlib as mpl
            #fig = mpl.pyplot.figure()
            #ax = fig.add_subplot(1,1,1)
            #ax.imshow(np.log(f))
            #ax.set_title('Select center point (hit return to finish)')
            #s = u.Multiclicks(ax, True, mode='replace')
            #mpl.pyplot.show()
            #s.wait_until_closed()
            #ctr = np.round(np.array(s.pts[0][::-1]));
            #logger.debug( 'Using center: (%d, %d) - I would have guessed it is (%d, %d)' % (ctr[0], ctr[1], ctr_auto[0], ctr_auto[1]))

            else:
                logger.debug(
                    'Using center: (%d, %d) - I would have guessed it is (%d, %d)'
                    % (ctr[0], ctr[1], ctr_auto[0], ctr_auto[1]))

            self.dpsize = dpsize
            self.ctr = np.array(ctr)
            lim_inf = -np.ceil(ctr - dpsize / 2.).astype(int)
            lim_sup = np.ceil(ctr + dpsize / 2.).astype(int) - np.array(sh)
            hplane_list = [(lim_inf[0], lim_sup[0]), (lim_inf[1], lim_sup[1])]
            logger.debug('Going from %s to %s (hplane_list = %s)' %
                         (str(sh), str(dpsize), str(hplane_list)))
            if self.mask is not None:
                self.mask = u.crop_pad(self.mask, hplane_list).astype(bool)
            if self.flat is not None:
                self.flat = u.crop_pad(self.flat, hplane_list, fillpar=1.)
            if self.dark is not None:
                self.dark = u.crop_pad(self.dark, hplane_list)
            if self.data is not None:
                self.data = u.crop_pad(self.data, hplane_list)
コード例 #7
0
ファイル: I13DLS.py プロジェクト: aglowacki/ptypy
    def read(self, scan=None, **kwargs):
        """\
        Read in the data
        TODO: (maybe?) MPI to avoid loading all data in a single process for large scans. 
        """
        
        scan=scan if scan is not None else self.p.scan
        logger.info( 'Processing scan number %s' % str(scan))
 
        self.scan = self.get_nexus_file(scan)
        logger.debug( 'Data will be read from path: %s' % self.scan)

        self.exp = load(self.scan,self.nxs.frame)
        try:
            self.motors = load(self.scan,self.nxs.motors)
        except:
            self.motors=None
        self.command = load(self.scan,self.nxs.command)
        self.data = load(self.scan, self.nxs.frame).astype(float)
        self.label = load(self.scan, self.nxs.label)[0]


        if self.p.experimentID is None:
            try:
                experimentID = load(self.scan, self.nxs.experiment)[0]
            except:
                logger.debug('Could not find experiment ID from nexus file %s.' % self.scan)
                experimentID = os.path.split(base_path[:-1])[1]
            logger.debug( 'experimentID: "%s" (automatically set).' % experimentID)          
        else:
            logger.debug( 'experimentID: "%s".' % self.p.experimentID)
            self.experimentID = self.p.experimentID



        dpsize = self.p.dpsize
        ctr =  self.p.ctr
        
        sh = self.data.shape[-2:]
        fullframe = False
        if dpsize is None:
            dpsize = sh
            logger.debug( 'Full frames (%d x %d) will be saved (so no recentering).' % (sh))
            fullframe = True
        self.p.dpsize = expect2(dpsize)
 
        #data_filename = self.get_save_filename(scan_number, dpsize)
        #logger.info( 'Data will be saved to %s' % data_filename)
        f = self.data
        if not fullframe:
            # Compute center of mass
            if self.mask is None:
                ctr_auto = mass_center(f.sum(0))
            else:
                ctr_auto = mass_center(f.sum(0)*self.mask)
            print ctr_auto
            # Check for center position
            if ctr is None:
                ctr = ctr_auto
                logger.debug( 'Using center: (%d, %d)' % (ctr[0],ctr[1]))

            #elif ctr == 'inter':
                #import matplotlib as mpl
                #fig = mpl.pyplot.figure()
                #ax = fig.add_subplot(1,1,1)
                #ax.imshow(np.log(f))
                #ax.set_title('Select center point (hit return to finish)')
                #s = u.Multiclicks(ax, True, mode='replace')
                #mpl.pyplot.show()
                #s.wait_until_closed()
                #ctr = np.round(np.array(s.pts[0][::-1]));
                #logger.debug( 'Using center: (%d, %d) - I would have guessed it is (%d, %d)' % (ctr[0], ctr[1], ctr_auto[0], ctr_auto[1]))

            else:
                logger.debug( 'Using center: (%d, %d) - I would have guessed it is (%d, %d)' % (ctr[0], ctr[1], ctr_auto[0], ctr_auto[1]))
    
            self.dpsize = dpsize
            self.ctr = np.array(ctr)
            lim_inf = -np.ceil(ctr - dpsize/2.).astype(int)
            lim_sup = np.ceil(ctr + dpsize/2.).astype(int) - np.array(sh)
            hplane_list = [(lim_inf[0], lim_sup[0]), (lim_inf[1], lim_sup[1])]
            logger.debug( 'Going from %s to %s (hplane_list = %s)' % (str(sh), str(dpsize), str(hplane_list)))
            if self.mask is not None: self.mask = u.crop_pad(self.mask, hplane_list).astype(bool)
            if self.flat is not None: self.flat = u.crop_pad(self.flat, hplane_list,fillpar=1.)
            if self.dark is not None: self.dark = u.crop_pad(self.dark, hplane_list)            
            if self.data is not None: self.data = u.crop_pad(self.data, hplane_list)