Пример #1
0
	def test_extract_source(self):
		# Test that extract source returns the correct source
		hgmca_analysis_maps = {'input_maps_dict':{},'analysis_type':'hgmca',
			'scale_int':2,'j_min':1,'j_max':9,'band_lim':128,
			'target_fwhm':1.0*self.a2r,'output_nside':128,'m_level':2}
		n_sources = 5
		n_freqs = 6
		A_truth = np.random.rand(n_sources*n_freqs).reshape(n_freqs,n_sources)
		S_truth = np.random.rand(n_sources*10).reshape(n_sources,10)
		hgmca_analysis_maps['0'] = {'A':np.copy(A_truth),
			'S':np.empty((0,0,0))}
		for level in range(1,hgmca_analysis_maps['m_level']+1):
			permute = np.random.permutation(n_sources)
			A_rand = A_truth.T[permute].T
			S_rand = S_truth[permute]
			hgmca_analysis_maps[str(level)] = {
				'A':np.repeat(A_rand[np.newaxis,:,:],
					wavelets_hgmca.level_to_npatches(level),axis=0),
				'S':np.repeat(S_rand[np.newaxis,:,:],
					wavelets_hgmca.level_to_npatches(level),axis=0)}

		A_target = A_truth[:,0]
		wav_analysis_maps = hgmca_core.extract_source(hgmca_analysis_maps,
			A_target)

		for level in range(1,hgmca_analysis_maps['m_level']+1):
			for patch in range(wavelets_hgmca.level_to_npatches(level)):
				np.testing.assert_almost_equal(
					wav_analysis_maps[str(level)][patch],
					A_truth[0,0]*S_truth[0,:])
Пример #2
0
	def test_int_min_rmse(self):
		# Make sure that it initializes things well
		m_level = 4
		n_freqs = 5
		A_shape = (n_freqs,5)
		n_sources = 5
		X_level = numba.typed.List()
		X_level.append(np.empty((0,0,0)))
		A_init = np.random.rand(A_shape[0]*A_shape[1]).reshape(A_shape)
		A_hier_list = hgmca_core.allocate_A_hier(m_level,A_shape,
			A_init=A_init)
		# Initialize X in a way that will return a really clean min
		# rmse solution given the mixing matrices
		for level in range(m_level):
			npatches = wavelets_hgmca.level_to_npatches(level+1)
			n_wavs = np.random.randint(5,10)
			X_level.append(np.ones((npatches,n_freqs,n_wavs)))
			for patch in range(wavelets_hgmca.level_to_npatches(level)):
				S_true = np.ones((n_sources,n_wavs))
				X_level[level+1][patch] = np.dot(A_hier_list[level][patch],
					S_true)
		# Make sure the min_rmse code returns very small error.
		S_level = hgmca_core.allocate_S_level(m_level,X_level,n_sources)
		hgmca_core.init_min_rmse(X_level,A_hier_list,S_level)
		for level in range(1,m_level+1):
			for patch in range(wavelets_hgmca.level_to_npatches(level)):
				np.testing.assert_almost_equal(X_level[level][patch],
					np.dot(A_hier_list[level][patch],
						S_level[level][patch]))
		return
Пример #3
0
	def test_allocate_A_hier(self):
		# Test that the right arrays are initialized.
		m_level = 4
		A_shape = (5,5)
		A_init = None
		A_hier_list = hgmca_core.allocate_A_hier(m_level,A_shape)

		for level, A_hier in enumerate(A_hier_list):
			self.assertEqual(len(A_hier),wavelets_hgmca.level_to_npatches(
				level))
			self.assertTupleEqual(A_hier[0].shape,A_shape)
			for A_patch in A_hier:
				np.testing.assert_equal(A_patch,
					np.ones(A_shape)/np.sqrt(5))

		# Repeat the same but with an A_init now
		A_init = np.random.rand(A_shape[0]*A_shape[1]).reshape(A_shape)
		A_hier_list = hgmca_core.allocate_A_hier(m_level,A_shape,
			A_init=A_init)
		for level, A_hier in enumerate(A_hier_list):
			self.assertEqual(len(A_hier),wavelets_hgmca.level_to_npatches(
				level))
			self.assertTupleEqual(A_hier[0].shape,A_shape)
			for A_patch in A_hier:
				np.testing.assert_equal(A_patch,A_init)
