Beispiel #1
0
class Simulator:
    def __init__(self):
        """The Simulator object simulates the camera, mirror, and dynamic aberrations
        of the system. CIAO can be run in simulation mode by instantiating a simulator
        object, then a sensor object using the simulator as its camera, and then a
        loop object using that sensor and the simulator in place of the mirror."""

        self.frame_timer = FrameTimer('simulator')

        # We need to define a meshes on which to build the simulated spots images
        # and the simulated wavefront:
        self.sy = ccfg.image_height_px
        self.sx = ccfg.image_width_px
        self.wavefront = np.zeros((self.sy, self.sx))

        # Some parameters for the spots image:
        self.dc = 100
        self.spots_range = 2000
        self.spots = np.ones((self.sy, self.sx)) * self.dc
        self.spots = self.noise(self.spots)
        self.pixel_size_m = ccfg.pixel_size_m

        # compute single spot
        self.lenslet_pitch_m = ccfg.lenslet_pitch_m
        self.f = ccfg.lenslet_focal_length_m
        self.L = ccfg.wavelength_m
        fwhm_px = (1.22 * self.L * self.f /
                   self.lenslet_pitch_m) / self.pixel_size_m

        xvec = np.arange(self.sx)
        yvec = np.arange(self.sy)
        xvec = xvec - xvec.mean()
        yvec = yvec - yvec.mean()
        XX, YY = np.meshgrid(xvec, yvec)
        d = np.sqrt(XX**2 + YY**2)

        self.beam_diameter_m = ccfg.beam_diameter_m
        self.beam_radius_m = self.beam_diameter_m / 2.0

        self.disc_diameter = 170
        #self.disc_diameter = ccfg.beam_diameter_m/self.pixel_size_m # was just set to 110

        self.disc = np.zeros((self.sy, self.sx))
        self.disc[np.where(d <= self.disc_diameter)] = 1.0

        self.X = np.arange(self.sx, dtype=np.float) * self.pixel_size_m
        self.Y = np.arange(self.sy, dtype=np.float) * self.pixel_size_m
        self.X = self.X - self.X.mean()
        self.Y = self.Y - self.Y.mean()

        self.XX, self.YY = np.meshgrid(self.X, self.Y)

        self.RR = np.sqrt(self.XX**2 + self.YY**2)
        self.mask = np.zeros(self.RR.shape)
        self.mask[np.where(self.RR <= self.beam_radius_m)] = 1.0

        use_partially_illuminated_lenslets = True
        if use_partially_illuminated_lenslets:
            d_lenslets = int(
                np.ceil(self.beam_diameter_m / self.lenslet_pitch_m))
        else:
            d_lenslets = int(
                np.floor(self.beam_diameter_m / self.lenslet_pitch_m))

        rad = float(d_lenslets) / 2.0

        xx, yy = np.meshgrid(np.arange(d_lenslets), np.arange(d_lenslets))

        xx = xx - float(d_lenslets - 1) / 2.0
        yy = yy - float(d_lenslets - 1) / 2.0

        d = np.sqrt(xx**2 + yy**2)

        self.lenslet_mask = np.zeros(xx.shape, dtype=np.uint8)
        self.lenslet_mask[np.where(d <= rad)] = 1
        self.n_lenslets = int(np.sum(self.lenslet_mask))

        self.x_lenslet_coords = xx * self.lenslet_pitch_m / self.pixel_size_m + self.sx / 2.0
        self.y_lenslet_coords = yy * self.lenslet_pitch_m / self.pixel_size_m + self.sy / 2.0
        in_pupil = np.where(self.lenslet_mask)
        self.x_lenslet_coords = self.x_lenslet_coords[in_pupil]
        self.y_lenslet_coords = self.y_lenslet_coords[in_pupil]

        self.lenslet_boxes = SearchBoxes(self.x_lenslet_coords,
                                         self.y_lenslet_coords,
                                         ccfg.search_box_half_width)

        #plt.plot(self.x_lenslet_coords,self.y_lenslet_coords,'ks')
        #plt.show()

        self.mirror_mask = np.loadtxt(ccfg.mirror_mask_filename)
        self.n_actuators = int(np.sum(self.mirror_mask))

        self.command = np.zeros(self.n_actuators)

        # virtual actuator spacing in magnified or demagnified
        # plane of camera
        actuator_spacing = ccfg.beam_diameter_m / float(
            self.mirror_mask.shape[0])
        ay, ax = np.where(self.mirror_mask)
        ay = ay * actuator_spacing
        ax = ax * actuator_spacing
        ay = ay - ay.mean()
        ax = ax - ax.mean()

        self.flat = np.loadtxt(ccfg.mirror_flat_filename)
        self.flat0 = np.loadtxt(ccfg.mirror_flat_filename)

        self.n_zernike_terms = ccfg.n_zernike_terms
        #actuator_sigma = actuator_spacing*0.75
        actuator_sigma = actuator_spacing * 1.5

        self.exposure = 10000  # microseconds

        key = '%d' % hash((tuple(ax), tuple(ay), actuator_sigma, tuple(
            self.X), tuple(self.Y), self.n_zernike_terms))
        key = key.replace('-', 'm')

        try:
            os.mkdir(ccfg.simulator_cache_directory)
        except OSError as e:
            pass

        cfn = os.path.join(ccfg.simulator_cache_directory,
                           '%s_actuator_basis.npy' % key)

        try:
            self.actuator_basis = np.load(cfn)
            print 'Loading cached actuator basis set...'
        except Exception as e:
            actuator_basis = []
            print 'Building actuator basis set...'
            for x, y in zip(ax, ay):
                xx = self.XX - x
                yy = self.YY - y
                surf = np.exp((-(xx**2 + yy**2) / (2 * actuator_sigma**2)))
                surf = (surf - surf.min()) / (surf.max() - surf.min())
                actuator_basis.append(surf.ravel())
                plt.clf()
                plt.imshow(surf)
                plt.title('generating actuator basis\n%0.2e,%0.2e' % (x, y))
                plt.pause(.1)

            self.actuator_basis = np.array(actuator_basis)
            np.save(cfn, self.actuator_basis)

        zfn = os.path.join(ccfg.simulator_cache_directory,
                           '%s_zernike_basis.npy' % key)
        self.zernike = Zernike()
        try:
            self.zernike_basis = np.load(zfn)
            print 'Loading cached zernike basis set...'
        except Exception as e:
            zernike_basis = []
            print 'Building zernike basis set...'
            #zernike = Zernike()
            for z in range(self.n_zernike_terms):
                surf = self.zernike.get_j_surface(z, self.XX, self.YY)
                zernike_basis.append(surf.ravel())

            self.zernike_basis = np.array(zernike_basis)
            np.save(zfn, self.zernike_basis)

        #self.new_error_sigma = np.ones(self.n_zernike_terms)*100.0

        self.zernike_orders = np.zeros(self.n_zernike_terms)
        for j in range(self.n_zernike_terms):
            self.zernike_orders[j] = self.zernike.j2nm(j)[0]

        self.baseline_error_sigma = 1.0 / (1.0 + self.zernike_orders) * 10.0
        self.baseline_error_sigma[3:6] = 150.0

        self.new_error_sigma = self.zernike_orders / 10.0 / 100.0

        # don't generate any piston, tip, or tilt:
        self.baseline_error_sigma[:3] = 0.0
        self.new_error_sigma[:3] = 0.0

        self.error = self.get_error(self.baseline_error_sigma)

        self.paused = False

    def pause(self):
        self.paused = True

    def unpause(self):
        self.paused = False

    def set_logging(self, val):
        self.logging = val

    def flatten(self):
        self.command[:] = self.flat[:]
        #self.update()

    def set_exposure(self, val):
        self.exposure = val

    def get_exposure(self):
        return self.exposure

    def restore_flat(self):
        self.flat[:] = self.flat0[:]

    def set_flat(self):
        self.flat[:] = self.get_command()[:]

    def get_command(self):
        return self.command

    def set_command(self, vec):
        self.command[:] = vec[:]
        #self.update()

    def set_actuator(self, index, value):
        self.command[index] = value
        self.update()

    def noise(self, im):
        noiserms = np.random.randn(im.shape[0], im.shape[1]) * np.sqrt(im)
        return im + noiserms

    def get_error(self, sigma):
        coefs = np.random.randn(self.n_zernike_terms) * sigma
        return np.reshape(np.dot(coefs, self.zernike_basis),
                          (self.sy, self.sx))

    def defocus_animation(self):
        err = np.zeros(self.n_zernike_terms)
        for k in np.arange(0.0, 100.0):
            err[4] = np.random.randn()
            im = np.reshape(np.dot(err, self.zernike_basis),
                            (self.sy, self.sx))
            plt.clf()
            plt.imshow(im - im.min())
            plt.colorbar()
            plt.pause(.1)

    def plot_actuators(self):
        edge = self.XX.min()
        wid = self.XX.max() - edge
        plt.imshow(self.mask, extent=[edge, edge + wid, edge, edge + wid])
        plt.autoscale(False)
        plt.plot(ax, ay, 'ks')
        plt.show()

    def update(self):
        mirror = np.reshape(np.dot(self.command, self.actuator_basis),
                            (self.sy, self.sx))

        err = self.error + self.get_error(self.new_error_sigma)

        dx = np.diff(err, axis=1)
        dy = np.diff(err, axis=0)
        sy, sx = err.shape
        col = np.zeros((sy, 1))
        row = np.zeros((1, sx))
        dx = np.hstack((col, dx))
        dy = np.vstack((row, dy))
        #err = err - dx - dy
        self.wavefront = mirror + err
        y_slope_vec = []
        x_slope_vec = []
        self.spots[:] = 0.0
        for idx, (x, y, x1, x2, y1, y2) in enumerate(
                zip(self.lenslet_boxes.x, self.lenslet_boxes.y,
                    self.lenslet_boxes.x1, self.lenslet_boxes.x2,
                    self.lenslet_boxes.y1, self.lenslet_boxes.y2)):
            subwf = self.wavefront[y1:y2 + 1, x1:x2 + 1]
            yslope = np.mean(np.diff(subwf.mean(1)))
            dy = yslope * self.f / self.pixel_size_m
            ycentroid = y + dy
            ypx = int(round(y + dy))
            xslope = np.mean(np.diff(subwf.mean(0)))
            dx = xslope * self.f / self.pixel_size_m
            xcentroid = x + dx

            #if idx==20:
            #    continue

            self.spots = self.interpolate_dirac(xcentroid, ycentroid,
                                                self.spots)
            x_slope_vec.append(xslope)
            y_slope_vec.append(yslope)
            #QApplication.processEvents()
        self.spots = np.abs(
            np.fft.ifft2(np.fft.fftshift(np.fft.fft2(self.spots)) * self.disc))
        self.x_slopes = np.array(x_slope_vec)
        self.y_slopes = np.array(y_slope_vec)

    def get_image(self):
        self.update()
        self.frame_timer.tick()
        spots = (self.spots - self.spots.min()) / (
            self.spots.max() - self.spots.min()) * self.spots_range + self.dc
        nspots = self.noise(spots) * (self.exposure / 10000)
        nspots = np.clip(nspots, 0, 4095)
        nspots = np.round(nspots).astype(np.int16)
        return nspots

    def interpolate_dirac(self, x, y, frame):
        # take subpixel precision locations x and y and insert an interpolated
        # delta w/ amplitude 1 there
        #no interpolation case:
        #frame[int(round(y)),int(round(x))] = 1.0
        #return frame
        x1 = int(np.floor(x))
        x2 = x1 + 1
        y1 = int(np.floor(y))
        y2 = y1 + 1

        for yi in [y1, y2]:
            for xi in [x1, x2]:
                yweight = 1.0 - (abs(yi - y))
                xweight = 1.0 - (abs(xi - x))
                try:
                    frame[yi, xi] = yweight * xweight
                except Exception as e:
                    pass
        return frame

    def wavefront_to_spots(self):

        pass

    def show_zernikes(self):
        for k in range(self.n_zernike_terms):
            b = np.reshape(self.zernike_basis[k, :], (self.sy, self.sx))
            plt.clf()
            plt.imshow(b)
            plt.colorbar()
            plt.pause(.5)

    def close(self):
        print 'Closing simulator.'
