def calibrate(self, U, images, fringe, wavelength, cam_pixel_size, cam_serial='', dname='', dm_serial='', dmplot_txs=(0, 0, 0), dm_transform=SquareRoot.name, hash1='', n_radial=25, alpha=.75, lambda1=5e-3, status_cb=False): if status_cb: status_cb('Computing Zernike polynomials ...') t1 = time() nu, ns = U.shape xx, yy, shape = fringe.get_unit_aperture() assert (xx.shape == shape) assert (yy.shape == shape) cart = RZern(n_radial) cart.make_cart_grid(xx, yy) LOG.info( f'calibrate(): Computing Zernike polynomials {time() - t1:.1f}') if status_cb: status_cb('Computing masks ...') t1 = time() zfm = cart.matrix(np.isfinite(cart.ZZ[:, 0])) self.cart = cart self.zfm = zfm zfA1, zfA2, mask = self._make_zfAs() LOG.info(f'calibrate(): Computing masks {time() - t1:.1f}') # TODO remove me mask1 = np.sqrt(xx**2 + yy**2) >= 1. assert (np.allclose(mask, mask1)) assert (np.allclose(fringe.mask, mask1)) if status_cb: status_cb('Computing phases 00.00% ...') t1 = time() def make_progress(): prevts = [time()] def f(pc): t = time() dt = t - prevts[0] prevts[0] = t if dt > 1.5 or pc > 99: status_cb(f'Computing phases {pc:05.2f}% ...') return f with Pool() as p: if status_cb: chunksize = ns // (4 * cpu_count()) if chunksize < 4: chunksize = 4 phases = [] progress_fun = make_progress() for i, phi in enumerate( p.imap(PhaseExtract(fringe), [images[i, ...] for i in range(ns)], chunksize), 1): phases.append(phi) progress_fun(100 * i / ns) else: phases = p.map(PhaseExtract(fringe), [images[i, ...] for i in range(ns)]) phases = np.array(phases) inds0 = fix_principal_val(U, phases) inds1 = np.setdiff1d(np.arange(ns), inds0) assert (np.allclose(np.arange(ns), np.sort(np.hstack((inds0, inds1))))) phi0 = phases[inds0, :].mean(axis=0) z0 = lstsq(np.dot(zfA1.T, zfA1), np.dot(zfA1.T, phi0), rcond=None)[0] phases -= phi0.reshape(1, -1) LOG.info(f'calibrate(): Computing phases {time() - t1:.1f}') if status_cb: status_cb('Computing least-squares matrices ...') t1 = time() nphi = phases.shape[1] uiuiT = np.zeros((nu, nu)) phiiuiT = np.zeros((nphi, nu)) for i in inds1: uiuiT += np.dot(U[:, [i]], U[:, [i]].T) phiiuiT += np.dot(phases[[i], :].T, U[:, [i]].T) LOG.info( f'calibrate(): Computing least-squares matrices {time() - t1:.1f}') if status_cb: status_cb('Solving least-squares ...') t1 = time() A = np.dot(zfA1.T, zfA1) C = np.dot(zfA1.T, phiiuiT) B = uiuiT U1 = cholesky(A, lower=False, overwrite_a=True) Y = solve_triangular(U1, C, trans='T', lower=False) D = solve_triangular(U1, Y, trans='N', lower=False) U2 = cholesky(B, lower=False, overwrite_a=True) YT = solve_triangular(U2, D.T, trans='T', lower=False) XT = solve_triangular(U2, YT, trans='N', lower=False) H = XT.T def vaf(y, ye): return 100 * (1 - np.var(y - ye, axis=1) / np.var(y, axis=1)) mvaf = vaf(phases.T, zfA1 @ H @ U) LOG.info(f'calibrate(): Solving least-squares {time() - t1:.1f}') if status_cb: status_cb('Applying regularisation ...') t1 = time() if alpha > 0.: # weighted least squares rr = np.sqrt(xx**2 + yy**2) win = .5 * (1 + np.cos(np.pi * ((2 * rr / (alpha) - 2 / alpha + 1)))) win[rr < 1 - alpha / 2] = 1 win[rr >= 1] = 0 stds = np.zeros(nu) for i in range(nu): ind = np.where(U[i, :] == U.max())[0][0] stds[i] = np.std(phases[ind] * win[zfm]) stds -= stds.min() stds /= stds.max() assert (stds.min() == 0.) assert (stds.max() == 1.) C = np.dot(pinv(lambda1 * np.diag(1 - stds) + np.dot(H.T, H)), H.T) else: C = np.linalg.pinv(H) uflat = -np.dot(C, z0) self.fringe = fringe self.shape = shape self.H = H self.mvaf = mvaf self.phi0 = phi0 self.z0 = z0 self.uflat = uflat self.C = C self.alpha = alpha self.lambda1 = lambda1 self.wavelength = wavelength self.dm_serial = dm_serial self.dm_transform = dm_transform self.cam_pixel_size = cam_pixel_size self.cam_serial = cam_serial self.dmplot_txs = dmplot_txs self.dname = dname self.hash1 = hash1 LOG.info(f'calibrate(): Applying regularisation {time() - t1:.1f}')
class ZernikePanel(QWidget): def_pars = {'zernike_labels': {}, 'shown_modes': 21} def __init__(self, wavelength, n_radial, z0=None, callback=None, pars={}, parent=None): super().__init__(parent=parent) self.log = logging.getLogger(self.__class__.__name__) self.pars = {**deepcopy(self.def_pars), **deepcopy(pars)} self.units = 'rad' self.status = None self.mul = 1.0 self.figphi = None self.ax = None self.im = None self.cb = None self.shape = (128, 128) self.P = 1 self.rzern = RZern(n_radial) dd = np.linspace(-1, 1, self.shape[0]) xv, yv = np.meshgrid(dd, dd) self.rzern.make_cart_grid(xv, yv) self.rad_to_nm = wavelength / (2 * np.pi) self.callback = callback self.zernike_rows = [] if z0 is None: self.z = np.zeros(self.rzern.nk) else: self.z = z0.copy() assert (self.rzern.nk == self.z.size) group_phase = QGroupBox('phase') lay_phase = QGridLayout() group_phase.setLayout(lay_phase) self.figphi = FigureCanvas(Figure(figsize=(2, 2))) self.ax = self.figphi.figure.add_subplot(1, 1, 1) phi = self.rzern.matrix(self.rzern.eval_grid(np.dot(self.P, self.z))) self.im = self.ax.imshow(phi, origin='lower') self.cb = self.figphi.figure.colorbar(self.im) self.cb.locator = ticker.MaxNLocator(nbins=5) self.cb.update_ticks() self.ax.axis('off') self.status = QLabel('') lay_phase.addWidget(self.figphi, 0, 0) lay_phase.addWidget(self.status, 1, 0) def nmodes(): return min(self.pars['shown_modes'], self.rzern.nk) bot = QGroupBox('Zernike') lay_zern = QGridLayout() bot.setLayout(lay_zern) labzm = QLabel('shown modes') lezm = QLineEdit(str(nmodes())) lezm.setMaximumWidth(50) lezmval = MyQIntValidator(1, self.rzern.nk) lezmval.setFixup(nmodes()) lezm.setValidator(lezmval) brad = QCheckBox('rad') brad.setChecked(True) breset = QPushButton('reset') lay_zern.addWidget(labzm, 0, 0) lay_zern.addWidget(lezm, 0, 1) lay_zern.addWidget(brad, 0, 2) lay_zern.addWidget(breset, 0, 3) scroll = QScrollArea() lay_zern.addWidget(scroll, 1, 0, 1, 5) scroll.setWidget(QWidget()) scrollLayout = QGridLayout(scroll.widget()) scroll.setWidgetResizable(True) def make_hand_slider(ind): def f(r): self.z[ind] = r self.update_phi_plot() return f def make_hand_lab(le, i): def f(): self.pars['zernike_labels'][str(i)] = le.text() return f def default_zernike_name(i, n, m): if i == 1: return 'piston' elif i == 2: return 'tip' elif i == 3: return 'tilt' elif i == 4: return 'defocus' elif m == 0: return 'spherical' elif abs(m) == 1: return 'coma' elif abs(m) == 2: return 'astigmatism' elif abs(m) == 3: return 'trefoil' elif abs(m) == 4: return 'quadrafoil' elif abs(m) == 5: return 'pentafoil' else: return '' def make_update_zernike_rows(): def f(mynk=None): if mynk is None: mynk = len(self.zernike_rows) ntab = self.rzern.ntab mtab = self.rzern.mtab if len(self.zernike_rows) < mynk: for i in range(len(self.zernike_rows), mynk): lab = QLabel( f'Z<sub>{i + 1}</sub> ' + f'Z<sub>{ntab[i]}</sub><sup>{mtab[i]}</sup>') slider = RelSlider(self.z[i], make_hand_slider(i)) if str(i) in self.pars['zernike_labels'].keys(): zname = self.pars['zernike_labels'][str(i)] else: zname = default_zernike_name( i + 1, ntab[i], mtab[i]) self.pars['zernike_labels'][str(i)] = zname lbn = QLineEdit(zname) lbn.setMaximumWidth(120) hand_lab = make_hand_lab(lbn, i) lbn.editingFinished.connect(hand_lab) scrollLayout.addWidget(lab, i, 0) scrollLayout.addWidget(lbn, i, 1) slider.add_to_layout(scrollLayout, i, 2) self.zernike_rows.append((lab, slider, lbn, hand_lab)) assert (len(self.zernike_rows) == mynk) elif len(self.zernike_rows) > mynk: for i in range(len(self.zernike_rows) - 1, mynk - 1, -1): lab, slider, lbn, hand_lab = self.zernike_rows.pop() scrollLayout.removeWidget(lab) slider.remove_from_layout(scrollLayout) scrollLayout.removeWidget(lbn) lbn.editingFinished.disconnect(hand_lab) lab.setParent(None) lbn.setParent(None) assert (len(self.zernike_rows) == mynk) return f self.update_zernike_rows = make_update_zernike_rows() def reset_fun(): self.z *= 0. self.update_gui_controls() self.update_phi_plot() def change_nmodes(): try: ival = int(lezm.text()) assert (ival > 0) assert (ival <= self.rzern.nk) except Exception: lezm.setText(str(len(self.zernike_rows))) return if ival != len(self.zernike_rows): self.update_zernike_rows(ival) self.update_phi_plot() lezm.setText(str(len(self.zernike_rows))) def f2(): def f(b): if b: self.units = 'rad' self.mul = 1.0 else: self.units = 'nm' self.mul = self.rad_to_nm self.update_phi_plot() return f self.update_zernike_rows(nmodes()) brad.stateChanged.connect(f2()) breset.clicked.connect(reset_fun) lezm.editingFinished.connect(change_nmodes) splitv = QSplitter(Qt.Vertical) top = QSplitter(Qt.Horizontal) top.addWidget(group_phase) splitv.addWidget(top) splitv.addWidget(bot) self.top = top self.bot = bot l1 = QGridLayout() l1.addWidget(splitv) self.setLayout(l1) self.lezm = lezm def save_parameters(self, merge={}): d = {**merge, **self.pars} d['shown_modes'] = len(self.zernike_rows) return d def load_parameters(self, d): self.pars = {**deepcopy(self.def_pars), **deepcopy(d)} nmodes = min(self.pars['shown_modes'], self.rzern.nk) self.pars['shown_modes'] = nmodes self.lezm.blockSignals(True) self.lezm.setText(str(nmodes)) self.lezm.blockSignals(False) self.update_zernike_rows(0) self.update_zernike_rows(nmodes) def update_gui_controls(self): for i, t in enumerate(self.zernike_rows): slider = t[1] slider.block() slider.set_value(self.z[i]) slider.unblock() def update_phi_plot(self, run_callback=True): phi = self.mul * self.rzern.matrix( self.rzern.eval_grid(np.dot(self.P, self.z))) inner = phi[np.isfinite(phi)] min1 = inner.min() max1 = inner.max() rms = self.mul * norm(self.z) self.status.setText( '{} [{: 03.2f} {: 03.2f}] {: 03.2f} PV {: 03.2f} RMS'.format( self.units, min1, max1, max1 - min1, rms)) self.im.set_data(phi) self.im.set_clim(inner.min(), inner.max()) self.figphi.figure.canvas.draw() if self.callback and run_callback: self.callback(self.z)