Example #1
0
    def calibrate(self,
                  U,
                  images,
                  fringe,
                  wavelength,
                  cam_pixel_size,
                  cam_serial='',
                  dname='',
                  dm_serial='',
                  dmplot_txs=(0, 0, 0),
                  dm_transform=SquareRoot.name,
                  hash1='',
                  n_radial=25,
                  alpha=.75,
                  lambda1=5e-3,
                  status_cb=False):

        if status_cb:
            status_cb('Computing Zernike polynomials ...')
        t1 = time()
        nu, ns = U.shape
        xx, yy, shape = fringe.get_unit_aperture()
        assert (xx.shape == shape)
        assert (yy.shape == shape)
        cart = RZern(n_radial)
        cart.make_cart_grid(xx, yy)
        LOG.info(
            f'calibrate(): Computing Zernike polynomials {time() - t1:.1f}')

        if status_cb:
            status_cb('Computing masks ...')
        t1 = time()
        zfm = cart.matrix(np.isfinite(cart.ZZ[:, 0]))
        self.cart = cart
        self.zfm = zfm
        zfA1, zfA2, mask = self._make_zfAs()
        LOG.info(f'calibrate(): Computing masks {time() - t1:.1f}')

        # TODO remove me
        mask1 = np.sqrt(xx**2 + yy**2) >= 1.
        assert (np.allclose(mask, mask1))
        assert (np.allclose(fringe.mask, mask1))

        if status_cb:
            status_cb('Computing phases 00.00% ...')
        t1 = time()

        def make_progress():
            prevts = [time()]

            def f(pc):
                t = time()
                dt = t - prevts[0]
                prevts[0] = t
                if dt > 1.5 or pc > 99:
                    status_cb(f'Computing phases {pc:05.2f}% ...')

            return f

        with Pool() as p:
            if status_cb:
                chunksize = ns // (4 * cpu_count())
                if chunksize < 4:
                    chunksize = 4
                phases = []
                progress_fun = make_progress()
                for i, phi in enumerate(
                        p.imap(PhaseExtract(fringe),
                               [images[i, ...] for i in range(ns)], chunksize),
                        1):
                    phases.append(phi)
                    progress_fun(100 * i / ns)
            else:
                phases = p.map(PhaseExtract(fringe),
                               [images[i, ...] for i in range(ns)])
            phases = np.array(phases)

        inds0 = fix_principal_val(U, phases)
        inds1 = np.setdiff1d(np.arange(ns), inds0)
        assert (np.allclose(np.arange(ns), np.sort(np.hstack((inds0, inds1)))))
        phi0 = phases[inds0, :].mean(axis=0)
        z0 = lstsq(np.dot(zfA1.T, zfA1), np.dot(zfA1.T, phi0), rcond=None)[0]
        phases -= phi0.reshape(1, -1)
        LOG.info(f'calibrate(): Computing phases {time() - t1:.1f}')

        if status_cb:
            status_cb('Computing least-squares matrices ...')
        t1 = time()
        nphi = phases.shape[1]
        uiuiT = np.zeros((nu, nu))
        phiiuiT = np.zeros((nphi, nu))
        for i in inds1:
            uiuiT += np.dot(U[:, [i]], U[:, [i]].T)
            phiiuiT += np.dot(phases[[i], :].T, U[:, [i]].T)
        LOG.info(
            f'calibrate(): Computing least-squares matrices {time() - t1:.1f}')
        if status_cb:
            status_cb('Solving least-squares ...')
        t1 = time()
        A = np.dot(zfA1.T, zfA1)
        C = np.dot(zfA1.T, phiiuiT)
        B = uiuiT
        U1 = cholesky(A, lower=False, overwrite_a=True)
        Y = solve_triangular(U1, C, trans='T', lower=False)
        D = solve_triangular(U1, Y, trans='N', lower=False)
        U2 = cholesky(B, lower=False, overwrite_a=True)
        YT = solve_triangular(U2, D.T, trans='T', lower=False)
        XT = solve_triangular(U2, YT, trans='N', lower=False)
        H = XT.T

        def vaf(y, ye):
            return 100 * (1 - np.var(y - ye, axis=1) / np.var(y, axis=1))

        mvaf = vaf(phases.T, zfA1 @ H @ U)
        LOG.info(f'calibrate(): Solving least-squares {time() - t1:.1f}')

        if status_cb:
            status_cb('Applying regularisation ...')
        t1 = time()
        if alpha > 0.:
            # weighted least squares
            rr = np.sqrt(xx**2 + yy**2)
            win = .5 * (1 + np.cos(np.pi * ((2 * rr /
                                             (alpha) - 2 / alpha + 1))))
            win[rr < 1 - alpha / 2] = 1
            win[rr >= 1] = 0

            stds = np.zeros(nu)
            for i in range(nu):
                ind = np.where(U[i, :] == U.max())[0][0]
                stds[i] = np.std(phases[ind] * win[zfm])
            stds -= stds.min()
            stds /= stds.max()
            assert (stds.min() == 0.)
            assert (stds.max() == 1.)

            C = np.dot(pinv(lambda1 * np.diag(1 - stds) + np.dot(H.T, H)), H.T)
        else:
            C = np.linalg.pinv(H)
        uflat = -np.dot(C, z0)

        self.fringe = fringe
        self.shape = shape

        self.H = H
        self.mvaf = mvaf
        self.phi0 = phi0
        self.z0 = z0
        self.uflat = uflat
        self.C = C
        self.alpha = alpha
        self.lambda1 = lambda1

        self.wavelength = wavelength
        self.dm_serial = dm_serial
        self.dm_transform = dm_transform
        self.cam_pixel_size = cam_pixel_size
        self.cam_serial = cam_serial
        self.dmplot_txs = dmplot_txs
        self.dname = dname
        self.hash1 = hash1

        LOG.info(f'calibrate(): Applying regularisation {time() - t1:.1f}')
