def testIncorrectComponents(self): """ Check we raise with inconsistent configuration of FSPCA components. """ with pytest.raises(RuntimeError, match=r"`pca_basis` components.*provided by user."): _ = RIRClass2D( self.clean_src, self.clean_fspca_basis, # 400 components fspca_components=100, large_pca_implementation="legacy", nn_implementation="legacy", bispectrum_implementation="legacy", ) # Explicitly providing the same number should be okay. _ = RIRClass2D( self.clean_src, self.clean_fspca_basis, # 400 components fspca_components=self.clean_fspca_basis.components, large_pca_implementation="legacy", nn_implementation="legacy", bispectrum_implementation="legacy", )
def testImplementations(self): """ Test optional implementations handle bad inputs with a descriptive error. """ # Nearest Neighbhor component with pytest.raises(ValueError, match=r"Provided nn_implementation.*"): _ = RIRClass2D(self.clean_src, self.clean_fspca_basis, nn_implementation="badinput") # Large PCA component with pytest.raises(ValueError, match=r"Provided large_pca_implementation.*"): _ = RIRClass2D( self.clean_src, self.clean_fspca_basis, large_pca_implementation="badinput", ) # Bispectrum component with pytest.raises(ValueError, match=r"Provided bispectrum_implementation.*"): _ = RIRClass2D( self.clean_src, self.clean_fspca_basis, bispectrum_implementation="badinput", ) # Legacy Bispectrum implies legacy bispectrum (they're integrated). with pytest.raises( ValueError, match=r'"legacy" bispectrum_implementation implies.*'): _ = RIRClass2D( self.clean_src, self.clean_fspca_basis, bispectrum_implementation="legacy", large_pca_implementation="sklearn", ) # Currently we only FSPCA Basis in RIRClass2D with pytest.raises( RuntimeError, match= "RIRClass2D has currently only been developed for pca_basis as a FSPCABasis.", ): _ = RIRClass2D(self.clean_src, self.basis)
def testRIRDevelBisp(self): """ Currently just tests for runtime errors. """ # Use the basis class setup, only requires a Source. rir = RIRClass2D( self.clean_src, large_pca_implementation="legacy", nn_implementation="legacy", bispectrum_implementation="devel", ) result = rir.classify() _ = rir.output(*result[:3])
def testComponentSize(self): """ Tests we raise when number of components are too small. Also tests dtype mismatch behavior. """ with pytest.raises( RuntimeError, match=r".*Images too small for Bispectrum Components.*"): _ = RIRClass2D( self.clean_src, self.clean_fspca_basis, bispectrum_components=self.clean_src.n + 1, dtype=np.float64, )
def testRIRLegacy(self): """ Currently just tests for runtime errors. """ clean_fspca_basis = FSPCABasis( self.clean_src, self.basis, noise_var=0, components=100 ) # Note noise_var assigned zero, skips eigval filtering. rir = RIRClass2D( self.clean_src, clean_fspca_basis, large_pca_implementation="legacy", nn_implementation="legacy", bispectrum_implementation="legacy", ) result = rir.classify() _ = rir.output(*result[:3])
def testRIRsk(self): """ Excercises the eigenvalue based filtering, along with other swappable components. Currently just tests for runtime errors. """ rir = RIRClass2D( self.noisy_src, self.noisy_fspca_basis, bispectrum_components=100, sample_n=42, n_classes=5, large_pca_implementation="sklearn", nn_implementation="sklearn", bispectrum_implementation="devel", aligner=BFSRAlign2D(self.noisy_fspca_basis, n_angles=100, n_x_shifts=0), ) result = rir.classify() _ = rir.output(*result[:4])
# Let's peek at the images to make sure they're shuffled up nicely src.images(0, 10).show() # %% # Class Average # ------------- # # We use the ASPIRE ``RIRClass2D`` class to classify the images via the rotationally invariant representation (RIR) # algorithm. We then yield class averages by performing ``classify``. rir = RIRClass2D( src, fspca_components=400, bispectrum_components=300, # Compressed Features after last PCA stage. n_nbor=10, n_classes=10, large_pca_implementation="legacy", nn_implementation="legacy", bispectrum_implementation="legacy", ) classes, reflections, rotations, shifts, corr = rir.classify() # %% # Display Classes # ^^^^^^^^^^^^^^^ avgs = rir.output(classes, reflections, rotations) avgs.images(0, 10).show() # %%