Ejemplo n.º 1
0
    def init_probe(self, *args, **kwargs):

        if CXP.io.initial_probe_guess is not '':
            probe = CXData()
            probe.load(CXP.io.initial_probe_guess)
            self.probe.modes = [CXData(data=[probe.data[0]/(i+1)]) for i in range(CXP.reconstruction.probe_modes)]
            self.probe.normalise()
        else:
            dx_s = CXP.dx_s

            p, p2 = CXP.preprocessing.desired_array_shape, CXP.preprocessing.desired_array_shape/2

            probe = sp.zeros((p, p), complex)

            if CXP.experiment.optic.lower() == 'kb':
                if len(CXP.experiment.beam_size)==1:
                    bsx=bsy=np.round(CXP.experiment.beam_size[0]/dx_s)
                elif len(CXP.experiment.beam_size)==2:
                    bsx, bsy = np.round(CXP.experiment.beam_size[0]/dx_s), np.round(CXP.experiment.beam_size[1]/dx_s)

                probe = np.sinc((np.arange(p)-p2)/bsx)[:,np.newaxis]*np.sinc((np.arange(p)-p2)/bsy)[np.newaxis,:]
                

            elif CXP.experiment.optic.lower() == 'zp':
                probe = np.sinc(sp.hypot(*sp.ogrid[-p2:p2, -p2:p2])/np.round(3.*CXP.experiment.beam_size[0]/(2*CXP.dx_s)))

            ph_func = gauss_smooth(np.random.random(probe.shape), 10)
            fwhm = p/2.0
            radker = sp.hypot(*sp.ogrid[-p/2:p/2,-p/2:p/2])
            gaussian = exp(-1.0*(fwhm/2.35)**-2. * radker**2.0 )
            gaussian /= gaussian.max()
            probe = abs(gaussian*probe)* exp(complex(0.,np.pi)*ph_func/ph_func.max())

            self.probe.modes = [CXData(data=[probe/(i+1)]) for i in range(CXP.reconstruction.probe_modes)]
            
            self.probe.normalise()
Ejemplo n.º 2
0
    def preprocessing(self):
        """.. method:: preprocessing()
            Collects together all the preprocessing functions that are required to begin phase retrieval.

        """
        # Get the scan positions
        self.positions = CXData(name='positions', data=[])
        self.ptycho_mesh()

        if CXP.measurement.simulate_data:
            self.simulate_data()
        else:
            # Read in raw data
            self.det_mod = CXData(name = 'det_mod')
            if CXP.actions.preprocess_data:
                self.det_mod.read_in_data()
            else:
                self.det_mod.load()
            if CXP.io.whitefield_filename:
                self.probe_det_mod = CXData(name='probe_det_mod')
                self.probe_det_mod.preprocess_data()

        self.object = CXData(name='object', data=[sp.zeros((self.ob_p, self.ob_p), complex)])
        
        self.probe_intensity = CXData(name='probe_intensity', data=[sp.zeros((self.p, self.p))])

        self.probe = CXModal(modes=[])
        self.psi = CXModal(modes=[])

        for i in range(CXP.reconstruction.probe_modes):
            self.probe.modes.append(CXData(name='probe{:d}'.format(i), data=[sp.zeros((self.p, self.p), complex)]))
            self.psi.modes.append(CXData(name='psi{:d}'.format(i), data=[sp.zeros((self.p, self.p), complex) for i in xrange(self.det_mod.len())]))
    
        self.init_probe()

        # Calculate STXM image if this is a ptycho scan
        if len(self.det_mod.data) > 1:
            self.calc_stxm_image()

        if CXP.actions.process_dpc:
            self.process_dpc()
