def testMeanCovar(self): # Test basic functionality against RotCov2D. mean_cov2d = self.cov2d.get_mean(self.coeff, ctf_fb=self.ctf_fb, ctf_idx=self.ctf_idx) covar_cov2d = self.cov2d.get_covar( self.coeff, mean_coeff=mean_cov2d, ctf_fb=self.ctf_fb, ctf_idx=self.ctf_idx, noise_var=self.noise_var, ) mean_bcov2d = self.bcov2d.get_mean() covar_bcov2d = self.bcov2d.get_covar(noise_var=self.noise_var) self.assertTrue( np.allclose(mean_cov2d, mean_bcov2d, atol=utest_tolerance(self.dtype))) self.assertTrue( self.blk_diag_allclose(covar_cov2d, covar_bcov2d, atol=utest_tolerance(self.dtype)))
def testArrayImageSourceRotsSetGet(self): """ Test ArrayImageSource `rots` property, setter and getter function. """ # Construct the source for testing src = ArrayImageSource(self.im) # Get some random angles, can use from sim. angles = self.sim.angles self.assertTrue(angles.shape == (src.n, 3)) # Convert to rotation matrix (n,3,3) rotations = R.from_euler("ZYZ", angles).as_matrix() self.assertTrue(rotations.shape == (src.n, 3, 3)) # Excercise the setter src.rots = rotations # Test Rots Getter self.assertTrue( np.allclose(rotations, src.rots, atol=utest_tolerance(self.dtype)) ) # Test Angles Getter self.assertTrue( np.allclose(angles, src.angles, atol=utest_tolerance(self.dtype)) )
def testArrayImageSourceMeanVol(self): """ Test that ArrayImageSource can be consumed by mean/volume codes. Checks that the estimate is consistent with Simulation source. """ # Run estimator with a Simulation source as a reference. sim_estimator = MeanEstimator(self.sim, self.basis, preconditioner="none") sim_est = sim_estimator.estimate() logger.info("Simulation source checkpoint") # Construct the source for testing src = ArrayImageSource(self.im, angles=self.sim.angles) # Instantiate a volume estimator using ArrayImageSource estimator = MeanEstimator(src, self.basis, preconditioner="none") # Get estimate consuming ArrayImageSource est = estimator.estimate() # Compute RMS error and log it for debugging. delta = np.sqrt(np.mean(np.square(est - sim_est))) logger.info(f"Simulation vs ArrayImageSource estimates MRSE: {delta}") # Estimate RMSE should be small. self.assertTrue(delta <= utest_tolerance(self.dtype)) # And the estimate themselves should be close (virtually same inputs). # We should be within same neighborhood as generating sim_est multiple times... self.assertTrue( np.allclose(est, sim_est, atol=10 * utest_tolerance(self.dtype)) )
def testRegister(self): # These will yield two more distinct sets of random rotations wrt self.rot_obj set1 = Rotation.generate_random_rotations(self.num_rots, dtype=self.dtype) set2 = Rotation.generate_random_rotations( self.num_rots, dtype=self.dtype, seed=7 ) # Align both sets of random rotations to rot_obj aligned_rots1 = self.rot_obj.register(set1) aligned_rots2 = self.rot_obj.register(set2) self.assertTrue(aligned_rots1.mse(aligned_rots2) < utest_tolerance(self.dtype)) self.assertTrue(aligned_rots2.mse(aligned_rots1) < utest_tolerance(self.dtype))
def testDownsample(self): # generate a 3D map with density decays as Gaussian function g3d = grid_3d(self.L, dtype=self.dtype) coords = np.array( [g3d["x"].flatten(), g3d["y"].flatten(), g3d["z"].flatten()]) sigma = 0.2 vol = np.exp(-0.5 * np.sum(np.abs(coords / sigma)**2, axis=0)).astype( self.dtype) vol = np.reshape(vol, g3d["x"].shape) vols = Volume(vol) # set noise to zero and CFT filters to unity for simulation object noise_var = 0 noise_filter = ScalarFilter(dim=2, value=noise_var) sim = Simulation( L=self.L, n=self.n, vols=vols, offsets=0.0, amplitudes=1.0, unique_filters=[ ScalarFilter(dim=2, value=1) for d in np.linspace(1.5e4, 2.5e4, 7) ], noise_filter=noise_filter, dtype=self.dtype, ) # get images before downsample imgs_org = sim.images(start=0, num=self.n) # get images after downsample max_resolution = 32 sim.downsample(max_resolution) imgs_ds = sim.images(start=0, num=self.n) # Check individual grid points self.assertTrue( np.allclose( imgs_org[:, 32, 32], imgs_ds[:, 16, 16], atol=utest_tolerance(self.dtype), )) # check resolution self.assertTrue(np.allclose(max_resolution, imgs_ds.shape[1])) # check energy conservation after downsample self.assertTrue( np.allclose( anorm(imgs_org.asnumpy(), axes=(1, 2)) / self.L, anorm(imgs_ds.asnumpy(), axes=(1, 2)) / max_resolution, atol=utest_tolerance(self.dtype), ))
def testMultiplication(self): result = (self.rot_obj * self.rot_obj.invert()).matrices for i in range(len(self.rot_obj)): self.assertTrue( np.allclose(np.eye(3), result[i], atol=utest_tolerance(self.dtype)))
def testMSE(self): q_ang = [np.random.random(3)] q_mat = sp_rot.from_euler("ZYZ", q_ang, degrees=False).as_matrix()[0] for flag in [0, 1]: regrots_ref = self.rot_obj.apply_registration(q_mat, flag) mse = self.rot_obj.mse(regrots_ref) self.assertTrue(mse < utest_tolerance(self.dtype))
def testPolarBasis2DAdjoint(self): # The evaluate function should be the adjoint operator of evaluate_t. # Namely, if A = evaluate, B = evaluate_t, and B=A^t, we will have # (y, A*x) = (A^t*y, x) = (B*y, x) x = randn(self.basis.count, seed=self.seed).astype(self.dtype) x = m_reshape(x, (self.basis.nrad, self.basis.ntheta)) x = (1 / 2 * x[:, :self.basis.ntheta // 2] + 1 / 2 * x[:, :self.basis.ntheta // 2].conj()) x = np.concatenate((x, x.conj()), axis=1) x = m_reshape(x, (self.basis.nrad * self.basis.ntheta, )) x_t = self.basis.evaluate(x).asnumpy() y = randn(np.prod(self.basis.sz), seed=self.seed).astype(self.dtype) y_t = self.basis.evaluate_t( Image(m_reshape(y, self.basis.sz)[np.newaxis, :])) # RCOPT lhs = np.dot(y, m_reshape(x_t, (np.prod(self.basis.sz), ))) rhs = np.real(np.dot(y_t, x)) logging.debug( f"lhs: {lhs} rhs: {rhs} absdiff: {np.abs(lhs-rhs)} atol: {utest_tolerance(self.dtype)}" ) self.assertTrue(np.isclose(lhs, rhs, atol=utest_tolerance(self.dtype)))
def testRegisterRots(self): q_mat = Rotation.generate_random_rotations(1, dtype=self.dtype)[0] for flag in [0, 1]: regrots_ref = self.rot_obj.apply_registration(q_mat, flag) q_mat_est, flag_est = self.rot_obj.find_registration(regrots_ref) self.assertTrue( np.allclose(flag_est, flag) and np.allclose( q_mat_est, q_mat, atol=utest_tolerance(self.dtype)))
def testGetCWFCoeffsClean(self): results = np.load( os.path.join(DATA_DIR, "clean70SRibosome_cov2d_cwf_coeff_clean.npy") ) self.coeff_cwf_clean = self.cov2d.get_cwf_coeffs(self.coeff_clean, noise_var=0) self.assertTrue( np.allclose(results, self.coeff_cwf_clean, atol=utest_tolerance(self.dtype)) )
def blk_diag_allclose(self, blk_diag_a, blk_diag_b, atol=None): if atol is None: atol = utest_tolerance(self.dtype) close = True for blk_a, blk_b in zip(blk_diag_a, blk_diag_b): close = close and np.allclose(blk_a, blk_b, atol=atol) return close
def testExpandEval(self): coef = self.fspca_basis.expand_from_image_basis(self.imgs) recon = self.fspca_basis.evaluate_to_image_basis(coef) # Check recon is close to imgs rmse = np.sqrt( np.mean(np.square(self.imgs.asnumpy() - recon.asnumpy()))) logger.info(f"FSPCA Expand Eval Image Round Trupe RMSE: {rmse}") self.assertTrue(rmse < utest_tolerance(self.dtype))
def testCTFScale(self): filt = CTFFilter(defocus_u=1.5e4, defocus_v=1.5e4) result1 = filt.evaluate(self.omega) scale_value = 2.5 filt = filt.scale(scale_value) # scaling a CTFFilter scales the pixel size which cancels out # a corresponding scaling in omega result2 = filt.evaluate(self.omega * scale_value) self.assertTrue(np.allclose(result1, result2, atol=utest_tolerance(self.dtype)))
def testScaledFilter(self): filt1 = CTFFilter(defocus_u=1.5e4, defocus_v=1.5e4) scale_value = 2.5 result1 = filt1.evaluate(self.omega) # ScaledFilter scales the pixel size which cancels out # a corresponding scaling in omega filt2 = ScaledFilter(filt1, scale_value) result2 = filt2.evaluate(self.omega * scale_value) self.assertTrue(np.allclose(result1, result2, atol=utest_tolerance(self.dtype)))
def testGetCWFCoeffsIdentityCTF(self): results = np.load( os.path.join(DATA_DIR, "clean70SRibosome_cov2d_cwf_coeff_noCTF.npy")) self.coeff_cwf_noCTF = self.cov2d.get_cwf_coeffs( self.coeff, noise_var=self.noise_var) self.assertTrue( np.allclose(results, self.coeff_cwf_noCTF, atol=utest_tolerance(self.dtype)))
def testEstShifts(self): # need to rerun explicitly the estimation of rotations self.orient_est.estimate_rotations() with Random(0): self.est_shifts = self.orient_est.estimate_shifts() results = np.load(os.path.join(DATA_DIR, "orient_est_shifts.npy")) self.assertTrue( np.allclose(results, self.est_shifts, atol=utest_tolerance(self.dtype)))
def testAutoBasis(self): # Make sure basis is automatically created if not specified. nbcov2d = BatchedRotCov2D(self.src) covar_bcov2d = self.bcov2d.get_covar() covar_nbcov2d = nbcov2d.get_covar() self.assertTrue( self.blk_diag_allclose(covar_bcov2d, covar_nbcov2d, atol=utest_tolerance(self.dtype)))
def testFFBBasis2DExpand(self): x = np.load(os.path.join(DATA_DIR, "ffbbasis2d_xcoeff_in_8_8.npy")).T # RCOPT result = self.basis.expand(x.astype(self.dtype)) self.assertTrue( np.allclose( result, np.load( os.path.join(DATA_DIR, "ffbbasis2d_vcoeff_out_exp_8_8.npy"))[..., 0], atol=utest_tolerance(self.dtype), ))
def testAutoMean(self): # Make sure it automatically calls get_mean if needed. covar_cov2d = self.cov2d.get_covar(self.coeff, ctf_fb=self.ctf_fb, ctf_idx=self.ctf_idx) covar_bcov2d = self.bcov2d.get_covar() self.assertTrue( self.blk_diag_allclose(covar_cov2d, covar_bcov2d, atol=utest_tolerance(self.dtype)))
def testFBBasis2DEvaluate(self): coeffs = np.array( [ 1.07338590e-01, 1.23690941e-01, 6.44482039e-03, -5.40484306e-02, -4.85304586e-02, 1.09852144e-02, 3.87838396e-02, 3.43796455e-02, -6.43284705e-03, -2.86677145e-02, -1.42313328e-02, -2.25684091e-03, -3.31840727e-02, -2.59706174e-03, -5.91919887e-04, -9.97433028e-03, 9.19123928e-04, 1.19891589e-03, 7.49154982e-03, 6.18865229e-03, -8.13265715e-04, -1.30715655e-02, -1.44160603e-02, 2.90379956e-03, 2.37066082e-02, 4.88805735e-03, 1.47870707e-03, 7.63376018e-03, -5.60619559e-03, 1.05165081e-02, 3.30510143e-03, -3.48652120e-03, -4.23228797e-04, 1.40484061e-02, ], dtype=self.dtype, ) result = self.basis.evaluate(coeffs) self.assertTrue( np.allclose( result, np.load( os.path.join(DATA_DIR, "fbbasis_evaluation_8_8.npy") ).T, # RCOPT atol=utest_tolerance(self.dtype), ) )
def testRotate(self): """ Trivial test of rotation in FSPCA Basis. Also covers to_real and to_complex conversions in FSPCA Basis. """ coef = self.fspca_basis.expand_from_image_basis(self.imgs) # rotate by pi rot_coef = self.fspca_basis.rotate(coef, radians=np.pi) rot_imgs = self.fspca_basis.evaluate_to_image_basis(rot_coef) for i, img in enumerate(self.imgs): rmse = np.sqrt(np.mean(np.square(np.flip(img) - rot_imgs[i]))) self.assertTrue(rmse < 10 * utest_tolerance(self.dtype))
def testCWFCoeffCleanCTF(self): """ Test case of clean images (coeff_clean and noise_var=0) while using a non Identity CTF. This case may come up when a developer switches between clean and dirty images. """ # Calculate CWF coefficients using Cov2D base class mean_cov2d = self.cov2d.get_mean(self.coeff, ctf_fb=self.ctf_fb, ctf_idx=self.ctf_idx) covar_cov2d = self.cov2d.get_covar( self.coeff, ctf_fb=self.ctf_fb, ctf_idx=self.ctf_idx, noise_var=self.noise_var, make_psd=True, ) coeff_cov2d = self.cov2d.get_cwf_coeffs( self.coeff, self.ctf_fb, self.ctf_idx, mean_coeff=mean_cov2d, covar_coeff=covar_cov2d, noise_var=0, ) # Calculate CWF coefficients using Batched Cov2D class mean_bcov2d = self.bcov2d.get_mean() covar_bcov2d = self.bcov2d.get_covar(noise_var=self.noise_var, make_psd=True) coeff_bcov2d = self.bcov2d.get_cwf_coeffs( self.coeff, self.ctf_fb, self.ctf_idx, mean_bcov2d, covar_bcov2d, noise_var=0, ) self.assertTrue( self.blk_diag_allclose( coeff_cov2d, coeff_bcov2d, atol=utest_tolerance(self.dtype), ))
def testZeroMean(self): # Make sure it works with zero mean (pure second moment). zero_coeff = np.zeros((self.basis.count, ), dtype=self.dtype) covar_cov2d = self.cov2d.get_covar(self.coeff, mean_coeff=zero_coeff, ctf_fb=self.ctf_fb, ctf_idx=self.ctf_idx) covar_bcov2d = self.bcov2d.get_covar(mean_coeff=zero_coeff) self.assertTrue( self.blk_diag_allclose(covar_cov2d, covar_bcov2d, atol=utest_tolerance(self.dtype)))
def testFBBasis2DEvaluate_t(self): v = np.load(os.path.join(DATA_DIR, "fbbasis_coefficients_8_8.npy")).T # RCOPT # While FB can accept arrays, prefable to pass FB2D and FFB2D Image instances. img = Image(v.astype(self.dtype)) result = self.basis.evaluate_t(img) self.assertTrue( np.allclose( result, [ 0.10761825, 0.12291151, 0.00836345, -0.0619454, -0.0483326, 0.01053718, 0.03977641, 0.03420101, -0.0060131, -0.02970658, -0.0151334, -0.00017575, -0.03987446, -0.00257069, -0.0006621, -0.00975174, 0.00108047, 0.00072022, 0.00753342, 0.00604493, 0.00024362, -0.01711248, -0.01387371, 0.00112805, 0.02407385, 0.00376325, 0.00081128, 0.00951368, -0.00557536, 0.01087579, 0.00255393, -0.00525156, -0.00839695, 0.00802198, ], atol=utest_tolerance(self.dtype), ) )
def testRotate(self): # Now low res (8x8) had problems; # better with odd (7x7), but still not good. # We'll use a higher res test image. # fh = np.load(os.path.join(DATA_DIR, 'ffbbasis2d_xcoeff_in_8_8.npy'))[:7,:7] # Use a real data volume to generate a clean test image. v = Volume( np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")).astype( np.float64)) src = Simulation(L=v.resolution, n=1, vols=v, dtype=v.dtype) # Extract, this is the original image to transform. x1 = src.images(0, 1) # Rotate 90 degrees in cartesian coordinates. x2 = Image(np.rot90(x1.asnumpy(), axes=(1, 2))) # Express in an FB basis basis = FFBBasis2D((x1.res, ) * 2, dtype=x1.dtype) v1 = basis.evaluate_t(x1) v2 = basis.evaluate_t(x2) v3 = basis.evaluate_t(x1) v4 = basis.evaluate_t(x1) # Reflect in the FB basis space v4 = basis.rotate(v1, 0, refl=[True]) # Rotate in the FB basis space v3 = basis.rotate(v1, 2 * np.pi) v1 = basis.rotate(v1, -np.pi / 2) # Evaluate back into cartesian y1 = basis.evaluate(v1) y2 = basis.evaluate(v2) y3 = basis.evaluate(v3) y4 = basis.evaluate(v4) # Rotate 90 self.assertTrue(np.allclose(y1[0], y2[0], atol=1e-4)) # 2*pi Identity self.assertTrue( np.allclose(x1[0], y3[0], atol=utest_tolerance(self.dtype))) # Refl (flipped using flipud) self.assertTrue(np.allclose(np.flipud(x1[0]), y4[0], atol=1e-4))
def testFBBasis2DExpand(self): v = np.load(os.path.join(DATA_DIR, "fbbasis_coefficients_8_8.npy")).T # RCOPT result = self.basis.expand(v.astype(self.dtype)) self.assertTrue( np.allclose( result, [ 0.10733859, 0.12369094, 0.00644482, -0.05404843, -0.04853046, 0.01098521, 0.03878384, 0.03437965, -0.00643285, -0.02866771, -0.01423133, -0.00225684, -0.03318407, -0.00259706, -0.00059192, -0.00997433, 0.00091912, 0.00119892, 0.00749155, 0.00618865, -0.00081327, -0.01307157, -0.01441606, 0.00290380, 0.02370661, 0.00488806, 0.00147871, 0.00763376, -0.00560620, 0.01051651, 0.00330510, -0.00348652, -0.00042323, 0.01404841, ], atol=utest_tolerance(self.dtype), ) )
def testCWFCoeff(self): # Calculate CWF coefficients using Cov2D base class mean_cov2d = self.cov2d.get_mean(self.coeff, ctf_fb=self.ctf_fb, ctf_idx=self.ctf_idx) covar_cov2d = self.cov2d.get_covar( self.coeff, ctf_fb=self.ctf_fb, ctf_idx=self.ctf_idx, noise_var=self.noise_var, make_psd=True, ) coeff_cov2d = self.cov2d.get_cwf_coeffs( self.coeff, self.ctf_fb, self.ctf_idx, mean_coeff=mean_cov2d, covar_coeff=covar_cov2d, noise_var=self.noise_var, ) # Calculate CWF coefficients using Batched Cov2D class mean_bcov2d = self.bcov2d.get_mean() covar_bcov2d = self.bcov2d.get_covar(noise_var=self.noise_var, make_psd=True) coeff_bcov2d = self.bcov2d.get_cwf_coeffs( self.coeff, self.ctf_fb, self.ctf_idx, mean_bcov2d, covar_bcov2d, noise_var=self.noise_var, ) self.assertTrue( self.blk_diag_allclose( coeff_cov2d, coeff_bcov2d, atol=utest_tolerance(self.dtype), ))
def testShrinkers(self, shrinker): """Test all the shrinkers we know about run without crashing, and check we raise with specific message for unsupporting shrinker arg.""" if shrinker in self.bad_shrinker_inputs: with raises(AssertionError, match="Unsupported shrink method"): _ = self.cov2d.get_covar(self.coeff_clean, covar_est_opt={"shrinker": shrinker}) return results = np.load( os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covar.npy"), allow_pickle=True, ) covar_coeff = self.cov2d.get_covar( self.coeff_clean, covar_est_opt={"shrinker": shrinker}) for im, mat in enumerate(results.tolist()): self.assertTrue( np.allclose(mat, covar_coeff[im], atol=utest_tolerance(self.dtype)))
def testFBBasis3DExpand(self): v = np.load(os.path.join(DATA_DIR, "hbbasis_coefficients_8_8_8.npy")).T result = self.basis.expand(v.astype(self.dtype)) self.assertTrue( np.allclose( result, [ +0.10743660, +0.12346847, +0.00684837, -0.05410818, -0.04840195, +0.01071116, +0.03878536, +0.03437083, -0.00638332, -0.02865552, -0.01425294, -0.00223313, -0.03317134, -0.00261654, -0.00056954, -0.00997264, +0.00091569, +0.00123042, +0.00754713, +0.00606669, -0.00043233, -0.01306626, -0.01443522, +0.00301968, +0.02375521, +0.00477979, +0.00166319, +0.00780333, -0.00601982, +0.01052385, +0.00328666, -0.00336805, -0.00070688, +0.01409127, +0.00127259, -0.01289172, +0.00234488, -0.00630249, +0.00117541, -0.02974037, +0.00108834, -0.00823955, +0.00340772, -0.00471875, +0.00266391, -0.00789639, +0.00093529, +0.00160710, -0.00011925, -0.00817443, +0.00046713, -0.01357463, +0.00145920, +0.01452459, -0.00267202, +0.00535952, -0.00322100, +0.00092083, -0.00075300, +0.00509418, +0.00193687, -0.00483399, -0.00204537, +0.00338492, +0.00111248, +0.00194841, +0.00174416, -0.00814324, -0.00839777, +0.00199974, -0.00196156, -0.00014695, -0.00245317, +0.00109957, -0.00146145, +0.00015149, +0.00415232, +0.00121810, +0.00066095, +0.00166167, -0.00231911, -0.00025819, -0.00086808, -0.00074656, +0.00110445, +0.00285573, -0.00014959, +0.00093241, +0.00051144, +0.00004805, +0.00250166, +0.00059104, +0.00066592, +0.00019188, -0.00079074, -0.00068995, -0.00087668, +0.00052913, ], atol=utest_tolerance(self.dtype), ))
def testRadialCTFFilterMultiplierGrid(self): filter = RadialCTFFilter(defocus=2.5e4) * RadialCTFFilter(defocus=2.5e4) result = filter.evaluate_grid(8, dtype=self.dtype) self.assertEqual(result.shape, (8, 8)) self.assertTrue( np.allclose( result, np.array( [ [ 0.461755701877834, -0.995184514498978, 0.063120922443392, 0.833250206225063, 0.961464660252150, 0.833250206225063, 0.063120922443392, -0.995184514498978, ], [ -0.995184514498978, 0.626977423649552, 0.799934516166400, 0.004814348317439, -0.298096205735759, 0.004814348317439, 0.799934516166400, 0.626977423649552, ], [ 0.063120922443392, 0.799934516166400, -0.573061561512667, -0.999286510416273, -0.963805291282899, -0.999286510416273, -0.573061561512667, 0.799934516166400, ], [ 0.833250206225063, 0.004814348317439, -0.999286510416273, -0.633095739808868, -0.368890743119366, -0.633095739808868, -0.999286510416273, 0.004814348317439, ], [ 0.961464660252150, -0.298096205735759, -0.963805291282899, -0.368890743119366, -0.070000000000000, -0.368890743119366, -0.963805291282899, -0.298096205735759, ], [ 0.833250206225063, 0.004814348317439, -0.999286510416273, -0.633095739808868, -0.368890743119366, -0.633095739808868, -0.999286510416273, 0.004814348317439, ], [ 0.063120922443392, 0.799934516166400, -0.573061561512667, -0.999286510416273, -0.963805291282899, -0.999286510416273, -0.573061561512667, 0.799934516166400, ], [ -0.995184514498978, 0.626977423649552, 0.799934516166400, 0.004814348317439, -0.298096205735759, 0.004814348317439, 0.799934516166400, 0.626977423649552, ], ] ) ** 2, atol=utest_tolerance(self.dtype), ) )