Example #2
0
class ZernikePanel(QWidget):

    def_pars = {'zernike_labels': {}, 'shown_modes': 21}

    def __init__(self,
                 wavelength,
                 n_radial,
                 z0=None,
                 callback=None,
                 pars={},
                 parent=None):
        super().__init__(parent=parent)
        self.log = logging.getLogger(self.__class__.__name__)

        self.pars = {**deepcopy(self.def_pars), **deepcopy(pars)}
        self.units = 'rad'
        self.status = None
        self.mul = 1.0
        self.figphi = None
        self.ax = None
        self.im = None
        self.cb = None
        self.shape = (128, 128)
        self.P = 1

        self.rzern = RZern(n_radial)
        dd = np.linspace(-1, 1, self.shape[0])
        xv, yv = np.meshgrid(dd, dd)
        self.rzern.make_cart_grid(xv, yv)
        self.rad_to_nm = wavelength / (2 * np.pi)
        self.callback = callback
        self.zernike_rows = []

        if z0 is None:
            self.z = np.zeros(self.rzern.nk)
        else:
            self.z = z0.copy()
        assert (self.rzern.nk == self.z.size)

        group_phase = QGroupBox('phase')
        lay_phase = QGridLayout()
        group_phase.setLayout(lay_phase)
        self.figphi = FigureCanvas(Figure(figsize=(2, 2)))
        self.ax = self.figphi.figure.add_subplot(1, 1, 1)
        phi = self.rzern.matrix(self.rzern.eval_grid(np.dot(self.P, self.z)))
        self.im = self.ax.imshow(phi, origin='lower')
        self.cb = self.figphi.figure.colorbar(self.im)
        self.cb.locator = ticker.MaxNLocator(nbins=5)
        self.cb.update_ticks()
        self.ax.axis('off')
        self.status = QLabel('')
        lay_phase.addWidget(self.figphi, 0, 0)
        lay_phase.addWidget(self.status, 1, 0)

        def nmodes():
            return min(self.pars['shown_modes'], self.rzern.nk)

        bot = QGroupBox('Zernike')
        lay_zern = QGridLayout()
        bot.setLayout(lay_zern)
        labzm = QLabel('shown modes')
        lezm = QLineEdit(str(nmodes()))
        lezm.setMaximumWidth(50)
        lezmval = MyQIntValidator(1, self.rzern.nk)
        lezmval.setFixup(nmodes())
        lezm.setValidator(lezmval)

        brad = QCheckBox('rad')
        brad.setChecked(True)
        breset = QPushButton('reset')
        lay_zern.addWidget(labzm, 0, 0)
        lay_zern.addWidget(lezm, 0, 1)
        lay_zern.addWidget(brad, 0, 2)
        lay_zern.addWidget(breset, 0, 3)

        scroll = QScrollArea()
        lay_zern.addWidget(scroll, 1, 0, 1, 5)
        scroll.setWidget(QWidget())
        scrollLayout = QGridLayout(scroll.widget())
        scroll.setWidgetResizable(True)

        def make_hand_slider(ind):
            def f(r):
                self.z[ind] = r
                self.update_phi_plot()

            return f

        def make_hand_lab(le, i):
            def f():
                self.pars['zernike_labels'][str(i)] = le.text()

            return f

        def default_zernike_name(i, n, m):
            if i == 1:
                return 'piston'
            elif i == 2:
                return 'tip'
            elif i == 3:
                return 'tilt'
            elif i == 4:
                return 'defocus'
            elif m == 0:
                return 'spherical'
            elif abs(m) == 1:
                return 'coma'
            elif abs(m) == 2:
                return 'astigmatism'
            elif abs(m) == 3:
                return 'trefoil'
            elif abs(m) == 4:
                return 'quadrafoil'
            elif abs(m) == 5:
                return 'pentafoil'
            else:
                return ''

        def make_update_zernike_rows():
            def f(mynk=None):
                if mynk is None:
                    mynk = len(self.zernike_rows)
                ntab = self.rzern.ntab
                mtab = self.rzern.mtab
                if len(self.zernike_rows) < mynk:
                    for i in range(len(self.zernike_rows), mynk):
                        lab = QLabel(
                            f'Z<sub>{i + 1}</sub> ' +
                            f'Z<sub>{ntab[i]}</sub><sup>{mtab[i]}</sup>')
                        slider = RelSlider(self.z[i], make_hand_slider(i))

                        if str(i) in self.pars['zernike_labels'].keys():
                            zname = self.pars['zernike_labels'][str(i)]
                        else:
                            zname = default_zernike_name(
                                i + 1, ntab[i], mtab[i])
                            self.pars['zernike_labels'][str(i)] = zname
                        lbn = QLineEdit(zname)
                        lbn.setMaximumWidth(120)
                        hand_lab = make_hand_lab(lbn, i)
                        lbn.editingFinished.connect(hand_lab)

                        scrollLayout.addWidget(lab, i, 0)
                        scrollLayout.addWidget(lbn, i, 1)
                        slider.add_to_layout(scrollLayout, i, 2)

                        self.zernike_rows.append((lab, slider, lbn, hand_lab))

                    assert (len(self.zernike_rows) == mynk)

                elif len(self.zernike_rows) > mynk:
                    for i in range(len(self.zernike_rows) - 1, mynk - 1, -1):
                        lab, slider, lbn, hand_lab = self.zernike_rows.pop()

                        scrollLayout.removeWidget(lab)
                        slider.remove_from_layout(scrollLayout)
                        scrollLayout.removeWidget(lbn)

                        lbn.editingFinished.disconnect(hand_lab)
                        lab.setParent(None)
                        lbn.setParent(None)

                    assert (len(self.zernike_rows) == mynk)

            return f

        self.update_zernike_rows = make_update_zernike_rows()

        def reset_fun():
            self.z *= 0.
            self.update_gui_controls()
            self.update_phi_plot()

        def change_nmodes():
            try:
                ival = int(lezm.text())
                assert (ival > 0)
                assert (ival <= self.rzern.nk)
            except Exception:
                lezm.setText(str(len(self.zernike_rows)))
                return

            if ival != len(self.zernike_rows):
                self.update_zernike_rows(ival)
                self.update_phi_plot()
                lezm.setText(str(len(self.zernike_rows)))

        def f2():
            def f(b):
                if b:
                    self.units = 'rad'
                    self.mul = 1.0
                else:
                    self.units = 'nm'
                    self.mul = self.rad_to_nm
                self.update_phi_plot()

            return f

        self.update_zernike_rows(nmodes())

        brad.stateChanged.connect(f2())
        breset.clicked.connect(reset_fun)
        lezm.editingFinished.connect(change_nmodes)

        splitv = QSplitter(Qt.Vertical)
        top = QSplitter(Qt.Horizontal)
        top.addWidget(group_phase)
        splitv.addWidget(top)
        splitv.addWidget(bot)
        self.top = top
        self.bot = bot
        l1 = QGridLayout()
        l1.addWidget(splitv)
        self.setLayout(l1)
        self.lezm = lezm

    def save_parameters(self, merge={}):
        d = {**merge, **self.pars}
        d['shown_modes'] = len(self.zernike_rows)
        return d

    def load_parameters(self, d):
        self.pars = {**deepcopy(self.def_pars), **deepcopy(d)}
        nmodes = min(self.pars['shown_modes'], self.rzern.nk)
        self.pars['shown_modes'] = nmodes
        self.lezm.blockSignals(True)
        self.lezm.setText(str(nmodes))
        self.lezm.blockSignals(False)
        self.update_zernike_rows(0)
        self.update_zernike_rows(nmodes)

    def update_gui_controls(self):
        for i, t in enumerate(self.zernike_rows):
            slider = t[1]
            slider.block()
            slider.set_value(self.z[i])
            slider.unblock()

    def update_phi_plot(self, run_callback=True):
        phi = self.mul * self.rzern.matrix(
            self.rzern.eval_grid(np.dot(self.P, self.z)))
        inner = phi[np.isfinite(phi)]
        min1 = inner.min()
        max1 = inner.max()
        rms = self.mul * norm(self.z)
        self.status.setText(
            '{} [{: 03.2f} {: 03.2f}] {: 03.2f} PV {: 03.2f} RMS'.format(
                self.units, min1, max1, max1 - min1, rms))
        self.im.set_data(phi)
        self.im.set_clim(inner.min(), inner.max())
        self.figphi.figure.canvas.draw()

        if self.callback and run_callback:
            self.callback(self.z)