Ejemplo n.º 3
0
    def ptycho_mesh(self):
        """
        Generate a list of ptycho scan positions.

        Outputs
        -------
        self.data : list of 2xN arrays containing horizontal and vertical scan positions in pixels
        self.initial : initial guess at ptycho scan positions (before position correction)
        self.initial_skew : initial skew
        self.initial_rot : initial rotation
        self.initial_scl : initial scaling
        self.skew : current best guess at skew
        self.rot : current best guess at rotation
        self.scl : current best guess at scaling
        self.total : total number of ptycho positions

        [optional]
        self.correct : for simulated data this contains the correct position

        """
        CXP.log.info('Getting ptycho position mesh.')
        
        if CXP.measurement.ptycho_scan_mesh == 'generate':
            if CXP.measurement.ptycho_scan_type == 'cartesian':
                x2 = 0.5*(CXP.measurement.cartesian_scan_dims[0]-1)
                y2 = 0.5*(CXP.measurement.cartesian_scan_dims[1]-1)
                tmp = map(lambda a: CXP.measurement.cartesian_step_size*a, np.mgrid[-x2:x2+1, -y2:y2+1])
                self.positions.data = [tmp[0].flatten(), tmp[1].flatten()]
                if CXP.reconstruction.flip_mesh_lr:
                    self.log.info('Flip ptycho mesh left-right')
                    self.positions.data[0] = self.data[0][::-1]
                if CXP.reconstruction.flip_mesh_ud:
                    self.log.info('Flip ptycho mesh up-down')
                    self.positions.data[1] = self.data[1][::-1]
                if CXP.reconstruction.flip_fast_axis:
                    self.log.info('Flip ptycho mesh fast axis')
                    tmp0, tmp1 = self.data[0], self.data[1]
                    self.positions.data[0], self.positions.data[1] = tmp1, tmp0
            if CXP.measurement.ptycho_scan_type == 'round_roi':
                self.positions.data = list(round_roi(CXP.measurement.round_roi_diameter, CXP.measurement.round_roi_step_size))
            if CXP.measurement.ptycho_scan_type == 'list':
                l = np.genfromtxt(CXP.measurement.list_scan_filename)
                x_pos, y_pos = [], []
                for element in l:
                    x_pos.append(element[0])
                    y_pos.append(element[1])
                self.positions.data = [sp.array(x_pos), sp.array(y_pos)]


        elif CXP.measurement.ptycho_scan_mesh == 'supplied':
            l = np.genfromtxt(CXP.measurement.list_scan_filename)
            x_pos, y_pos = [], []
            for element in l:
                x_pos.append(element[0])
                y_pos.append(element[1])
            self.positions.data = [sp.array(x_pos), sp.array(y_pos)]

        for element in self.positions.data:
            element /= CXP.dx_s
            element += CXP.ob_p/2
        self.positions.total = len(self.positions.data[0])

        self.positions.correct = [sp.zeros((self.positions.total))]*2
        jit_pix = CXP.reconstruction.initial_position_jitter_radius
        search_pix = CXP.reconstruction.ppc_search_radius

        self.positions.data[0] += jit_pix * uniform(-1, 1, self.positions.total)
        self.positions.data[1] += jit_pix * uniform(-1, 1, self.positions.total)

        if CXP.reconstruction.probe_position_correction:
            self.positions.correct[0] = self.positions.data[0]+0.25*search_pix * uniform(-1, 1, self.positions.total)
            self.positions.correct[1] = self.positions.data[1]+0.25*search_pix * uniform(-1, 1, self.positions.total)
        else:
            self.positions.correct = [self.positions.data[0].copy(), self.positions.data[1].copy()]

        data_copy = CXData(data=list(self.positions.data))
        if not CXP.reconstruction.ptycho_subpixel_shift:
            self.positions.data = [np.round(self.positions.data[0]), np.round(self.positions.data[1])]
            self.positions.correct = [np.round(self.positions.correct[0]), np.round(self.positions.correct[1])]
        CXP.rms_rounding_error = [None]*2

        for i in range(2):
            CXP.rms_rounding_error[i] = sp.sqrt(sp.sum(abs(abs(data_copy.data[i])**2.-abs(self.positions.data[i])**2.)))

        CXP.log.info('RMS Rounding Error (Per Position, X, Y):\t {:2.2f}, {:2.2f}'.format(CXP.rms_rounding_error[0]/len(self.positions.data[0]),
                                                                                           CXP.rms_rounding_error[1]/len(self.positions.data[1])))