Пример #4
0
    def test_level_to_npatches(self):
        # Manually test that a few different levels give the right number
        # of patches
        level = 0
        self.assertEqual(wavelets_hgmca.level_to_npatches(level), 1)

        level = 1
        self.assertEqual(wavelets_hgmca.level_to_npatches(level), 12)

        level = 2
        self.assertEqual(wavelets_hgmca.level_to_npatches(level), 48)

        level = 3
        self.assertEqual(wavelets_hgmca.level_to_npatches(level), 192)
Пример #5
0
	def test_allocate_S_level(self):
		m_level = 4
		n_sources = 5
		X_level = numba.typed.List()
		X_level.append(np.empty((0,0,0)))
		for level in range(m_level):
			npatches = wavelets_hgmca.level_to_npatches(level+1)
			X_level.append(np.ones((npatches,10,np.random.randint(5,10))))
		S_level = hgmca_core.allocate_S_level(m_level,X_level,n_sources)

		self.assertEqual(S_level[0].size,0)
		for level in range(1,m_level+1):
			npatches = wavelets_hgmca.level_to_npatches(level)
			self.assertTupleEqual(S_level[level].shape,(npatches,n_sources,
				X_level[level].shape[2]))
Пример #6
0
def allocate_A_hier(m_level, A_shape, A_init=None):
    """Allocates the hierarchy of mixing matrices for HGMCA analysis

	Parameters:
		m_level (int): The deepest level of analysis to consider
		A_shape (tuple): The shape of the mixing matrix. Should be
			(n_freqs,n_sources).
		A_init (np.array): A value at which to initialize all of the
			matrices in the hierarchy. If None then matrices will be
			initialized to 1/np.sqrt(n_freqs)

	Returns
		[np.array,...]: A numba.typed.List of numpy arrays containing the
		mixing matrix hierarchy for each level of analysis.
	"""
    # Get the initialization value
    if A_init is None:
        A = np.ones(A_shape) / np.sqrt(A_shape[0])
    else:
        A = A_init

    # Initialize all of the matrices
    A_hier_list = numba.typed.List()
    for level in range(m_level + 1):
        npatches = wavelets_hgmca.level_to_npatches(level)
        # Initialize the array and then set it to the desired value.
        A_hier = np.zeros((npatches, ) + A_shape)
        for patch in range(npatches):
            A_hier[patch] += A
        A_hier_list.append(A_hier)

    return A_hier_list
Пример #7
0
def allocate_S_level(m_level, X_level, n_sources):
    """Allocates the S matrices for all the levels of analysis.

	Parameters:
		m_level (int): The deepest level of analysis to consider
		X_level ([np.array,...]): A numba.typed.List of numpy arrays
			corresponding to the data for each level of analysis.
		n_sources (int): The number of sources

	Returns
		[np.array,...]: A numba.typed.List of numpy arrays containing the
		source matrices for each level of analysis.
	"""
    # Initialize our list
    S_level = numba.typed.List()
    for level in range(m_level + 1):
        # If there is no data at that level, initialize a 0x0 array.
        if X_level[level].size == 0:
            S_level.append(np.empty((0, 0, 0)))
            continue
        # Otherwise initialize an array of zeros.
        npatches = wavelets_hgmca.level_to_npatches(level)
        n_wavs = X_level[level].shape[2]
        S_level.append(np.zeros((npatches, n_sources, n_wavs)))

    return S_level
Пример #8
0
def init_min_rmse(X_level, A_hier_list, S_level):
    """Initializes the source hierarhcy to the minimum RMSE solution
	given the mixing matrix hierarchy

	Parameters:
		X_level (np.array): A numba.typed.List of numpy arrays corresponding to
			the data for each level of analysis.
		A_hier_list ([np.array]): A numba.typed.List of numpy arrays containing
			the mixing matrix hierarchy for each level of analysis.
		S_level ([np.array,...]):A numba.typed.List of numpy arrays containing
			the source matrices for each level of analysis.
	"""
    for level in range(len(X_level)):
        # Skip levels with no data
        if X_level[level].size == 0:
            continue
        # Go through each patch and calculate the min rmse. This requires
        # removing nans, so we will use a temp array to store a version
        # of X with nans converted to 0s.
        X_temp = np.zeros(X_level[level][0].shape)
        for patch in range(wavelets_hgmca.level_to_npatches(level)):
            X_temp *= 0
            X_temp += X_level[level][patch]
            helpers.nan_to_num(X_temp)
            np.dot(np.linalg.pinv(A_hier_list[level][patch]),
                   X_temp,
                   out=S_level[level][patch])
