def test_zernike_edges(num): """Make sure same result is obtained at 0 and 0.0 and 1 and 1.0""" n, m = choose_random_nm() theta = np.random.rand() * 2 * np.pi - np.pi assert zernike(float(num), theta, n, m) == zernike(int(num), theta, n, m), f"theta, n, m = {theta}, {n}, {m}"
def test_norm_sum(): """Test RMS of sum of zernikes is the square root of the sum of the coefficients.""" # set up coordinates x = np.linspace(-1, 1, 2048) xx, yy = np.meshgrid(x, x) # xy indexing is default r, theta = cart2pol(yy, xx) # make coefs np.random.seed(12345) c0, c1 = np.random.randn(2) astig_zern = c0 * zernike(r, theta, 2, -2, norm=True) spherical_zern = c1 * zernike(r, theta, 3, -3, norm=True) np.testing.assert_allclose(abs(c0), np.sqrt((astig_zern[r <= 1]**2).mean()), atol=1e-3, rtol=1e-3) np.testing.assert_allclose(abs(c1), np.sqrt((spherical_zern[r <= 1]**2).mean()), atol=1e-3, rtol=1e-3) np.testing.assert_allclose( np.sqrt(c0**2 + c1**2), np.sqrt(((astig_zern + spherical_zern)[r <= 1]**2).mean()), atol=1e-3, rtol=1e-3, )
def test_zernike_return_shape(): """Make sure that the return shape matches input shape""" x = np.linspace(-1, 1, 512) xx, yy = np.meshgrid(x, x) r, theta = cart2pol(yy, xx) zern = zernike(r, theta, 10) assert zern.shape == r.shape
def test_zernike_zero(): """Make sure same result is obtained for integer and float""" n, m = choose_random_nm() r = 0.5 theta = np.random.rand() * 2 * np.pi - np.pi assert np.isfinite(zernike( r, theta, n, m)).all(), f"r, theta, n, m = {r}, {theta}, {n}, {m}"
def test_odd_nm(): """Make sure that n and m seperated by odd numbers gives zeros""" n, m = choose_random_nm(True) theta = np.random.rand(100) * 2 * np.pi - np.pi # we'll check outside the normal range too, when r r = np.random.rand(100) * 2 assert (zernike(r, theta, n, m) == 0).all(), f"theta, n, m = {theta}, {n}, {m}"
def test_norm(): """Test that normalization works.""" # set up coordinates x = np.linspace(-1, 1, 2048) xx, yy = np.meshgrid(x, x) # xy indexing is default r, theta = cart2pol(yy, xx) # fill out plot for (n, m), v in sorted(degrees2name.items())[0:]: zern = zernike(r, theta, n, m, norm=True) tol = 10.0**(n - 6) np.testing.assert_allclose(1.0, np.sqrt((zern[r <= 1]**2).mean()), err_msg=f"{v} failed!", atol=tol, rtol=tol)
def test_r_theta_dims(): """Make sure that a ValueError is raised if the dims are greater than 2""" r = np.ones((3, 3, 3)) with pytest.raises(ValueError): zernike(r, r, 10)
def test_n_lt_m(): """n must always be greater than or equal to m""" with pytest.raises(ValueError): zernike(0.5, 0.0, 4, 5)
def test_zernike_errors(test_input): """Make sure zernike doesn't accept bad input.""" with pytest.raises(ValueError): zernike(*test_input)