Ejemplo n.º 4
0
    def simulate_data(self):
        CXP.log.info('Simulating diffraction patterns.')
        self.sample = CXData()
        self.sample.load(CXP.io.simulation_sample_filename[0])
        self.sample.data[0] = self.sample.data[0].astype(float)
        self.sample.normalise(val=0.8)
        self.sample.data[0]+=0.2
        self.input_probe = CXModal()
        if len(CXP.io.simulation_sample_filename)>1:
            ph = CXData()
            ph.load(CXP.io.simulation_sample_filename[1])
            ph.data[0] = ph.data[0].astype(float)
            ph.normalise(val=np.pi/3)
            self.sample.data[0] = self.sample.data[0]*exp(complex(0., 1.)*ph.data[0])
        p = self.sample.data[0].shape[0]
        ham_window = sp.hamming(p)[:,np.newaxis]*sp.hamming(p)[np.newaxis,:]
        sample_large = CXData(data=sp.zeros((CXP.ob_p, CXP.ob_p), complex))
        sample_large.data[0][CXP.ob_p/2-p/2:CXP.ob_p/2+p/2, CXP.ob_p/2-p/2:CXP.ob_p/2+p/2] = self.sample.data[0]*ham_window

        ker = sp.arange(0, p)
        fwhm = p/3.0
        radker = sp.hypot(*sp.ogrid[-p/2:p/2,-p/2:p/2])
        gaussian = exp(-1.0*(fwhm/2.35)**-2. * radker**2.0 )
        ortho_modes = lambda n1, n2 : gaussian*np.sin(n1*math.pi*ker/p)[:,np.newaxis]*np.sin(n2*math.pi*ker/p)[np.newaxis, :]
        mode_generator = lambda : sp.floor(4*sp.random.random(2))+1

        used_modes = []
        self.input_psi = CXModal()
        
        for mode in range(CXP.reconstruction.probe_modes):
            if mode==0:
                new_mode = [1,1]
            else:
                new_mode = list(mode_generator())
                while new_mode in used_modes:
                    new_mode = list(mode_generator())
            used_modes.append(new_mode)
            CXP.log.info('Simulating mode {:d}: [{:d}, {:d}]'.format(mode, int(new_mode[0]), int(new_mode[1])))
            ph_func = gauss_smooth(np.random.random((p,p)), 10)
            self.input_probe.modes.append(CXData(name='probe{:d}'.format(mode), 
                data=ortho_modes(new_mode[0], new_mode[1])*exp(complex(0.,np.pi)*ph_func/ph_func.max())))
        
        self.input_probe.normalise()
        self.input_probe.orthogonalise()

        for mode in range(CXP.reconstruction.probe_modes):
            p2 = p/2
            x, y = self.positions.correct
            self.input_psi.modes.append(CXData(name='input_psi_mode{:d}'.format(mode), data=[]))
            
            for i in xrange(len(x)):
                if i%(len(x)/10)==0.:
                    CXP.log.info('Simulating diff patt {:d}'.format(i))
                tmp = (CXData.shift(sample_large, -1.0*(x[i]-CXP.ob_p/2), -1.0*(y[i]-CXP.ob_p/2))
                        [CXP.ob_p/2-p2:CXP.ob_p/2+p2, CXP.ob_p/2-p2:CXP.ob_p/2+p2]*
                        self.input_probe[mode][0])
                self.input_psi[mode].data.append(tmp.data[0])

        # Add modes incoherently
        self.det_mod = CXModal.modal_sum(abs(fft2(self.input_psi)))
        self.det_mod.save(path=CXP.io.base_dir+'/'+CXP.io.scan_id+'/raw_data/{:s}.npy'.format('det_mod'))