Пример #9
0
	def test_save_load_numba_hier_list(self):
		# Create a A_hier_list and S_level and make sure that
		# the save load functions behave as expected.
		save_path = self.root_path
		m_level = 4
		# First check we get a value error if the files aren't
		# there.
		with self.assertRaises(ValueError):
			hgmca_core.load_numba_hier_list(save_path,m_level)
		# Now just make sure saving and loading preserves the identity
		# transform.
		A_shape = (5,5)
		n_sources = 5
		X_level = numba.typed.List()
		X_level.append(np.empty((0,0,0)))
		for level in range(m_level):
			npatches = wavelets_hgmca.level_to_npatches(level+1)
			X_level.append(np.ones((npatches,10,np.random.randint(5,10))))
		A_init = np.random.rand(A_shape[0]*A_shape[1]).reshape(A_shape)
		A_hier_list = hgmca_core.allocate_A_hier(m_level,A_shape,
			A_init=A_init)
		S_level = hgmca_core.allocate_S_level(m_level,X_level,n_sources)

		hgmca_core.save_numba_hier_lists(A_hier_list,S_level,save_path)
		A_test, S_test = hgmca_core.load_numba_hier_list(save_path,m_level)

		for level in range(m_level+1):
			np.testing.assert_almost_equal(A_test[level],A_hier_list[level])
			np.testing.assert_almost_equal(S_test[level],S_level[level])

		folder_path = os.path.join(save_path,'hgmca_save')
		for level in range(m_level+1):
			os.remove(os.path.join(folder_path,'A_%d.npy'%(level)))
			os.remove(os.path.join(folder_path,'S_%d.npy'%(level)))
		os.rmdir(folder_path)
Пример #10
0
def extract_source(hgmca_analysis_maps, A_target, freq_ind=0):
    """Modifies wav_analysis_maps to only include the source closest
	in its frequency dependence to A_target

	Parameters:
		hgmca_analysis_maps (dict): A dictionary containing the information
			about the wavelet functions used for analysis, the original nside
			of the input map, and the wavelet maps that need to be transformed
			back to the original healpix space.
		A_target (np.array): An n_sources long np.array that contains the
			desired frequency scaling of the target source.
		freq_ind (int): The frequency to return the source at.
	Returns:
		(dict): A dictionary containing the information about the wavelet
		functions used for analysis, the original nside of the input map, and
		the wavelet maps that needto be transformed back to the original
		healpix space.
	"""
    wav_analysis_maps = {
        'input_maps_dict': hgmca_analysis_maps['input_maps_dict'],
        'analysis_type': 'hgmca',
        'scale_int': hgmca_analysis_maps['scale_int'],
        'j_min': hgmca_analysis_maps['j_min'],
        'j_max': hgmca_analysis_maps['j_max'],
        'band_lim': hgmca_analysis_maps['band_lim'],
        'target_fwhm': hgmca_analysis_maps['target_fwhm'],
        'output_nside': hgmca_analysis_maps['output_nside'],
        'm_level': hgmca_analysis_maps['m_level']
    }
    # Iterate through each of the levels and pick out the mixing matrix column
    # that is most similar to the target column.
    if len(A_target.shape) == 1:
        A_target = np.expand_dims(A_target, axis=-1)
    m_level = hgmca_analysis_maps['m_level']
    for level in range(m_level + 1):
        S = hgmca_analysis_maps[str(level)]['S']
        if S.size == 0:
            continue
        # Allocate the array
        wav_analysis_maps[str(level)] = np.zeros((S.shape[0], S.shape[2]))
        # Find the best match for each patch
        A_hier = hgmca_analysis_maps[str(level)]['A']
        target_match = np.argmin(np.sum(np.abs(A_hier - A_target), axis=-2),
                                 axis=1)
        # Calculate the source time the mixing matrix normalization for
        # each patch.
        for patch in range(wavelets_hgmca.level_to_npatches(level)):
            wav_analysis_maps[str(level)][patch] += (
                A_hier[patch, freq_ind, target_match[patch]] *
                S[patch, target_match[patch]])

    # Return the new wav_analysis_maps
    return wav_analysis_maps