Beispiel #2
0
def simulateShackHartmannReconstruction():
    """A test function that uses Zernike modes to generate a random simulated wavefront, and then
    uses Zernike derivatives to estimate the original coefficients.
    """
    pupilDiameterM = 10e-3

    pupilRadiusM = pupilDiameterM / 2.0
    unity = 1.0

    z = Zernike()

    nZT = 1+2+3+4+5+6+7

    # Set up some aligned vectors of Zernike index (j), order (n),
    # frequency (m), and tuples of order and frequency (nm).
    jVec = range(nZT)
    nmVec = np.array([z.j2nm(x) for x in jVec])
    nVec = np.array([z.j2nm(x)[0] for x in jVec])
    mVec = np.array([z.j2nm(x)[1] for x in jVec])
    maxOrder = nVec.max()

    # Generate some random Zernike coefficients. Seed the generator to
    # permit iterative comparison of output. Zero the first three modes
    # (piston, tip, tilt).
    
    cVec = np.random.randn(len(jVec))/10000.0
    cVec[:3] = 0.0 # zero tip, tilt, and piston

    # We'll make a densely sampled wavefront using the modes and
    # coefficients specified above.
    N = 1000
    xx,yy = np.meshgrid(np.linspace(-unity,unity,N),np.linspace(-unity,unity,N))
    mask = np.zeros((N,N)) # circular mask to define unit pupil
    d = np.sqrt(xx**2 + yy**2)
    mask[np.where(d<unity)] = 1
    wavefront = np.zeros(mask.shape) # matrix to accumulate wavefront error

    # We build up the simulated wavefront by summing the modes of
    # interest. The simulated wavefront has units of pupil radius; to
    # convert to meters, we have to multiply by the pupil radius in
    # meters.
    for (n,m),coef in zip(nmVec,cVec):
        print 'Adding %0.6f Z(%d,%d) to simulated wavefront.'%(coef,n,m)
        wavefront = wavefront + coef*z.getSurface(n,m,xx,yy,'h')*pupilRadiusM

    def coord2index(coord):
        test = xx[0,:]
        return np.argmin(np.abs(test-coord))


    # In order to compute the slope at each subaperture, we need to know
    # the pixel size.
    pixelSizeM = pupilDiameterM/N

    # Now we'll define a lenslet array as a masked 20x20 grid. Since we
    # want the edges of the outer lenslets to abutt the pupil edge, let's
    # start by enumerating the 21 edge coordinates in the unit line:
    nLensletsAcross = 20
    lensletEdges = np.linspace(-unity,unity,nLensletsAcross+1)
    # The centers can now be specified as the average between the leftmost
    # 20 coordinates and the rightmost 20 coordinates:
    lensletCenters = (lensletEdges[1:]+lensletEdges[:-1])/2.0
    lensletXX,lensletYY = np.meshgrid(lensletCenters,lensletCenters)
    lensletD = np.sqrt(lensletXX**2+lensletYY**2)
    lensletMask = np.zeros(lensletD.shape)
    lensletMask[np.where(lensletD<unity)]=1

    # Pull out the x and y coordinates of the lenslets into vectors, and
    # we'll iterate through these to build corresponding vectors of local
    # slopes:
    lensletXXVector = lensletXX[np.where(lensletMask)]
    lensletYYVector = lensletYY[np.where(lensletMask)]

    # Estimate the local slope in each subaperture:
    xSlopes = []
    ySlopes = []
    for x,y in zip(lensletXXVector,lensletYYVector):
        print 'Estimating slope at (%0.2f,%0.2f).'%(x,y)
        x1i = coord2index(x-.5/nLensletsAcross)
        x2i = coord2index(x+.5/nLensletsAcross)
        y1i = coord2index(y-.5/nLensletsAcross)
        y2i = coord2index(y+.5/nLensletsAcross)
        subap = wavefront[y1i:y2i,x1i:x2i]
        dSubapX_dPx = np.diff(subap,axis=1).mean()/pixelSizeM
        dSubapY_dPx = np.diff(subap,axis=0).mean()/pixelSizeM
        xSlopes.append(dSubapX_dPx)
        ySlopes.append(dSubapY_dPx)

    xSlopes = np.array(xSlopes)
    ySlopes = np.array(ySlopes)
    nLenslets = len(xSlopes)


    # Now that we have a simulated wavefront and simulated slope
    # measurements, let's reconstruct!  Reconstruction consists of:
    #
    #   1) Using zernike.Zernike.getSurface to calculate the x
    #      and y derivatives of each Zernike mode; 
    #
    #   2) Using np.where and lensletMask to order the unmasked
    #      derivatives the same way they're ordered in lensletXXVector and
    #      lensletYYVector above;
    #
    #   3) Assembling these derivatives into a matrix A such that A[n,m]
    #      contains the partial x-derivative of the mth mode at the
    #      location corresponding to the nth lenslet. A[n+k,m] is the
    #      corresponding partial y-derivative, where k is the number of
    #      lenslets.
    #
    #   4) Inverting A and multiplying it by the local slopes gives us
    #      reconstructed Zernike coefficients, which can then be compared
    #      with our randomly chosen initial values.

    A = np.zeros([nLenslets*2,nZT])

    for j in jVec:
        iOrder,iFreq = z.j2nm(j)
        print 'Calculating Zernike derivatives for Z(%d,%d).'%(iOrder,iFreq)
        dzdx = z.getSurface(iOrder,iFreq,lensletXX,lensletYY,'dx')
        dzdy = z.getSurface(iOrder,iFreq,lensletXX,lensletYY,'dy')
        dzdx = dzdx[np.where(lensletMask)]
        dzdy = dzdy[np.where(lensletMask)]
        A[:nLenslets,j] = dzdx
        A[nLenslets:,j] = dzdy

    slopes = np.hstack((xSlopes,ySlopes))

    # Sidebar:
    # Note that once the matrix A is built, we can trivially compute the
    # slopes at our desired locations by simply computing the dot product
    # of A and our Zernike coefficient vector cVec.
    # Let's do this for fun:
    dotSlopes = np.dot(A,cVec)
    plt.figure(figsize=(12,6))
    plt.plot(slopes,dotSlopes,'ks')
    plt.xlabel('slope estimated by local curvature')
    plt.ylabel('slope calculated by Zernike derivatives')
    plt.savefig('./images/slope_comparison.png')


    # Back to the reconstruction. Now we use SVD to invert A and
    # multiply the resulting inverse by the slopes to determine the
    # modal coefficients:
    B = np.linalg.pinv(A)
    rVec = np.dot(B,slopes)


    # Using the reconstructed coefficients, we can reconstruct the
    # wavefront.
    rWavefront = np.zeros(wavefront.shape)
    for (n,m),coef in zip(nmVec,rVec):
        rWavefront = rWavefront + coef*z.getSurface(n,m,xx,yy,'h')*pupilRadiusM


    # Now we plot the results.
    fig = plt.figure(figsize=(12,12))

    fig.add_subplot(2,2,1)
    plt.imshow(wavefront*mask)
    for x,y in zip(lensletXXVector,lensletYYVector):
        plt.plot(coord2index(x),coord2index(y),'ks')

    plt.axis('image')
    plt.axis('tight')
    plt.title('simulated wavefront and sampling locations')

    fig.add_subplot(2,2,2)
    comparison_plot_type='bar'
    xax = np.arange(len(cVec))
    if comparison_plot_type=='scatter':
        ph1=plt.plot(xax,cVec,'gs')[0]
        ph1.set_markersize(8)
        ph1.set_markeredgecolor('k')
        ph1.set_markerfacecolor('g')
        ph1.set_markeredgewidth(2)

        ph2=plt.plot(xax+.35,rVec,'bs')[0]
        ph2.set_markersize(8)
        ph2.set_markeredgecolor('k')
        ph2.set_markerfacecolor('g')
        ph2.set_markeredgewidth(2)

        err = np.sqrt(np.mean((cVec[3:]-rVec[3:])**2))
        xlim = plt.gca().get_xlim()
        ylim = plt.gca().get_ylim()
        plt.text(xlim[1],ylim[0],'$\epsilon=%0.2e$'%err,ha='right',va='bottom')

        plt.legend(['simulated','reconstructed'])

    elif comparison_plot_type=='bar':
        bar_width = .45
        rects1 = plt.bar(xax,cVec,bar_width,color='g',label='simulated')
        rects2 = plt.bar(xax + bar_width,rVec,bar_width,color='b',label='reconstructed')
        plt.legend()
        plt.tight_layout()

    plt.title('reconstructed zernike terms')
    plt.ylabel('coefficient')

    xticks = plt.gca().get_xticks()
    xticklabels = []
    for xtick in xticks:
        try:
            n = nVec[xtick]
            m = mVec[xtick]
            xticklabels.append('$Z^{%d}_{%d}$'%(m,n))
        except Exception as e:
            pass

    plt.gca().set_xticklabels(xticklabels)

    ax = fig.add_subplot(2,2,3,projection='3d')
    surf = ax.plot_wireframe(1e3*xx*pupilRadiusM,1e3*yy*pupilRadiusM,1e6*wavefront*mask,rstride=50,cstride=50,color='k')
    ax.view_init(elev=43., azim=68)
    plt.title('simulated wavefront')
    plt.xlabel('pupil x (mm)')
    plt.ylabel('pupil y (mm)')
    ax.set_zlabel('height ($\mu m$)')

    ax = fig.add_subplot(2,2,4,projection='3d')
    surf = ax.plot_wireframe(1e3*xx*pupilRadiusM,1e3*yy*pupilRadiusM,1e6*rWavefront*mask,rstride=50,cstride=50,color='k')
    ax.view_init(elev=43., azim=68)
    plt.title('reconstructed wavefront')
    plt.xlabel('pupil x (mm)')
    plt.ylabel('pupil y (mm)')
    ax.set_zlabel('height ($\mu m$)')

    plt.savefig('./images/shack_hartmann_reconstruction_simulation.png')