def test_extract_source(self): # Test that extract source returns the correct source hgmca_analysis_maps = {'input_maps_dict':{},'analysis_type':'hgmca', 'scale_int':2,'j_min':1,'j_max':9,'band_lim':128, 'target_fwhm':1.0*self.a2r,'output_nside':128,'m_level':2} n_sources = 5 n_freqs = 6 A_truth = np.random.rand(n_sources*n_freqs).reshape(n_freqs,n_sources) S_truth = np.random.rand(n_sources*10).reshape(n_sources,10) hgmca_analysis_maps['0'] = {'A':np.copy(A_truth), 'S':np.empty((0,0,0))} for level in range(1,hgmca_analysis_maps['m_level']+1): permute = np.random.permutation(n_sources) A_rand = A_truth.T[permute].T S_rand = S_truth[permute] hgmca_analysis_maps[str(level)] = { 'A':np.repeat(A_rand[np.newaxis,:,:], wavelets_hgmca.level_to_npatches(level),axis=0), 'S':np.repeat(S_rand[np.newaxis,:,:], wavelets_hgmca.level_to_npatches(level),axis=0)} A_target = A_truth[:,0] wav_analysis_maps = hgmca_core.extract_source(hgmca_analysis_maps, A_target) for level in range(1,hgmca_analysis_maps['m_level']+1): for patch in range(wavelets_hgmca.level_to_npatches(level)): np.testing.assert_almost_equal( wav_analysis_maps[str(level)][patch], A_truth[0,0]*S_truth[0,:])
def test_int_min_rmse(self): # Make sure that it initializes things well m_level = 4 n_freqs = 5 A_shape = (n_freqs,5) n_sources = 5 X_level = numba.typed.List() X_level.append(np.empty((0,0,0))) A_init = np.random.rand(A_shape[0]*A_shape[1]).reshape(A_shape) A_hier_list = hgmca_core.allocate_A_hier(m_level,A_shape, A_init=A_init) # Initialize X in a way that will return a really clean min # rmse solution given the mixing matrices for level in range(m_level): npatches = wavelets_hgmca.level_to_npatches(level+1) n_wavs = np.random.randint(5,10) X_level.append(np.ones((npatches,n_freqs,n_wavs))) for patch in range(wavelets_hgmca.level_to_npatches(level)): S_true = np.ones((n_sources,n_wavs)) X_level[level+1][patch] = np.dot(A_hier_list[level][patch], S_true) # Make sure the min_rmse code returns very small error. S_level = hgmca_core.allocate_S_level(m_level,X_level,n_sources) hgmca_core.init_min_rmse(X_level,A_hier_list,S_level) for level in range(1,m_level+1): for patch in range(wavelets_hgmca.level_to_npatches(level)): np.testing.assert_almost_equal(X_level[level][patch], np.dot(A_hier_list[level][patch], S_level[level][patch])) return
def test_allocate_A_hier(self): # Test that the right arrays are initialized. m_level = 4 A_shape = (5,5) A_init = None A_hier_list = hgmca_core.allocate_A_hier(m_level,A_shape) for level, A_hier in enumerate(A_hier_list): self.assertEqual(len(A_hier),wavelets_hgmca.level_to_npatches( level)) self.assertTupleEqual(A_hier[0].shape,A_shape) for A_patch in A_hier: np.testing.assert_equal(A_patch, np.ones(A_shape)/np.sqrt(5)) # Repeat the same but with an A_init now A_init = np.random.rand(A_shape[0]*A_shape[1]).reshape(A_shape) A_hier_list = hgmca_core.allocate_A_hier(m_level,A_shape, A_init=A_init) for level, A_hier in enumerate(A_hier_list): self.assertEqual(len(A_hier),wavelets_hgmca.level_to_npatches( level)) self.assertTupleEqual(A_hier[0].shape,A_shape) for A_patch in A_hier: np.testing.assert_equal(A_patch,A_init)
def test_level_to_npatches(self): # Manually test that a few different levels give the right number # of patches level = 0 self.assertEqual(wavelets_hgmca.level_to_npatches(level), 1) level = 1 self.assertEqual(wavelets_hgmca.level_to_npatches(level), 12) level = 2 self.assertEqual(wavelets_hgmca.level_to_npatches(level), 48) level = 3 self.assertEqual(wavelets_hgmca.level_to_npatches(level), 192)
def test_allocate_S_level(self): m_level = 4 n_sources = 5 X_level = numba.typed.List() X_level.append(np.empty((0,0,0))) for level in range(m_level): npatches = wavelets_hgmca.level_to_npatches(level+1) X_level.append(np.ones((npatches,10,np.random.randint(5,10)))) S_level = hgmca_core.allocate_S_level(m_level,X_level,n_sources) self.assertEqual(S_level[0].size,0) for level in range(1,m_level+1): npatches = wavelets_hgmca.level_to_npatches(level) self.assertTupleEqual(S_level[level].shape,(npatches,n_sources, X_level[level].shape[2]))
def allocate_A_hier(m_level, A_shape, A_init=None): """Allocates the hierarchy of mixing matrices for HGMCA analysis Parameters: m_level (int): The deepest level of analysis to consider A_shape (tuple): The shape of the mixing matrix. Should be (n_freqs,n_sources). A_init (np.array): A value at which to initialize all of the matrices in the hierarchy. If None then matrices will be initialized to 1/np.sqrt(n_freqs) Returns [np.array,...]: A numba.typed.List of numpy arrays containing the mixing matrix hierarchy for each level of analysis. """ # Get the initialization value if A_init is None: A = np.ones(A_shape) / np.sqrt(A_shape[0]) else: A = A_init # Initialize all of the matrices A_hier_list = numba.typed.List() for level in range(m_level + 1): npatches = wavelets_hgmca.level_to_npatches(level) # Initialize the array and then set it to the desired value. A_hier = np.zeros((npatches, ) + A_shape) for patch in range(npatches): A_hier[patch] += A A_hier_list.append(A_hier) return A_hier_list
def allocate_S_level(m_level, X_level, n_sources): """Allocates the S matrices for all the levels of analysis. Parameters: m_level (int): The deepest level of analysis to consider X_level ([np.array,...]): A numba.typed.List of numpy arrays corresponding to the data for each level of analysis. n_sources (int): The number of sources Returns [np.array,...]: A numba.typed.List of numpy arrays containing the source matrices for each level of analysis. """ # Initialize our list S_level = numba.typed.List() for level in range(m_level + 1): # If there is no data at that level, initialize a 0x0 array. if X_level[level].size == 0: S_level.append(np.empty((0, 0, 0))) continue # Otherwise initialize an array of zeros. npatches = wavelets_hgmca.level_to_npatches(level) n_wavs = X_level[level].shape[2] S_level.append(np.zeros((npatches, n_sources, n_wavs))) return S_level
def init_min_rmse(X_level, A_hier_list, S_level): """Initializes the source hierarhcy to the minimum RMSE solution given the mixing matrix hierarchy Parameters: X_level (np.array): A numba.typed.List of numpy arrays corresponding to the data for each level of analysis. A_hier_list ([np.array]): A numba.typed.List of numpy arrays containing the mixing matrix hierarchy for each level of analysis. S_level ([np.array,...]):A numba.typed.List of numpy arrays containing the source matrices for each level of analysis. """ for level in range(len(X_level)): # Skip levels with no data if X_level[level].size == 0: continue # Go through each patch and calculate the min rmse. This requires # removing nans, so we will use a temp array to store a version # of X with nans converted to 0s. X_temp = np.zeros(X_level[level][0].shape) for patch in range(wavelets_hgmca.level_to_npatches(level)): X_temp *= 0 X_temp += X_level[level][patch] helpers.nan_to_num(X_temp) np.dot(np.linalg.pinv(A_hier_list[level][patch]), X_temp, out=S_level[level][patch])
def test_save_load_numba_hier_list(self): # Create a A_hier_list and S_level and make sure that # the save load functions behave as expected. save_path = self.root_path m_level = 4 # First check we get a value error if the files aren't # there. with self.assertRaises(ValueError): hgmca_core.load_numba_hier_list(save_path,m_level) # Now just make sure saving and loading preserves the identity # transform. A_shape = (5,5) n_sources = 5 X_level = numba.typed.List() X_level.append(np.empty((0,0,0))) for level in range(m_level): npatches = wavelets_hgmca.level_to_npatches(level+1) X_level.append(np.ones((npatches,10,np.random.randint(5,10)))) A_init = np.random.rand(A_shape[0]*A_shape[1]).reshape(A_shape) A_hier_list = hgmca_core.allocate_A_hier(m_level,A_shape, A_init=A_init) S_level = hgmca_core.allocate_S_level(m_level,X_level,n_sources) hgmca_core.save_numba_hier_lists(A_hier_list,S_level,save_path) A_test, S_test = hgmca_core.load_numba_hier_list(save_path,m_level) for level in range(m_level+1): np.testing.assert_almost_equal(A_test[level],A_hier_list[level]) np.testing.assert_almost_equal(S_test[level],S_level[level]) folder_path = os.path.join(save_path,'hgmca_save') for level in range(m_level+1): os.remove(os.path.join(folder_path,'A_%d.npy'%(level))) os.remove(os.path.join(folder_path,'S_%d.npy'%(level))) os.rmdir(folder_path)
def extract_source(hgmca_analysis_maps, A_target, freq_ind=0): """Modifies wav_analysis_maps to only include the source closest in its frequency dependence to A_target Parameters: hgmca_analysis_maps (dict): A dictionary containing the information about the wavelet functions used for analysis, the original nside of the input map, and the wavelet maps that need to be transformed back to the original healpix space. A_target (np.array): An n_sources long np.array that contains the desired frequency scaling of the target source. freq_ind (int): The frequency to return the source at. Returns: (dict): A dictionary containing the information about the wavelet functions used for analysis, the original nside of the input map, and the wavelet maps that needto be transformed back to the original healpix space. """ wav_analysis_maps = { 'input_maps_dict': hgmca_analysis_maps['input_maps_dict'], 'analysis_type': 'hgmca', 'scale_int': hgmca_analysis_maps['scale_int'], 'j_min': hgmca_analysis_maps['j_min'], 'j_max': hgmca_analysis_maps['j_max'], 'band_lim': hgmca_analysis_maps['band_lim'], 'target_fwhm': hgmca_analysis_maps['target_fwhm'], 'output_nside': hgmca_analysis_maps['output_nside'], 'm_level': hgmca_analysis_maps['m_level'] } # Iterate through each of the levels and pick out the mixing matrix column # that is most similar to the target column. if len(A_target.shape) == 1: A_target = np.expand_dims(A_target, axis=-1) m_level = hgmca_analysis_maps['m_level'] for level in range(m_level + 1): S = hgmca_analysis_maps[str(level)]['S'] if S.size == 0: continue # Allocate the array wav_analysis_maps[str(level)] = np.zeros((S.shape[0], S.shape[2])) # Find the best match for each patch A_hier = hgmca_analysis_maps[str(level)]['A'] target_match = np.argmin(np.sum(np.abs(A_hier - A_target), axis=-2), axis=1) # Calculate the source time the mixing matrix normalization for # each patch. for patch in range(wavelets_hgmca.level_to_npatches(level)): wav_analysis_maps[str(level)][patch] += ( A_hier[patch, freq_ind, target_match[patch]] * S[patch, target_match[patch]]) # Return the new wav_analysis_maps return wav_analysis_maps
def test_multifrequency_wavelet_maps_hgmca(self): # Test that the multifrequency maps agree with our expectations. # We'll do our tests using just two frequencies, and make both # frequencie the same map. We'll change the nside to make sure all of # the funcitonality is working as intended. # Check that the python and s2let output match input_map_path = self.root_path + 'gmca_test_full_sim_90_GHZ.fits' input_map_path_256 = (self.root_path + 'gmca_test_full_sim_90_GHZ_256.fits') input_maps_dict = { '30': { 'band_lim': 256, 'fwhm': 33, 'path': input_map_path, 'nside': 128 }, '44': { 'band_lim': 512, 'fwhm': 5, 'path': input_map_path_256, 'nside': 256 } } output_maps_prefix = self.root_path + 's2dw_test' scale_int = 2 j_min = 1 max_nside = 256 j_max = 9 # Generate the wavelet maps using the python code wav_analysis_maps = self.wav_class.multifrequency_wavelet_maps( input_maps_dict, output_maps_prefix, scale_int, j_min) self.assertEqual(wav_analysis_maps['n_freqs'], 2) # Now directly compare the values in the groupped array to the # values of the original map. input_map = hp.read_map(input_map_path, dtype=np.float64, verbose=False) wavelet_dict = self.wav_class.s2dw_wavelet_tranform( input_map, output_maps_prefix + 't30', input_maps_dict['30']['band_lim'], scale_int, j_min, input_maps_dict['30']['fwhm'], target_fwhm=np.ones(9) * input_maps_dict['44']['fwhm']) # Go patch by patch, starting with the scaling coefficients offset = 0 scale_nside = wavelets_base.get_max_nside(scale_int, j_min, max_nside) n_patches = wavelets_hgmca.level_to_npatches(2) ppp = hp.nside2npix(scale_nside) // n_patches scale_map = hp.read_map(wavelet_dict['scale_map']['path'], nest=True, verbose=False) for patch in range(n_patches): np.testing.assert_almost_equal( scale_map[patch * ppp:(patch + 1) * ppp], wav_analysis_maps['2'][patch][0][:ppp]) offset += ppp # Now repeat the calculation for the rest of the wavelet scales. for j in range(j_min, 6): wav_nside = wavelets_base.get_max_nside(scale_int, j + 1, max_nside) n_patches = wavelets_hgmca.level_to_npatches(2) ppp = hp.nside2npix(wav_nside) // n_patches wav_map = hp.read_map(wavelet_dict['wav_%d_map' % (j)]['path'], nest=True, verbose=False) for patch in range(n_patches): np.testing.assert_almost_equal( wav_map[patch * ppp:(patch + 1) * ppp], wav_analysis_maps['2'][patch][0][offset:offset + ppp]) offset += ppp offset = 0 for j in range(6, j_max): wav_nside = wavelets_base.get_max_nside(scale_int, j + 1, max_nside) n_patches = wavelets_hgmca.level_to_npatches(3) ppp = hp.nside2npix(wav_nside) // n_patches wav_map = hp.ud_grade(hp.read_map(wavelet_dict['wav_%d_map' % (j)]['path'], nest=True, verbose=False), wav_nside, order_in='NESTED', order_out='NESTED') for patch in range(n_patches): np.testing.assert_almost_equal( wav_map[patch * ppp:(patch + 1) * ppp], wav_analysis_maps['3'][patch][0][offset:offset + ppp]) offset += ppp # The nside 128 map shouldn't have any signal for the wavelet scale # j=9. self.assertEqual( np.sum(np.isnan(wav_analysis_maps['3'][:, 0, offset:])), wav_analysis_maps['3'][:, 0, offset:].size) # Now we repeat the same tests for the 256 maps. input_map_256 = hp.read_map(input_map_path_256, dtype=np.float64, verbose=False) wavelet_dict_256 = self.wav_class.s2dw_wavelet_tranform( input_map_256, output_maps_prefix + 't44', input_maps_dict['44']['band_lim'], scale_int, j_min, input_maps_dict['44']['fwhm'], target_fwhm=np.ones(10) * input_maps_dict['44']['fwhm']) offset = 0 scale_nside = wavelets_base.get_max_nside(scale_int, j_min, max_nside) n_patches = wavelets_hgmca.level_to_npatches(2) ppp = hp.nside2npix(scale_nside) // n_patches scale_map = hp.read_map(wavelet_dict_256['scale_map']['path'], nest=True, verbose=False) for patch in range(n_patches): np.testing.assert_almost_equal( scale_map[patch * ppp:(patch + 1) * ppp], wav_analysis_maps['2'][patch][1][:ppp]) offset += ppp # Now repeat the calculation for the rest of the wavelet scales. for j in range(j_min, 6): wav_nside = wavelets_base.get_max_nside(scale_int, j + 1, max_nside) n_patches = wavelets_hgmca.level_to_npatches(2) ppp = hp.nside2npix(wav_nside) // n_patches wav_map = hp.read_map(wavelet_dict_256['wav_%d_map' % (j)]['path'], nest=True, verbose=False) for patch in range(n_patches): np.testing.assert_almost_equal( wav_map[patch * ppp:(patch + 1) * ppp], wav_analysis_maps['2'][patch][1][offset:offset + ppp]) offset += ppp offset = 0 for j in range(6, j_max + 1): wav_nside = wavelets_base.get_max_nside(scale_int, j + 1, max_nside) n_patches = wavelets_hgmca.level_to_npatches(3) ppp = hp.nside2npix(wav_nside) // n_patches wav_map = hp.ud_grade(hp.read_map(wavelet_dict_256['wav_%d_map' % (j)]['path'], nest=True, verbose=False), wav_nside, order_in='NESTED', order_out='NESTED') for patch in range(n_patches): np.testing.assert_almost_equal( wav_map[patch * ppp:(patch + 1) * ppp], wav_analysis_maps['3'][patch][1][offset:offset + ppp]) offset += ppp # Delete all the superfluous maps that have been made for testing os.remove(wavelet_dict['scale_map']['path']) for j in range(j_min, wavelet_dict['j_max'] + 1): os.remove(wavelet_dict['wav_%d_map' % (j)]['path']) os.remove(wavelet_dict['scale_map']['path'][:-16] + '30_scaling.fits') for j in range(j_min, wavelet_dict['j_max'] + 1): os.remove(wavelet_dict['wav_%d_map' % (j)]['path'][:-14] + '30_wav_%d.fits' % (j)) os.remove(wavelet_dict_256['scale_map']['path']) for j in range(j_min, wavelet_dict_256['j_max'] + 1): os.remove(wavelet_dict_256['wav_%d_map' % (j)]['path']) os.remove(wavelet_dict_256['scale_map']['path'][:-16] + '44_scaling.fits') for j in range(j_min, wavelet_dict_256['j_max'] + 1): os.remove(wavelet_dict_256['wav_%d_map' % (j)]['path'][:-14] + '44_wav_%d.fits' % (j))
def hgmca_epoch_numba(X_level, A_hier_list, lam_hier, A_global, lam_global, S_level, n_epochs, m_level, n_iterations, lam_s, seed, enforce_nn_A, min_rmse_rate, epoch_start=0): """ Runs the epoch loop for a given level of hgmca using numba optimization. For an in depth description of the algorithm see arxiv 1910.08077 Parameters: X_level (np.array): A numba.typed.List of numpy arrays corresponding to the data for each level of analysis. A_hier_list ([np.array]): A numba.typed.List of numpy arrays containing the mixing matrix hierarchy for each level of analysis. lam_hier (np.array): A n_sources long array of the prior for each of the columns of the mixing matrices in A_hier. This allows for a source dependent prior. A_global (np.array): A global mixing matrix prior that will be applied to all matrices in the hierarchy lam_global (np.array): A n_sources long array of the prior for each of the columns of the global prior. S_level ([np.array,...]):A numba.typed.List of numpy arrays containing the source matrices for each level of analysis. n_epochs (int): The number of times the algorithm should pass through all the levels. m_level (int): The maximum level of analysis. n_iterations (int): The number of iterations of coordinate descent to conduct per epoch. lam_s (float): The lambda parameter for the sparsity l1 norm. seed (int): An integer to seed the random number generator. enforce_nn_A (bool): A boolean that determines if the mixing matrix will be forced to only have non-negative values. min_rmse_rate (int): How often the source matrix will be set to the minimum rmse solution. 0 will never return min_rmse within the gmca optimization. epoch_start (int): What epoch the code is starting at. Important for min_rmse_rate. Notes: A_hier and S_level will be updated in place. """ # We allocate the arrays that gmca will use here to avoid them being # reallocated R_i_level = numba.typed.List() AS_level = numba.typed.List() A_R_level = numba.typed.List() A_i_level = numba.typed.List() for level in range(m_level + 1): R_i_level.append(np.zeros(X_level[level][0].shape)) AS_level.append(np.zeros(X_level[level][0].shape)) A_R_level.append(np.zeros((1, X_level[level].shape[2]))) A_i_level.append(np.zeros((X_level[level].shape[1], 1))) # Set the random seed. np.random.seed(seed) # Now we iterate through our graphical model using our approximate closed # form solution for the desired number of epochs. for epoch in range(epoch_start, epoch_start + n_epochs): # We want to iterate through the levels in random order. This should # theoretically speed up convergence. level_perm = np.random.permutation(m_level + 1) for level in level_perm: npatches = wavelets_hgmca.level_to_npatches(level) for patch in range(npatches): # Calculate the mixing matrix prior and the number of matrices # used to construct that prior. A_prior = get_A_prior(A_hier_list, level, patch, lam_hier) A_prior += lam_global * A_global # First we deal with the relatively simple case where there are # no sources at this level if X_level[level].size == 0: A_hier_list[level][patch] = A_prior helpers.A_norm(A_hier_list[level][patch]) # If there are sources at this level we must run GMCA else: # Extract the data for the patch. X_p = X_level[level][patch] S_p = S_level[level][patch] # For HGMCA we want to store the source signal that # includes the lasso shooting subtraction of lam_s for # stability of the loss function. Only in the last step do # we want gmca to return the min_rmse solution. if min_rmse_rate == 0: ret_min_rmse = False else: ret_min_rmse = (((epoch + 1) % min_rmse_rate) == 0) # Call gmca for the patch. Note the lam_p has already been # accounted for. n_sources = len(S_p) gmca_core.gmca_numba(X_p, n_sources, n_iterations, A_hier_list[level][patch], S_p, A_p=A_prior, lam_p=np.ones(n_sources), enforce_nn_A=enforce_nn_A, lam_s=lam_s, ret_min_rmse=ret_min_rmse, R_i_init=R_i_level[level], AS_init=AS_level[level], A_R_init=A_R_level[level], A_i_init=A_i_level[level])
def test_hgmca_opt(self): # Generate a quick approximation using the hgmca_opt code and # make sure it gives the same results as the core hgmca code. input_map_path = self.root_path + 'gmca_test_full_sim_90_GHZ.fits' input_maps_dict = { '30':{'band_lim':64,'fwhm':33,'path':input_map_path, 'nside':128}, '44':{'band_lim':64,'fwhm':24,'path':input_map_path, 'nside':128}, '70':{'band_lim':64,'fwhm':14,'path':input_map_path, 'nside':128}, '100':{'band_lim':256,'fwhm':10,'path':input_map_path, 'nside':128}, '143':{'band_lim':256,'fwhm':7.1,'path':input_map_path, 'nside':128}, '217':{'band_lim':256,'fwhm':5.5,'path':input_map_path, 'nside':128}} output_maps_prefix = self.root_path + 'hgmca_test' scale_int = 2 j_min = 1 n_freqs = len(input_maps_dict) wav_analysis_maps = self.wav_class.multifrequency_wavelet_maps( input_maps_dict,output_maps_prefix,scale_int,j_min) # Run hgmca on the wavelet analysis maps n_sources = 5 n_epochs = 5 n_iterations = 2 lam_s = 1e-3 lam_hier = np.random.rand(n_sources) lam_global = np.random.rand(n_sources) A_global = np.random.rand(n_freqs*n_sources).reshape( n_freqs,n_sources) seed = 5 min_rmse_rate = 2 hgmca_dict = hgmca_core.hgmca_opt(wav_analysis_maps,n_sources,n_epochs, lam_hier,lam_s,n_iterations,A_init=None,A_global=A_global, lam_global=lam_global,seed=seed,enforce_nn_A=True, min_rmse_rate=min_rmse_rate,verbose=True) # Allocate what we need for the core code. X_level = hgmca_core.convert_wav_to_X_level(wav_analysis_maps) n_freqs = wav_analysis_maps['n_freqs'] A_shape = (n_freqs,n_sources) A_hier_list = hgmca_core.allocate_A_hier(self.m_level,A_shape, A_init=None) S_level = hgmca_core.allocate_S_level(self.m_level,X_level,n_sources) hgmca_core.init_min_rmse(X_level,A_hier_list,S_level) hgmca_core.hgmca_epoch_numba(X_level,A_hier_list,lam_hier,A_global, lam_global,S_level,n_epochs,self.m_level,n_iterations,lam_s,seed, True,min_rmse_rate) for level in range(self.m_level+1): for patch in range(wavelets_hgmca.level_to_npatches(level)): np.testing.assert_almost_equal(A_hier_list[level][patch], hgmca_dict[str(level)]['A'][patch]) if S_level[level].size == 0: self.assertEqual(hgmca_dict[str(level)]['S'].size,0) else: np.testing.assert_almost_equal(S_level[level][patch], hgmca_dict[str(level)]['S'][patch]) # Check that saving doesn't cause issues n_epochs = 6 save_dict = {'save_path':self.root_path,'save_rate':2} hgmca_dict = hgmca_core.hgmca_opt(wav_analysis_maps,n_sources,n_epochs, lam_hier,lam_s,n_iterations,A_init=None,A_global=A_global, lam_global=lam_global,seed=seed,enforce_nn_A=True, min_rmse_rate=min_rmse_rate,save_dict=save_dict,verbose=True) # Make sure loading works as well. hgmca_dict = hgmca_core.hgmca_opt(wav_analysis_maps,n_sources,n_epochs, lam_hier,lam_s,n_iterations,A_init=None,A_global=A_global, lam_global=lam_global,seed=seed,enforce_nn_A=True, min_rmse_rate=min_rmse_rate,save_dict=save_dict,verbose=True) # Delete all the files we made folder_path = os.path.join(save_dict['save_path'],'hgmca_save') for level in range(self.m_level+1): os.remove(os.path.join(folder_path,'A_%d.npy'%(level))) os.remove(os.path.join(folder_path,'S_%d.npy'%(level))) os.rmdir(folder_path) for freq in input_maps_dict.keys(): os.remove(output_maps_prefix+freq+'_scaling.fits') j_max = wavelets_base.calc_j_max(input_maps_dict[freq]['band_lim'], scale_int) for j in range(j_min,j_max+1): os.remove(output_maps_prefix+freq+'_wav_%d.fits'%(j))
def test_hgmca_epoch_numba(self): # Test that the hgmca optimization returns roughly what we would # expect for a small lam_s n_freqs = 8 n_sources = 5 n_wavs = 256 m_level = 2 lam_s = 1e-6 lam_hier = np.zeros(n_sources) lam_global = np.zeros(n_sources) # The true values s_mag = 100 A_org = np.random.rand(n_freqs*n_sources).reshape((n_freqs,n_sources)) helpers.A_norm(A_org) S_org = np.random.rand(n_sources*n_wavs).reshape((n_sources,n_wavs)) S_org *= s_mag # Allocate what we need X_level = numba.typed.List() X_level.append(np.empty((0,0,0))) # Level 1 npatches = wavelets_hgmca.level_to_npatches(1) X_level.append(np.zeros((npatches,n_freqs,n_wavs))) X_level[1][:] += np.dot(A_org,S_org) # Level 2 npatches = wavelets_hgmca.level_to_npatches(2) X_level.append(np.zeros((npatches,n_freqs,n_wavs))) X_level[2][:] += np.dot(A_org,S_org) # The rest A_hier_list = hgmca_core.allocate_A_hier(m_level,(n_freqs,n_sources)) S_level = hgmca_core.allocate_S_level(m_level,X_level,n_sources) A_global = np.random.rand(n_freqs*n_sources).reshape( n_freqs,n_sources) helpers.A_norm(A_global) # Run hgmca min_rmse_rate = 5 enforce_nn_A = True seed = 5 n_epochs = 5 n_iterations = 30 hgmca_core.hgmca_epoch_numba(X_level,A_hier_list,lam_hier,A_global, lam_global,S_level,n_epochs,m_level,n_iterations,lam_s,seed, enforce_nn_A,min_rmse_rate) for level in range(1,3): for patch in range(wavelets_hgmca.level_to_npatches(level)): np.testing.assert_almost_equal(X_level[level][patch], np.dot(A_hier_list[level][patch],S_level[level][patch]), decimal=1) # Repeat the same but with strong priors. Start with the global prior lam_global = np.ones(n_sources)*1e12 n_epochs = 2 n_iterations = 5 A_hier_list = hgmca_core.allocate_A_hier(m_level,(n_freqs,n_sources)) S_level = hgmca_core.allocate_S_level(m_level,X_level,n_sources) hgmca_core.hgmca_epoch_numba(X_level,A_hier_list,lam_hier,A_global, lam_global,S_level,n_epochs,m_level,n_iterations,lam_s,seed, enforce_nn_A,min_rmse_rate) for level in range(3): for patch in range(wavelets_hgmca.level_to_npatches(level)): np.testing.assert_almost_equal(A_hier_list[level][patch], A_global,decimal=4) # The same test, but now for the hierarchy lam_hier = np.ones(n_sources)*1e12 lam_global = np.zeros(n_sources) n_epochs = 2 n_iterations = 5 A_init = np.random.rand(n_freqs*n_sources).reshape((n_freqs,n_sources)) A_hier_list = hgmca_core.allocate_A_hier(m_level,(n_freqs,n_sources), A_init=A_init) S_level = hgmca_core.allocate_S_level(m_level,X_level,n_sources) for level in range(3): for patch in range(wavelets_hgmca.level_to_npatches(level)): np.testing.assert_almost_equal(A_hier_list[level][patch], A_init,decimal=4)