Пример #11
0
    def test_multifrequency_wavelet_maps_hgmca(self):
        # Test that the multifrequency maps agree with our expectations.
        # We'll do our tests using just two frequencies, and make both
        # frequencie the same map. We'll change the nside to make sure all of
        # the funcitonality is working as intended.
        # Check that the python and s2let output match
        input_map_path = self.root_path + 'gmca_test_full_sim_90_GHZ.fits'
        input_map_path_256 = (self.root_path +
                              'gmca_test_full_sim_90_GHZ_256.fits')
        input_maps_dict = {
            '30': {
                'band_lim': 256,
                'fwhm': 33,
                'path': input_map_path,
                'nside': 128
            },
            '44': {
                'band_lim': 512,
                'fwhm': 5,
                'path': input_map_path_256,
                'nside': 256
            }
        }
        output_maps_prefix = self.root_path + 's2dw_test'
        scale_int = 2
        j_min = 1
        max_nside = 256
        j_max = 9

        # Generate the wavelet maps using the python code
        wav_analysis_maps = self.wav_class.multifrequency_wavelet_maps(
            input_maps_dict, output_maps_prefix, scale_int, j_min)
        self.assertEqual(wav_analysis_maps['n_freqs'], 2)

        # Now directly compare the values in the groupped array to the
        # values of the original map.
        input_map = hp.read_map(input_map_path,
                                dtype=np.float64,
                                verbose=False)
        wavelet_dict = self.wav_class.s2dw_wavelet_tranform(
            input_map,
            output_maps_prefix + 't30',
            input_maps_dict['30']['band_lim'],
            scale_int,
            j_min,
            input_maps_dict['30']['fwhm'],
            target_fwhm=np.ones(9) * input_maps_dict['44']['fwhm'])

        # Go patch by patch, starting with the scaling coefficients
        offset = 0
        scale_nside = wavelets_base.get_max_nside(scale_int, j_min, max_nside)
        n_patches = wavelets_hgmca.level_to_npatches(2)
        ppp = hp.nside2npix(scale_nside) // n_patches
        scale_map = hp.read_map(wavelet_dict['scale_map']['path'],
                                nest=True,
                                verbose=False)
        for patch in range(n_patches):
            np.testing.assert_almost_equal(
                scale_map[patch * ppp:(patch + 1) * ppp],
                wav_analysis_maps['2'][patch][0][:ppp])
        offset += ppp

        # Now repeat the calculation for the rest of the wavelet scales.
        for j in range(j_min, 6):
            wav_nside = wavelets_base.get_max_nside(scale_int, j + 1,
                                                    max_nside)
            n_patches = wavelets_hgmca.level_to_npatches(2)
            ppp = hp.nside2npix(wav_nside) // n_patches
            wav_map = hp.read_map(wavelet_dict['wav_%d_map' % (j)]['path'],
                                  nest=True,
                                  verbose=False)
            for patch in range(n_patches):
                np.testing.assert_almost_equal(
                    wav_map[patch * ppp:(patch + 1) * ppp],
                    wav_analysis_maps['2'][patch][0][offset:offset + ppp])
            offset += ppp

        offset = 0
        for j in range(6, j_max):
            wav_nside = wavelets_base.get_max_nside(scale_int, j + 1,
                                                    max_nside)
            n_patches = wavelets_hgmca.level_to_npatches(3)
            ppp = hp.nside2npix(wav_nside) // n_patches
            wav_map = hp.ud_grade(hp.read_map(wavelet_dict['wav_%d_map' %
                                                           (j)]['path'],
                                              nest=True,
                                              verbose=False),
                                  wav_nside,
                                  order_in='NESTED',
                                  order_out='NESTED')
            for patch in range(n_patches):
                np.testing.assert_almost_equal(
                    wav_map[patch * ppp:(patch + 1) * ppp],
                    wav_analysis_maps['3'][patch][0][offset:offset + ppp])
            offset += ppp

        # The nside 128 map shouldn't have any signal for the wavelet scale
        # j=9.
        self.assertEqual(
            np.sum(np.isnan(wav_analysis_maps['3'][:, 0, offset:])),
            wav_analysis_maps['3'][:, 0, offset:].size)

        # Now we repeat the same tests for the 256 maps.
        input_map_256 = hp.read_map(input_map_path_256,
                                    dtype=np.float64,
                                    verbose=False)
        wavelet_dict_256 = self.wav_class.s2dw_wavelet_tranform(
            input_map_256,
            output_maps_prefix + 't44',
            input_maps_dict['44']['band_lim'],
            scale_int,
            j_min,
            input_maps_dict['44']['fwhm'],
            target_fwhm=np.ones(10) * input_maps_dict['44']['fwhm'])
        offset = 0
        scale_nside = wavelets_base.get_max_nside(scale_int, j_min, max_nside)
        n_patches = wavelets_hgmca.level_to_npatches(2)
        ppp = hp.nside2npix(scale_nside) // n_patches
        scale_map = hp.read_map(wavelet_dict_256['scale_map']['path'],
                                nest=True,
                                verbose=False)
        for patch in range(n_patches):
            np.testing.assert_almost_equal(
                scale_map[patch * ppp:(patch + 1) * ppp],
                wav_analysis_maps['2'][patch][1][:ppp])
        offset += ppp

        # Now repeat the calculation for the rest of the wavelet scales.
        for j in range(j_min, 6):
            wav_nside = wavelets_base.get_max_nside(scale_int, j + 1,
                                                    max_nside)
            n_patches = wavelets_hgmca.level_to_npatches(2)
            ppp = hp.nside2npix(wav_nside) // n_patches
            wav_map = hp.read_map(wavelet_dict_256['wav_%d_map' % (j)]['path'],
                                  nest=True,
                                  verbose=False)
            for patch in range(n_patches):
                np.testing.assert_almost_equal(
                    wav_map[patch * ppp:(patch + 1) * ppp],
                    wav_analysis_maps['2'][patch][1][offset:offset + ppp])
            offset += ppp

        offset = 0
        for j in range(6, j_max + 1):
            wav_nside = wavelets_base.get_max_nside(scale_int, j + 1,
                                                    max_nside)
            n_patches = wavelets_hgmca.level_to_npatches(3)
            ppp = hp.nside2npix(wav_nside) // n_patches
            wav_map = hp.ud_grade(hp.read_map(wavelet_dict_256['wav_%d_map' %
                                                               (j)]['path'],
                                              nest=True,
                                              verbose=False),
                                  wav_nside,
                                  order_in='NESTED',
                                  order_out='NESTED')
            for patch in range(n_patches):
                np.testing.assert_almost_equal(
                    wav_map[patch * ppp:(patch + 1) * ppp],
                    wav_analysis_maps['3'][patch][1][offset:offset + ppp])
            offset += ppp

        # Delete all the superfluous maps that have been made for testing
        os.remove(wavelet_dict['scale_map']['path'])
        for j in range(j_min, wavelet_dict['j_max'] + 1):
            os.remove(wavelet_dict['wav_%d_map' % (j)]['path'])

        os.remove(wavelet_dict['scale_map']['path'][:-16] + '30_scaling.fits')
        for j in range(j_min, wavelet_dict['j_max'] + 1):
            os.remove(wavelet_dict['wav_%d_map' % (j)]['path'][:-14] +
                      '30_wav_%d.fits' % (j))

        os.remove(wavelet_dict_256['scale_map']['path'])
        for j in range(j_min, wavelet_dict_256['j_max'] + 1):
            os.remove(wavelet_dict_256['wav_%d_map' % (j)]['path'])

        os.remove(wavelet_dict_256['scale_map']['path'][:-16] +
                  '44_scaling.fits')
        for j in range(j_min, wavelet_dict_256['j_max'] + 1):
            os.remove(wavelet_dict_256['wav_%d_map' % (j)]['path'][:-14] +
                      '44_wav_%d.fits' % (j))
