forked from david-hoffman/pyotf
-
Notifications
You must be signed in to change notification settings - Fork 0
/
zernike.py
258 lines (227 loc) · 8.25 KB
/
zernike.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# zernike.py
"""
A module defining the zernike polynomials and associated functions to convert
between radial and azimuthal degree pairs and Noll's indices.
Running this file as a script will output a graph of the first 15 zernike
polynomials on the unit disk.
https://en.wikipedia.org/wiki/Zernike_polynomials
http://mathworld.wolfram.com/ZernikePolynomial.html
Copyright (c) 2016, David Hoffman
"""
import numpy as np
from scipy.special import eval_jacobi
from .utils import cart2pol
from .otf import HanserPSF
from numpy.fft import fftshift, ifftshift
# forward mapping of Noll indices https://oeis.org/A176988
noll_mapping = np.array([
1, 3, 2, 5, 4, 6, 9, 7, 8, 10, 15, 13, 11, 12, 14, 21, 19, 17, 16,
18, 20, 27, 25, 23, 22, 24, 26, 28, 35, 33, 31, 29, 30, 32, 34,
36, 45, 43, 41, 39, 37, 38, 40, 42, 44, 55, 53, 51, 49, 47, 46,
48, 50, 52, 54, 65, 63, 61, 59, 57, 56, 58, 60, 62, 64, 66, 77,
75, 73, 71, 69, 67, 68, 70, 72, 74, 76, 78, 91, 89, 87, 85, 83,
81, 79, 80, 82, 84, 86, 88, 90, 105, 103, 101, 99, 97, 95, 93,
92, 94, 96, 98, 100, 102, 104, 119, 117, 115, 113, 111, 109,
107, 106, 108, 110, 112, 114, 116, 118, 120
])
# reverse mapping of noll indices
noll_inverse = noll_mapping.argsort()
# classical names for the Noll indices
# https://en.wikipedia.org/wiki/Zernike_polynomials
noll2name = {
1: "Piston",
2: "Tip (lateral position) (X-Tilt)",
3: "Tilt (lateral position) (Y-Tilt)",
4: "Defocus (longitudinal position)",
5: "Oblique astigmatism",
6: "Vertical astigmatism",
7: "Vertical coma",
8: "Horizontal coma",
9: "Vertical trefoil",
10: "Oblique trefoil",
11: "Primary spherical",
12: "Vertical secondary astigmatism",
13: "Oblique secondary astigmatism",
14: "Vertical quadrafoil",
15: "Oblique quadrafoil"
}
name2noll = {v: k for k, v in noll2name.items()}
def gen_psf(params, mcoefs=[1, 0], pcoefs=[0, 0], mclass=HanserPSF):
from .phaseretrieval import ZernikeDecomposition
mcoefs = np.array(mcoefs)
pcoefs = np.array(pcoefs)
if len(mcoefs) > len(pcoefs):
_pcoefs = np.zeros_like(mcoefs)
_pcoefs[:len(pcoefs)] = pcoefs
pcoefs = _pcoefs
if len(pcoefs) > len(mcoefs):
_mcoefs = np.zeros_like(pcoefs)
_mcoefs[:len(mcoefs)] = mcoefs
mcoefs = _mcoefs
model = mclass(**params)
model._gen_kr()
r, theta = fftshift(model._kr), fftshift(model._phi)
r = r / (model.na / model.wl)
zerns = zernike(r, theta, np.arange(1, len(mcoefs) + 1))
zd = ZernikeDecomposition(mcoefs, pcoefs, zerns)
pupil = ifftshift(zd.complex_pupil(sphase=slice(4, None, None)))
model._gen_psf(pupil)
return model, pupil
def noll2degrees(noll):
"""Convert from Noll's indices to radial degree and azimuthal degree"""
noll = np.asarray(noll)
if not np.issubdtype(noll.dtype, np.signedinteger):
raise ValueError("input is not integer, input = {}".format(noll))
if not (noll > 0).all():
raise ValueError(
"Noll indices must be greater than 0, input = {}".format(noll))
# need to subtract 1 from the Noll's indices because they start at 1.
p = noll_inverse[noll - 1]
n = np.ceil((-3 + np.sqrt(9 + 8 * p)) / 2)
m = 2 * p - n * (n + 2)
return n.astype(int), m.astype(int)
def degrees2noll(n, m):
"""Convert from radial and azimuthal degrees to Noll's index"""
n, m = np.asarray(n), np.asarray(m)
# check inputs
if not np.issubdtype(n.dtype, int):
raise ValueError(
"Radial degree is not integer, input = {}".format(n))
if not np.issubdtype(m.dtype, int):
raise ValueError(
"Azimuthal degree is not integer, input = {}".format(m))
if ((n - m) % 2).any():
raise ValueError(
"The difference between radial and azimuthal degree isn't mod 2")
# do the mapping
p = (m + n * (n + 2)) / 2
noll = noll_mapping[p.astype(int)]
return noll
def zernike(r, theta, *args, **kwargs):
"""Calculates the Zernike polynomial on the unit disk for the requested
orders
Parameters
----------
r : ndarray
theta : ndarray
Args
----
Noll : numeric or numeric sequence
Noll's Indices to generate
(n, m) : tuple of numerics or numeric sequences
Radial and azimuthal degrees
n : see above
m : see above
Kwargs
------
norm : bool (default False)
Do you want the output normed?
Returns
-------
zernike : ndarray
The zernike polynomials corresponding to Noll or (n, m) whichever are
provided
Example
-------
>>> x = np.linspace(-1, 1, 512)
>>> xx, yy = np.meshgrid(x, x)
>>> r, theta = cart2pol(yy, xx)
>>> zern = zernike(r, theta, 4) # generates the defocus zernike polynomial
"""
if len(args) == 1:
args = np.asarray(args[0])
if args.ndim < 2:
n, m = noll2degrees(args)
elif args.ndim == 2:
if args.shape[0] == 2:
n, m = args
else:
raise RuntimeError("This shouldn't happen")
else:
raise ValueError("{} is the wrong shape".format(args.shape))
elif len(args) == 2:
n, m = np.asarray(args)
if n.ndim > 1:
raise ValueError("Radial degree has the wrong shape")
if m.ndim > 1:
raise ValueError("Azimuthal degree has the wrong shape")
if n.shape != m.shape:
raise ValueError(
"Radial and Azimuthal degrees have different shapes")
else:
raise ValueError(
"{} is an invalid number of arguments".format(len(args)))
# make sure r and theta are arrays
r = np.asarray(r, dtype=float)
theta = np.asarray(theta, dtype=float)
# make sure that r is always greater than 0
if not (r >= 0).all():
raise ValueError("r must always be greater or equal to 0")
if r.ndim > 2:
raise ValueError(
"Input rho and theta cannot have more than two dimensions")
# make sure that n and m are iterable
n, m = n.ravel(), m.ravel()
# make sure that n is always greater or equal to m
if not (n >= abs(m)).all():
raise ValueError("n must always be greater or equal to m")
# return column of zernike polynomials
return np.array([_zernike(r, theta, nn, mm, **kwargs)
for nn, mm in zip(n, m)]).squeeze()
def _radial_zernike(r, n, m):
"""The radial part of the zernike polynomial
Formula from http://mathworld.wolfram.com/ZernikePolynomial.html"""
rad_zern = np.zeros_like(r)
# zernike polynomials are only valid for r <= 1
valid_points = r <= 1.0
if m == 0 and n == 0:
rad_zern[valid_points] = 1
return rad_zern
rprime = r[valid_points]
# for the radial part m is always positive
m = abs(m)
# calculate the coefs
coef1 = (n + m) // 2
coef2 = (n - m) // 2
jacobi = eval_jacobi(coef2, m, 0, 1 - 2 * rprime**2)
rad_zern[valid_points] = (-1)**coef1 * rprime**m * jacobi
return rad_zern
def _zernike(r, theta, n, m, norm=False):
"""The actual function that calculates the full zernike polynomial"""
# remember if m is negative
mneg = m < 0
# going forward m is positive (Radial zernikes are only defined for
# positive m)
m = abs(m)
# if m and n aren't seperated by multiple of two then return zeros
if (n - m) % 2:
return np.zeros_like(r)
zern = _radial_zernike(r, n, m)
if mneg:
# odd zernike
zern *= np.sin(m * theta)
else:
# even zernike
zern *= np.cos(m * theta)
# calculate the normalization factor
if norm:
raise NotImplementedError
return zern
if __name__ == "__main__":
from matplotlib import pyplot as plt
# make coordinates
x = np.linspace(-1, 1, 1025)
xx, yy = np.meshgrid(x, x) # xy indexing is default
r, theta = cart2pol(yy, xx)
# set up plot
fig, axs = plt.subplots(3, 5, figsize=(20, 12))
# fill out plot
for ax, (k, v) in zip(axs.ravel(), noll2name.items()):
zern = zernike(r, theta, k, norm=False)
ax.matshow(zern, vmin=-1, vmax=1, cmap="seismic")
ax.set_title(v + r", $Z_{{{}}}^{{{}}}$".format(*noll2degrees(k)))
ax.axis("off")
fig.tight_layout()
plt.show()