Пример #12
0
def hgmca_epoch_numba(X_level,
                      A_hier_list,
                      lam_hier,
                      A_global,
                      lam_global,
                      S_level,
                      n_epochs,
                      m_level,
                      n_iterations,
                      lam_s,
                      seed,
                      enforce_nn_A,
                      min_rmse_rate,
                      epoch_start=0):
    """ Runs the epoch loop for a given level of hgmca using numba optimization.

	For an in depth description of the algorithm see arxiv 1910.08077

	Parameters:
		X_level (np.array): A numba.typed.List of numpy arrays corresponding to
		the data for each level of analysis.
		A_hier_list ([np.array]): A numba.typed.List of numpy arrays containing
			the mixing matrix hierarchy for each level of analysis.
		lam_hier (np.array): A n_sources long array of the prior for each of
			the columns of the mixing matrices in A_hier. This allows for a
			source dependent prior.
		A_global (np.array): A global mixing matrix prior that will be
			applied to all matrices in the hierarchy
		lam_global (np.array): A n_sources long array of the prior for each
			of the columns of the global prior.
		S_level ([np.array,...]):A numba.typed.List of numpy arrays containing
			the source matrices for each level of analysis.
		n_epochs (int): The number of times the algorithm should pass
			through all the levels.
		m_level (int): The maximum level of analysis.
		n_iterations (int): The number of iterations of coordinate descent
			to conduct per epoch.
		lam_s (float): The lambda parameter for the sparsity l1 norm.
		seed (int): An integer to seed the random number generator.
		enforce_nn_A (bool): A boolean that determines if the mixing matrix
			will be forced to only have non-negative values.
		min_rmse_rate (int): How often the source matrix will be set to the
			minimum rmse solution. 0 will never return min_rmse within the
			gmca optimization.
		epoch_start (int): What epoch the code is starting at. Important
			for min_rmse_rate.

	Notes:
		A_hier and S_level will be updated in place.
	"""
    # We allocate the arrays that gmca will use here to avoid them being
    # reallocated
    R_i_level = numba.typed.List()
    AS_level = numba.typed.List()
    A_R_level = numba.typed.List()
    A_i_level = numba.typed.List()
    for level in range(m_level + 1):
        R_i_level.append(np.zeros(X_level[level][0].shape))
        AS_level.append(np.zeros(X_level[level][0].shape))
        A_R_level.append(np.zeros((1, X_level[level].shape[2])))
        A_i_level.append(np.zeros((X_level[level].shape[1], 1)))
    # Set the random seed.
    np.random.seed(seed)
    # Now we iterate through our graphical model using our approximate closed
    # form solution for the desired number of epochs.
    for epoch in range(epoch_start, epoch_start + n_epochs):
        # We want to iterate through the levels in random order. This should
        # theoretically speed up convergence.
        level_perm = np.random.permutation(m_level + 1)
        for level in level_perm:
            npatches = wavelets_hgmca.level_to_npatches(level)
            for patch in range(npatches):
                # Calculate the mixing matrix prior and the number of matrices
                # used to construct that prior.
                A_prior = get_A_prior(A_hier_list, level, patch, lam_hier)
                A_prior += lam_global * A_global
                # First we deal with the relatively simple case where there are
                # no sources at this level
                if X_level[level].size == 0:
                    A_hier_list[level][patch] = A_prior
                    helpers.A_norm(A_hier_list[level][patch])
                # If there are sources at this level we must run GMCA
                else:
                    # Extract the data for the patch.
                    X_p = X_level[level][patch]
                    S_p = S_level[level][patch]
                    # For HGMCA we want to store the source signal that
                    # includes the lasso shooting subtraction of lam_s for
                    # stability of the loss function. Only in the last step do
                    # we want gmca to return the min_rmse solution.
                    if min_rmse_rate == 0:
                        ret_min_rmse = False
                    else:
                        ret_min_rmse = (((epoch + 1) % min_rmse_rate) == 0)

                    # Call gmca for the patch. Note the lam_p has already been
                    # accounted for.
                    n_sources = len(S_p)
                    gmca_core.gmca_numba(X_p,
                                         n_sources,
                                         n_iterations,
                                         A_hier_list[level][patch],
                                         S_p,
                                         A_p=A_prior,
                                         lam_p=np.ones(n_sources),
                                         enforce_nn_A=enforce_nn_A,
                                         lam_s=lam_s,
                                         ret_min_rmse=ret_min_rmse,
                                         R_i_init=R_i_level[level],
                                         AS_init=AS_level[level],
                                         A_R_init=A_R_level[level],
                                         A_i_init=A_i_level[level])
Пример #13
0
	def test_hgmca_opt(self):
		# Generate a quick approximation using the hgmca_opt code and
		# make sure it gives the same results as the core hgmca code.
		input_map_path = self.root_path + 'gmca_test_full_sim_90_GHZ.fits'
		input_maps_dict = {
			'30':{'band_lim':64,'fwhm':33,'path':input_map_path,
				'nside':128},
			'44':{'band_lim':64,'fwhm':24,'path':input_map_path,
				'nside':128},
			'70':{'band_lim':64,'fwhm':14,'path':input_map_path,
				'nside':128},
			'100':{'band_lim':256,'fwhm':10,'path':input_map_path,
				'nside':128},
			'143':{'band_lim':256,'fwhm':7.1,'path':input_map_path,
				'nside':128},
			'217':{'band_lim':256,'fwhm':5.5,'path':input_map_path,
				'nside':128}}
		output_maps_prefix = self.root_path + 'hgmca_test'
		scale_int = 2
		j_min = 1
		n_freqs = len(input_maps_dict)
		wav_analysis_maps = self.wav_class.multifrequency_wavelet_maps(
			input_maps_dict,output_maps_prefix,scale_int,j_min)

		# Run hgmca on the wavelet analysis maps
		n_sources = 5
		n_epochs = 5
		n_iterations = 2
		lam_s = 1e-3
		lam_hier = np.random.rand(n_sources)
		lam_global = np.random.rand(n_sources)
		A_global = np.random.rand(n_freqs*n_sources).reshape(
			n_freqs,n_sources)
		seed = 5
		min_rmse_rate = 2
		hgmca_dict = hgmca_core.hgmca_opt(wav_analysis_maps,n_sources,n_epochs,
			lam_hier,lam_s,n_iterations,A_init=None,A_global=A_global,
			lam_global=lam_global,seed=seed,enforce_nn_A=True,
			min_rmse_rate=min_rmse_rate,verbose=True)

		# Allocate what we need for the core code.
		X_level = hgmca_core.convert_wav_to_X_level(wav_analysis_maps)
		n_freqs = wav_analysis_maps['n_freqs']
		A_shape = (n_freqs,n_sources)
		A_hier_list = hgmca_core.allocate_A_hier(self.m_level,A_shape,
			A_init=None)
		S_level = hgmca_core.allocate_S_level(self.m_level,X_level,n_sources)
		hgmca_core.init_min_rmse(X_level,A_hier_list,S_level)
		hgmca_core.hgmca_epoch_numba(X_level,A_hier_list,lam_hier,A_global,
			lam_global,S_level,n_epochs,self.m_level,n_iterations,lam_s,seed,
			True,min_rmse_rate)

		for level in range(self.m_level+1):
			for patch in range(wavelets_hgmca.level_to_npatches(level)):
				np.testing.assert_almost_equal(A_hier_list[level][patch],
					hgmca_dict[str(level)]['A'][patch])
				if S_level[level].size == 0:
					self.assertEqual(hgmca_dict[str(level)]['S'].size,0)
				else:
					np.testing.assert_almost_equal(S_level[level][patch],
						hgmca_dict[str(level)]['S'][patch])

		# Check that saving doesn't cause issues
		n_epochs = 6
		save_dict = {'save_path':self.root_path,'save_rate':2}
		hgmca_dict = hgmca_core.hgmca_opt(wav_analysis_maps,n_sources,n_epochs,
			lam_hier,lam_s,n_iterations,A_init=None,A_global=A_global,
			lam_global=lam_global,seed=seed,enforce_nn_A=True,
			min_rmse_rate=min_rmse_rate,save_dict=save_dict,verbose=True)
		# Make sure loading works as well.
		hgmca_dict = hgmca_core.hgmca_opt(wav_analysis_maps,n_sources,n_epochs,
			lam_hier,lam_s,n_iterations,A_init=None,A_global=A_global,
			lam_global=lam_global,seed=seed,enforce_nn_A=True,
			min_rmse_rate=min_rmse_rate,save_dict=save_dict,verbose=True)

		# Delete all the files we made
		folder_path = os.path.join(save_dict['save_path'],'hgmca_save')
		for level in range(self.m_level+1):
			os.remove(os.path.join(folder_path,'A_%d.npy'%(level)))
			os.remove(os.path.join(folder_path,'S_%d.npy'%(level)))
		os.rmdir(folder_path)

		for freq in input_maps_dict.keys():
			os.remove(output_maps_prefix+freq+'_scaling.fits')
			j_max = wavelets_base.calc_j_max(input_maps_dict[freq]['band_lim'],
				scale_int)
			for j in range(j_min,j_max+1):
				os.remove(output_maps_prefix+freq+'_wav_%d.fits'%(j))
Пример #14
0
	def test_hgmca_epoch_numba(self):
		# Test that the hgmca optimization returns roughly what we would
		# expect for a small lam_s
		n_freqs = 8
		n_sources = 5
		n_wavs = 256
		m_level = 2
		lam_s = 1e-6
		lam_hier = np.zeros(n_sources)
		lam_global = np.zeros(n_sources)

		# The true values
		s_mag = 100
		A_org = np.random.rand(n_freqs*n_sources).reshape((n_freqs,n_sources))
		helpers.A_norm(A_org)
		S_org = np.random.rand(n_sources*n_wavs).reshape((n_sources,n_wavs))
		S_org *= s_mag

		# Allocate what we need
		X_level = numba.typed.List()
		X_level.append(np.empty((0,0,0)))

		# Level 1
		npatches = wavelets_hgmca.level_to_npatches(1)
		X_level.append(np.zeros((npatches,n_freqs,n_wavs)))
		X_level[1][:] += np.dot(A_org,S_org)

		# Level 2
		npatches = wavelets_hgmca.level_to_npatches(2)
		X_level.append(np.zeros((npatches,n_freqs,n_wavs)))
		X_level[2][:] += np.dot(A_org,S_org)

		# The rest
		A_hier_list = hgmca_core.allocate_A_hier(m_level,(n_freqs,n_sources))
		S_level = hgmca_core.allocate_S_level(m_level,X_level,n_sources)
		A_global = np.random.rand(n_freqs*n_sources).reshape(
			n_freqs,n_sources)
		helpers.A_norm(A_global)

		# Run hgmca
		min_rmse_rate = 5
		enforce_nn_A = True
		seed = 5
		n_epochs = 5
		n_iterations = 30
		hgmca_core.hgmca_epoch_numba(X_level,A_hier_list,lam_hier,A_global,
			lam_global,S_level,n_epochs,m_level,n_iterations,lam_s,seed,
			enforce_nn_A,min_rmse_rate)

		for level in range(1,3):
			for patch in range(wavelets_hgmca.level_to_npatches(level)):
				np.testing.assert_almost_equal(X_level[level][patch],
					np.dot(A_hier_list[level][patch],S_level[level][patch]),
					decimal=1)

		# Repeat the same but with strong priors. Start with the global prior
		lam_global = np.ones(n_sources)*1e12
		n_epochs = 2
		n_iterations = 5
		A_hier_list = hgmca_core.allocate_A_hier(m_level,(n_freqs,n_sources))
		S_level = hgmca_core.allocate_S_level(m_level,X_level,n_sources)
		hgmca_core.hgmca_epoch_numba(X_level,A_hier_list,lam_hier,A_global,
			lam_global,S_level,n_epochs,m_level,n_iterations,lam_s,seed,
			enforce_nn_A,min_rmse_rate)
		for level in range(3):
			for patch in range(wavelets_hgmca.level_to_npatches(level)):
				np.testing.assert_almost_equal(A_hier_list[level][patch],
					A_global,decimal=4)

		# The same test, but now for the hierarchy
		lam_hier = np.ones(n_sources)*1e12
		lam_global = np.zeros(n_sources)
		n_epochs = 2
		n_iterations = 5
		A_init = np.random.rand(n_freqs*n_sources).reshape((n_freqs,n_sources))
		A_hier_list = hgmca_core.allocate_A_hier(m_level,(n_freqs,n_sources),
			A_init=A_init)
		S_level = hgmca_core.allocate_S_level(m_level,X_level,n_sources)
		for level in range(3):
			for patch in range(wavelets_hgmca.level_to_npatches(level)):
				np.testing.assert_almost_equal(A_hier_list[level][patch],
					A_init,decimal=4)