Beispiel #1
0
    def test_s2dw_harmonic(self):
        # Test that the harmonic space representation of the wavelet and
        # scaling function kernels are correctly calculated.
        scale_int = 2
        band_lim = 1024
        j_max = wavelets_base.calc_j_max(band_lim, scale_int)
        phi2 = np.zeros((j_max + 2) * band_lim)
        n_quads = 1000
        j_min = 2
        wavelets_base.phi2_s2dw(phi2, band_lim, scale_int, n_quads)

        wav_har = np.zeros((j_max + 2) * band_lim)
        scale_har = np.zeros(band_lim)

        wavelets_base.s2dw_harmonic(wav_har, scale_har, band_lim, scale_int,
                                    n_quads, j_min)

        # Test that the scales below j_min are set to zero
        np.testing.assert_equal(wav_har[:j_min * band_lim],
                                np.zeros(j_min * band_lim))

        # Test the scales above j_min by hand
        for j in range(j_min, j_max + 1):
            true_values = (phi2[(j + 1) * band_lim:(j + 2) * band_lim] -
                           phi2[j * band_lim:(j + 1) * band_lim])
            true_values[true_values < 0] = 0
            true_values = np.sqrt(true_values)
            ell = np.arange(band_lim)
            true_values *= np.sqrt(2 * ell + 1) / np.sqrt(4 * np.pi)
            np.testing.assert_almost_equal(
                wav_har[j * band_lim:(j + 1) * band_lim], true_values)
Beispiel #2
0
    def test_phi2_s2dw(self):
        # Test that the phi2_s2dw integral behaves as it should.
        scale_int = 2
        band_lim = 1024
        j_max = wavelets_base.calc_j_max(band_lim, scale_int)
        phi2 = np.zeros((j_max + 2) * band_lim)
        n_quads = 1000
        norm = wavelets_base.kappa_integral(1.0 / scale_int, 1.0, n_quads,
                                            scale_int)

        wavelets_base.phi2_s2dw(phi2, band_lim, scale_int, n_quads)
        self.assertAlmostEqual(np.max(phi2), 1)
        self.assertEqual(np.min(phi2), 0)

        # We'll go an manually test the values for a few scales
        correct_values = np.zeros(band_lim)
        correct_values[0] = 1
        np.testing.assert_almost_equal(phi2[0:band_lim], correct_values)

        correct_values = np.zeros(band_lim)
        correct_values[0:2] = 1
        np.testing.assert_almost_equal(phi2[band_lim:2 * band_lim],
                                       correct_values)

        correct_values = np.zeros(band_lim)
        correct_values[0:2] = 1
        correct_values[2] = wavelets_base.kappa_integral(
            2 / 4, 1.0, n_quads, scale_int) / norm
        correct_values[3] = wavelets_base.kappa_integral(
            3 / 4, 1.0, n_quads, scale_int) / norm
        np.testing.assert_almost_equal(phi2[2 * band_lim:3 * band_lim],
                                       correct_values)

        correct_values = np.ones(band_lim)
        np.testing.assert_almost_equal(phi2[(j_max + 1) * band_lim:],
                                       correct_values)
Beispiel #3
0
 def test_j_max(self):
     # Test that j_max returns the expected results
     self.assertEqual(wavelets_base.calc_j_max(128, 2), 7)
     self.assertEqual(wavelets_base.calc_j_max(256, 2), 8)
     self.assertEqual(wavelets_base.calc_j_max(255, 2), 8)
     self.assertEqual(wavelets_base.calc_j_max(2048, 2), 11)
Beispiel #4
0
    def test_s2dw_wavelet_tranform(self):
        # Check that the python and s2let output match
        input_map_path = self.root_path + 'gmca_test_full_sim_90_GHZ.fits'
        input_map = hp.read_map(input_map_path, dtype=np.float64)
        output_map_prefix = self.root_path + 's2dw_test'
        band_lim = 256
        scale_int = 3
        j_min = 1
        input_fwhm = 1e-10

        # Generate the wavelet maps using the python code
        wavelet_dict = self.wav_class.s2dw_wavelet_tranform(input_map,
                                                            output_map_prefix,
                                                            band_lim,
                                                            scale_int,
                                                            j_min,
                                                            input_fwhm,
                                                            n_quads=1000)

        # Check the properties we want are set
        self.assertEqual(wavelet_dict['band_lim'], band_lim)
        self.assertEqual(wavelet_dict['scale_int'], scale_int)
        self.assertEqual(wavelet_dict['j_min'], j_min)
        self.assertEqual(wavelet_dict['j_max'],
                         wavelets_base.calc_j_max(band_lim, scale_int))
        self.assertEqual(wavelet_dict['original_nside'], 128)
        self.assertEqual(wavelet_dict['input_fwhm'], input_fwhm)
        self.assertEqual(
            wavelet_dict['n_scales'],
            wavelets_base.calc_j_max(band_lim, scale_int) - j_min + 2)
        np.testing.assert_equal(wavelet_dict['target_fwhm'],
                                np.ones(wavelet_dict['n_scales'] + 1) * 1e-10)

        # Compare the output to the s2let C code output
        scaling_python = hp.read_map(wavelet_dict['scale_map']['path'],
                                     nest=True)
        scaling_s2let = hp.reorder(
            hp.read_map(self.root_path +
                        'gmca_full_90_GHZ_wav_scal_256_3_1.fits'),
            r2n=True)
        self.assertLess(np.mean(np.abs(scaling_python - scaling_s2let)), 0.001)

        # Repeat the same for the wavelet maps
        for j in range(j_min, wavelet_dict['j_max'] + 1):
            wav_python = hp.read_map(wavelet_dict['wav_%d_map' % (j)]['path'],
                                     nest=True)
            wav_s2let = hp.reorder(
                hp.read_map(self.root_path +
                            'gmca_full_90_GHZ_wav_wav_256_3_1_%d.fits' % (j)),
                r2n=True)
            self.assertLess(np.mean(np.abs(wav_python - wav_s2let)), 0.01)

        # Repeat the same test with the precomputed flag
        wavelet_dict = self.wav_class.s2dw_wavelet_tranform(input_map,
                                                            output_map_prefix,
                                                            band_lim,
                                                            scale_int,
                                                            j_min,
                                                            input_fwhm,
                                                            n_quads=1000,
                                                            precomputed=True)

        # Check the properties we want are set
        self.assertEqual(wavelet_dict['band_lim'], band_lim)
        self.assertEqual(wavelet_dict['scale_int'], scale_int)
        self.assertEqual(wavelet_dict['j_min'], j_min)
        self.assertEqual(wavelet_dict['j_max'],
                         wavelets_base.calc_j_max(band_lim, scale_int))
        self.assertEqual(wavelet_dict['original_nside'], 128)
        self.assertEqual(wavelet_dict['input_fwhm'], input_fwhm)
        self.assertEqual(
            wavelet_dict['n_scales'],
            wavelets_base.calc_j_max(band_lim, scale_int) - j_min + 2)
        np.testing.assert_equal(wavelet_dict['target_fwhm'],
                                np.ones(wavelet_dict['n_scales'] + 1) * 1e-10)

        # Compare the output to the s2let C code output
        scaling_python = hp.read_map(wavelet_dict['scale_map']['path'],
                                     nest=True)
        scaling_s2let = hp.reorder(
            hp.read_map(self.root_path +
                        'gmca_full_90_GHZ_wav_scal_256_3_1.fits'),
            r2n=True)
        self.assertLess(np.mean(np.abs(scaling_python - scaling_s2let)), 0.001)

        # Repeat the same for the wavelet maps
        for j in range(j_min, wavelet_dict['j_max'] + 1):
            wav_python = hp.read_map(wavelet_dict['wav_%d_map' % (j)]['path'],
                                     nest=True)
            wav_s2let = hp.reorder(
                hp.read_map(self.root_path +
                            'gmca_full_90_GHZ_wav_wav_256_3_1_%d.fits' % (j)),
                r2n=True)
            self.assertLess(np.mean(np.abs(wav_python - wav_s2let)), 0.01)

        # Remove all of the maps we created.
        os.remove(wavelet_dict['scale_map']['path'])
        for j in range(j_min, wavelet_dict['j_max'] + 1):
            os.remove(wavelet_dict['wav_%d_map' % (j)]['path'])

        # Now we also want to make sure that setting a target beam behaves
        # as we expect.
        # 1 arcmin input assumed (small)
        input_fwhm = 1

        # Try outputting at two different resolutions and confirm that
        # things behave as expected
        target_fwhm_big = np.ones(wavelet_dict['n_scales'] + 1) * 30
        target_fwhm_small = np.ones(wavelet_dict['n_scales'] + 1) * 20
        output_map_prefix_big = self.root_path + 's2dw_test_big'
        output_map_prefix_small = self.root_path + 's2dw_test_small'

        wavelet_dict_big = self.wav_class.s2dw_wavelet_tranform(
            input_map,
            output_map_prefix_big,
            band_lim,
            scale_int,
            j_min,
            input_fwhm,
            target_fwhm=target_fwhm_big,
            n_quads=1000)
        np.testing.assert_equal(wavelet_dict_big['target_fwhm'],
                                target_fwhm_big)
        wavelet_dict_small = self.wav_class.s2dw_wavelet_tranform(
            input_map,
            output_map_prefix_small,
            band_lim,
            scale_int,
            j_min,
            input_fwhm,
            target_fwhm=target_fwhm_small,
            n_quads=1000)
        np.testing.assert_equal(wavelet_dict_small['target_fwhm'],
                                target_fwhm_small)

        # Set the limit for the comparison to the maximum ell used to
        # write the map.
        scale_lim = min(int(scale_int**j_min), band_lim)
        big_alm = hp.map2alm(hp.reorder(hp.read_map(
            wavelet_dict_big['scale_map']['path'], nest=True),
                                        n2r=True),
                             lmax=scale_lim)
        big_cl = hp.alm2cl(
            hp.almxfl(
                big_alm, 1 / hp.gauss_beam(target_fwhm_big[0] * self.a2r,
                                           lmax=scale_lim - 1)))

        small_alm = hp.map2alm(hp.reorder(hp.read_map(
            wavelet_dict_small['scale_map']['path'], nest=True),
                                          n2r=True),
                               lmax=scale_lim)
        small_cl = hp.alm2cl(
            hp.almxfl(
                small_alm, 1 / hp.gauss_beam(target_fwhm_small[0] * self.a2r,
                                             lmax=scale_lim - 1)))

        # Ignore small cls where numerical error will dominate
        np.testing.assert_almost_equal(
            big_cl[big_cl > 1e-9] / small_cl[big_cl > 1e-9],
            np.ones(np.sum(big_cl > 1e-9)))

        # Repeat the same comparison for all of the wavelet maps.
        for j in range(j_min, wavelet_dict['j_max'] + 1):
            wav_lim = min(int(scale_int**(j + 1)), band_lim)
            big_alm = hp.map2alm(hp.reorder(hp.read_map(
                wavelet_dict_big['wav_%d_map' % (j)]['path'], nest=True),
                                            n2r=True),
                                 lmax=wav_lim)
            big_cl = hp.alm2cl(
                hp.almxfl(
                    big_alm, 1 / hp.gauss_beam(target_fwhm_big[0] * self.a2r,
                                               lmax=wav_lim - 1)))

            small_alm = hp.map2alm(hp.reorder(hp.read_map(
                wavelet_dict_small['wav_%d_map' % (j)]['path'], nest=True),
                                              n2r=True),
                                   lmax=wav_lim)
            small_cl = hp.alm2cl(
                hp.almxfl(
                    small_alm,
                    1 / hp.gauss_beam(target_fwhm_small[0] * self.a2r,
                                      lmax=wav_lim - 1)))

            # Ignore the last cl, it's zero. Also ignore values that should
            # be 0 since numerical error will dominate.
            np.testing.assert_almost_equal(big_cl[big_cl > 1e-9] /
                                           small_cl[big_cl > 1e-9],
                                           np.ones(np.sum(big_cl > 1e-9)),
                                           decimal=2)

        # Now make sure that the input beam is also accounted for
        input_fwhm2 = 2
        output_map_prefix_big2 = self.root_path + 's2dw_test_big2'
        wavelet_dict_big2 = self.wav_class.s2dw_wavelet_tranform(
            input_map,
            output_map_prefix_big2,
            band_lim,
            scale_int,
            j_min,
            input_fwhm2,
            target_fwhm=target_fwhm_big,
            n_quads=1000)

        # Conduct the same comparison with this new factor
        # Set the limit for the comparison to the maximum ell used to
        # write the map.
        scale_lim = min(int(scale_int**j_min), band_lim)
        big_alm = hp.map2alm(hp.reorder(hp.read_map(
            wavelet_dict_big['scale_map']['path'], nest=True),
                                        n2r=True),
                             lmax=scale_lim)
        big_cl = hp.alm2cl(
            hp.almxfl(big_alm,
                      hp.gauss_beam(input_fwhm * self.a2r,
                                    lmax=scale_lim - 1)))

        big_alm2 = hp.map2alm(hp.reorder(hp.read_map(
            wavelet_dict_big2['scale_map']['path'], nest=True),
                                         n2r=True),
                              lmax=scale_lim)
        big_cl2 = hp.alm2cl(
            hp.almxfl(
                big_alm2,
                hp.gauss_beam(input_fwhm2 * self.a2r, lmax=scale_lim - 1)))

        # Ignore small cls where numerical error will dominate
        np.testing.assert_almost_equal(
            big_cl[big_cl > 1e-9] / big_cl2[big_cl > 1e-9],
            np.ones(np.sum(big_cl > 1e-9)))

        # Repeat the same comparison for all of the wavelet maps.
        for j in range(j_min, wavelet_dict['j_max'] + 1):
            wav_lim = min(int(scale_int**(j + 1)), band_lim)
            big_alm = hp.map2alm(hp.reorder(hp.read_map(
                wavelet_dict_big['wav_%d_map' % (j)]['path'], nest=True),
                                            n2r=True),
                                 lmax=wav_lim)
            big_cl = hp.alm2cl(
                hp.almxfl(
                    big_alm,
                    hp.gauss_beam(input_fwhm * self.a2r, lmax=wav_lim - 1)))

            big_alm2 = hp.map2alm(hp.reorder(hp.read_map(
                wavelet_dict_big2['wav_%d_map' % (j)]['path'], nest=True),
                                             n2r=True),
                                  lmax=wav_lim)
            big_cl2 = hp.alm2cl(
                hp.almxfl(
                    big_alm2,
                    hp.gauss_beam(input_fwhm2 * self.a2r, lmax=wav_lim - 1)))

            # Ignore the last cl, it's zero. Also ignore values that should
            # be 0 since numerical error will dominate.
            np.testing.assert_almost_equal(big_cl[big_cl > 1e-9] /
                                           big_cl2[big_cl > 1e-9],
                                           np.ones(np.sum(big_cl > 1e-9)),
                                           decimal=2)

        # Remove all of the maps we created.
        os.remove(wavelet_dict_big['scale_map']['path'])
        for j in range(j_min, wavelet_dict_big['j_max'] + 1):
            os.remove(wavelet_dict_big['wav_%d_map' % (j)]['path'])

        os.remove(wavelet_dict_small['scale_map']['path'])
        for j in range(j_min, wavelet_dict_small['j_max'] + 1):
            os.remove(wavelet_dict_small['wav_%d_map' % (j)]['path'])

        os.remove(wavelet_dict_big2['scale_map']['path'])
        for j in range(j_min, wavelet_dict_big2['j_max'] + 1):
            os.remove(wavelet_dict_big2['wav_%d_map' % (j)]['path'])
Beispiel #5
0
    def test_mgmca(self):
        # Test that mgmca runs the analysis and returns a viable map
        input_map_path = self.root_path + 'gmca_test_full_sim_90_GHZ.fits'
        input_maps_dict = {
            '30': {
                'band_lim': 64,
                'fwhm': 33,
                'path': input_map_path,
                'nside': 128
            },
            '44': {
                'band_lim': 64,
                'fwhm': 24,
                'path': input_map_path,
                'nside': 128
            },
            '70': {
                'band_lim': 64,
                'fwhm': 14,
                'path': input_map_path,
                'nside': 128
            },
            '100': {
                'band_lim': 256,
                'fwhm': 10,
                'path': input_map_path,
                'nside': 128
            },
            '143': {
                'band_lim': 256,
                'fwhm': 7.1,
                'path': input_map_path,
                'nside': 128
            },
            '217': {
                'band_lim': 256,
                'fwhm': 5.5,
                'path': input_map_path,
                'nside': 128
            }
        }
        output_maps_prefix = self.root_path + 'mgmca_test'
        scale_int = 2
        j_min = 1
        lam_s = 0
        wav_analysis_maps = self.wav_class.multifrequency_wavelet_maps(
            input_maps_dict, output_maps_prefix, scale_int, j_min)

        # Run mgmca on the wavelet analysis maps
        max_n_sources = 6
        n_iterations = 10
        lam_s = 0
        mgmca_dict = gmca_core.mgmca(wav_analysis_maps,
                                     max_n_sources,
                                     n_iterations,
                                     lam_s=lam_s)
        np.testing.assert_almost_equal(wav_analysis_maps['0'],
                                       np.dot(mgmca_dict['0']['A'],
                                              mgmca_dict['0']['S']),
                                       decimal=4)
        np.testing.assert_almost_equal(wav_analysis_maps['1'],
                                       np.dot(mgmca_dict['1']['A'],
                                              mgmca_dict['1']['S']),
                                       decimal=4)

        for freq in input_maps_dict.keys():
            os.remove(output_maps_prefix + freq + '_scaling.fits')
            j_max = wavelets_base.calc_j_max(input_maps_dict[freq]['band_lim'],
                                             scale_int)
            for j in range(j_min, j_max + 1):
                os.remove(output_maps_prefix + freq + '_wav_%d.fits' % (j))
Beispiel #6
0
	def multifrequency_wavelet_maps(self,input_maps_dict,output_maps_prefix,
			scale_int,j_min,precomputed=False,nest=False,n_quads=1000):
		"""Creates and groups the wavelet coefficients of several maps by
		analysis level.

		This function allows for wavelet coefficients from several frequency
		maps to be grouped for the purposes of (h)gmca analysis.

		Parameters:
			input_maps_dict (dict): A dictionary that maps frequencies to
				band limits, fwhm, nside, and input map path. Units of arcmin.
			output_maps_prefix (str): The prefix that the output wavelet maps
				will be written to.
			analysis_type (string): A string specifying what type of analysis
				to divide the wavelet scales for. Current options are 'mgmca'
				and 'hgmca'.
			scale_int (int): The integer used as the basis for scaling
				the wavelet functions
			j_min (int): The minimum wavelet scale to use in the decomposition
			precomputed (float): If true, will grab paths to precomputed maps
				based on the output_maps_prefix provided.
			nest (bool): If true the input maps are in the nested
				configuration.
			n_quads (int): Using the trapezoid rule, the number of
				bins to consider for integration


		Returns:
			(dict): A dictionary with one entry per level of analysis. Each
			entry contains a dict with the frequencies that are included and
			the np.array containing the wavelet coefficients

		Notes:
			The frequencies will be ordered from smallest to largest bandlimit.
			This choice is important to maintain contiguous arrays in a
			hierarchical analysis.
		"""
		# First we want to order the frequencies by fwhm. Keys will be strings.
		freq_list = np.array(list(input_maps_dict.keys()))
		fwhm_list = np.array(list(map(lambda x: input_maps_dict[x]['fwhm'],
			input_maps_dict)))
		band_lim_list = np.array(list(map(
			lambda x: input_maps_dict[x]['band_lim'],input_maps_dict)))
		nside_list = np.array(list(map(lambda x: input_maps_dict[x]['nside'],
			input_maps_dict)))
		nside_list = nside_list[np.argsort(fwhm_list)[::-1]]
		freq_list = freq_list[np.argsort(fwhm_list)[::-1]]
		band_lim_list = band_lim_list[np.argsort(fwhm_list)[::-1]]

		# Get the maximum wavelet scale for each map
		j_max_list = np.array(list(map(lambda x: wavelets_base.calc_j_max(
			input_maps_dict[x]['band_lim'],scale_int),input_maps_dict)))
		j_max_list = j_max_list[np.argsort(fwhm_list)[::-1]]
		fwhm_list = fwhm_list[np.argsort(fwhm_list)[::-1]]

		# We will always target the smallest fwhm.
		target_fwhm = np.ones(2+np.max(j_max_list)-j_min)*np.min(fwhm_list)

		# The wavelet analysis maps we will populated. Save the information
		# in the input_maps_dict for later reconstruction.
		wav_analysis_maps = {'input_maps_dict':input_maps_dict,
			'analysis_type':'mgmca','scale_int':scale_int,'j_min':j_min,
			'j_max':np.max(j_max_list),'band_lim':np.max(band_lim_list),
			'target_fwhm':target_fwhm,'output_nside':np.max(nside_list),
			'n_freqs':len(freq_list)}

		# In the case of mgmca, we want to group the wavelet scales such that
		# the number of frequencies is constant. Therefore, the minimum and
		# maximum scale of each group will be set by maximum scale of each
		# frequency being analyzed.
		# Pre-allocate the numpy arrays we're going to fill with each
		# set of wavelet scales
		n_pix = hp.nside2npix(wavelets_base.get_max_nside(scale_int,j_min,
			np.max(nside_list)))
		scale_group = 0
		unique_j_max = np.unique(j_max_list)

		# Go through the scales and create arrays.
		for j in range(j_min,np.max(j_max_list)+1):
			# If this j starts a new group reset the number of pixels and
			# allocate an array.
			if j > unique_j_max[scale_group]:
				n_freqs = np.sum(j_max_list>=unique_j_max[scale_group])
				wav_analysis_maps[str(scale_group)] = (np.zeros((n_freqs,
					n_pix),dtype=np.float64))
				n_pix = 0
				scale_group += 1
			# Add the number of pixels in this scale.
			n_pix += hp.nside2npix(wavelets_base.get_max_nside(scale_int,j+1,
				np.max(nside_list)))
		# Write out the final group
		n_freqs = np.sum(j_max_list>=unique_j_max[scale_group])
		wav_analysis_maps[str(scale_group)] = (np.zeros((n_freqs,n_pix),
			dtype=np.float64))

		# Now we have to iterate through the frequency maps and populate
		# the wavelet maps.
		for i, freq in enumerate(freq_list):
			n_scales = 2+j_max_list[i]-j_min
			input_map = hp.read_map(input_maps_dict[str(freq)]['path'],
				verbose=False,dtype=np.float64)
			freq_wav_dict = self.s2dw_wavelet_tranform(input_map,
				output_maps_prefix+str(freq),
				input_maps_dict[str(freq)]['band_lim'],scale_int,j_min,
				fwhm_list[i],target_fwhm=target_fwhm[:n_scales],
				precomputed=precomputed,nest=nest,n_quads=n_quads)

			# Now we populate our arrays with the wavelets
			nside = wavelets_base.get_max_nside(scale_int,j_min,
				np.max(nside_list))
			n_pix = 0
			scale_group = 0
			dn = hp.nside2npix(nside)
			# The number of frequencies currently being excluded.
			n_exc_freq = 0
			wav_analysis_maps[str(scale_group)][i,n_pix:n_pix+dn] = (
				hp.ud_grade(hp.read_map(freq_wav_dict['scale_map']['path'],
				nest=True,verbose=False,dtype=np.float64),nside,
				order_in='NESTED',order_out='NESTED'))

			# Update our position in the array
			n_pix += dn

			for j in range(j_min,j_max_list[i]+1):
				if j > unique_j_max[scale_group]:
					scale_group += 1
					n_exc_freq = np.sum(j_max_list<unique_j_max[scale_group])
					n_pix = 0
				nside = wavelets_base.get_max_nside(scale_int,j+1,
					np.max(nside_list))
				dn = hp.nside2npix(nside)
				# The frequency index needs to ignore excluded frequencies
				freq_i = i-n_exc_freq
				wav_analysis_maps[str(scale_group)][freq_i,n_pix:n_pix+dn] = (
					hp.ud_grade(hp.read_map(
						freq_wav_dict['wav_%d_map'%(j)]['path'],nest=True,
						verbose=False,dtype=np.float64),nside,
						order_in='NESTED',order_out='NESTED'))
				# Update the position in the array
				n_pix += dn

		return wav_analysis_maps
Beispiel #7
0
    def multifrequency_wavelet_maps(self,
                                    input_maps_dict,
                                    output_maps_prefix,
                                    scale_int,
                                    j_min,
                                    precomputed=False,
                                    nest=False,
                                    n_quads=1000):
        """Creates and groups the wavelet coefficients of several maps by
		analysis level.

		This function allows for wavelet coefficients from several frequency
		maps to be grouped for the purposes of (h)gmca analysis.

		Parameters:
			input_maps_dict (dict): A dictionary that maps frequencies to
				band limits, fwhm, nside, and input map path. Units of arcmin.
			output_maps_prefix (str): The prefix that the output wavelet maps
				will be written to.
			analysis_type (string): A string specifying what type of analysis
				to divide the wavelet scales for. Current options are 'mgmca'
				and 'hgmca'.
			scale_int (int): The integer used as the basis for scaling
				the wavelet functions
			j_min (int): The minimum wavelet scale to use in the decomposition
			precomputed (float): If true, will grab paths to precomputed maps
				based on the output_maps_prefix provided.
			nest (bool): If true the input maps are in the nested
				configuration.
			n_quads (int): Using the trapezoid rule, the number of
				bins to consider for integration

		Returns:
			(dict): A dictionary with one entry per level of analysis. Each
			entry contains a dict with the frequencies that are included and
			the np.array containing the wavelet coefficients

		Notes:
			The frequencies will be ordered from smallest to largest bandlimit.
			This choice is important to maintain contiguous arrays in a
			hierarchical analysis. For hgmca the analysis groups are governed
			by the level of subdivision we'll conduct on the data. This is just
			governed by the nside of the analysis. Unlike mgmca, we won't reshape
			the arrays by the number of frequencies available. Instead, when
			certain frequencies contain no data at those scales we will
			populate the data with nans. This will then be used in the hgmca
			code to indicate that no constraining power should be derived from
			those frequencies at those scales.
		"""
        # First we want to order the frequencies by fwhm. Keys will be strings.
        freq_list = np.array(list(input_maps_dict.keys()))
        fwhm_list = np.array(
            list(map(lambda x: input_maps_dict[x]['fwhm'], input_maps_dict)))
        band_lim_list = np.array(
            list(map(lambda x: input_maps_dict[x]['band_lim'],
                     input_maps_dict)))
        nside_list = np.array(
            list(map(lambda x: input_maps_dict[x]['nside'], input_maps_dict)))
        nside_list = nside_list[np.argsort(fwhm_list)[::-1]]
        freq_list = freq_list[np.argsort(fwhm_list)[::-1]]
        n_freqs = len(freq_list)
        band_lim_list = band_lim_list[np.argsort(fwhm_list)[::-1]]

        # Get the maximum wavelet scale for each map
        j_max_list = np.array(
            list(
                map(
                    lambda x: wavelets_base.calc_j_max(
                        input_maps_dict[x]['band_lim'], scale_int),
                    input_maps_dict)))
        j_max_list = j_max_list[np.argsort(fwhm_list)[::-1]]
        fwhm_list = fwhm_list[np.argsort(fwhm_list)[::-1]]

        # We will always target the smallest fwhm.
        target_fwhm = np.ones(2 + np.max(j_max_list) -
                              j_min) * np.min(fwhm_list)

        # The wavelet analysis maps we will populated. Save the information
        # in the input_maps_dict for later reconstruction.
        wav_analysis_maps = {
            'input_maps_dict': input_maps_dict,
            'analysis_type': 'hgmca',
            'scale_int': scale_int,
            'j_min': j_min,
            'j_max': np.max(j_max_list),
            'band_lim': np.max(band_lim_list),
            'target_fwhm': target_fwhm,
            'output_nside': np.max(nside_list),
            'n_freqs': len(freq_list)
        }

        # Get the largest n_side that will be considered and therefore the
        # actual largest level that will be used (this will be less than
        # or equal to the maximum level specified).
        max_nside = wavelets_base.get_max_nside(scale_int,
                                                np.max(j_max_list) + 1,
                                                np.max(nside_list))
        m_level = nside_to_level(max_nside, self.m_level)
        wav_analysis_maps['m_level'] = m_level

        self.allocate_analysis_arrays(wav_analysis_maps, scale_int, j_min,
                                      np.max(j_max_list), m_level, max_nside,
                                      n_freqs)

        # Get the analysis level for each coefficient
        wav_level = self.get_analysis_level(scale_int, j_min,
                                            np.max(j_max_list), m_level,
                                            max_nside)

        # Create a matching array with the j index for each scale (including
        # 0 for the scaling coefficients).
        wav_j_ind = np.zeros(2 + np.max(j_max_list) - j_min)
        wav_j_ind[1:] = np.arange(j_min, np.max(j_max_list) + 1)

        # Now we go through each input frequency map and populate the
        # wavelet map arrays.
        for freq_i, freq in enumerate(freq_list):
            n_scales = 2 + j_max_list[freq_i] - j_min
            input_map = hp.read_map(input_maps_dict[str(freq)]['path'],
                                    verbose=False,
                                    dtype=np.float64)
            freq_wav_dict = self.s2dw_wavelet_tranform(
                input_map,
                output_maps_prefix + str(freq),
                input_maps_dict[str(freq)]['band_lim'],
                scale_int,
                j_min,
                fwhm_list[freq_i],
                target_fwhm=target_fwhm[:n_scales],
                precomputed=precomputed,
                nest=nest,
                n_quads=n_quads)

            # Iterate through the levels
            for level in range(m_level + 1):
                # If no wavelet scales should be analyzed at this level
                # continue
                if np.sum(wav_level == level) == 0:
                    continue
                # Which scales belong at this level
                level_j_ind = wav_j_ind[wav_level == level]
                # Get the number of patches for a given level.
                n_patches = level_to_npatches(level)

                # Keep track of how many pixels into the level we've
                # gone so far.
                offset = 0
                for j in level_j_ind:
                    # Check that this scale exists for this frequency
                    if j > j_max_list[freq_i]:
                        continue
                    # Now deal with scaling or wavelet coefficient
                    if j == 0:
                        nside = wavelets_base.get_max_nside(
                            scale_int, j_min, max_nside)
                        wav_map_freq = hp.ud_grade(hp.read_map(
                            freq_wav_dict['scale_map']['path'],
                            nest=True,
                            verbose=False,
                            dtype=np.float64),
                                                   nside,
                                                   order_in='NESTED',
                                                   order_out='NESTED')
                    else:
                        nside = wavelets_base.get_max_nside(
                            scale_int, j + 1, max_nside)
                        # Read in the map for this frequency and scale
                        wav_map_freq = hp.ud_grade(hp.read_map(
                            freq_wav_dict['wav_%d_map' % (j)]['path'],
                            nest=True,
                            verbose=False,
                            dtype=np.float64),
                                                   nside,
                                                   order_in='NESTED',
                                                   order_out='NESTED')
                    n_pix = hp.nside2npix(nside)
                    n_pix_patch = n_pix // n_patches

                    # Now populate each patch
                    for patch in range(n_patches):
                        wav_analysis_maps[str(level)][
                            patch, freq_i, offset:offset +
                            n_pix_patch] = wav_map_freq[patch *
                                                        n_pix_patch:(patch +
                                                                     1) *
                                                        n_pix_patch]

                    # Update the number of pixels that have already been
                    # filled.
                    offset += n_pix_patch

        return wav_analysis_maps
Beispiel #8
0
	def test_hgmca_opt(self):
		# Generate a quick approximation using the hgmca_opt code and
		# make sure it gives the same results as the core hgmca code.
		input_map_path = self.root_path + 'gmca_test_full_sim_90_GHZ.fits'
		input_maps_dict = {
			'30':{'band_lim':64,'fwhm':33,'path':input_map_path,
				'nside':128},
			'44':{'band_lim':64,'fwhm':24,'path':input_map_path,
				'nside':128},
			'70':{'band_lim':64,'fwhm':14,'path':input_map_path,
				'nside':128},
			'100':{'band_lim':256,'fwhm':10,'path':input_map_path,
				'nside':128},
			'143':{'band_lim':256,'fwhm':7.1,'path':input_map_path,
				'nside':128},
			'217':{'band_lim':256,'fwhm':5.5,'path':input_map_path,
				'nside':128}}
		output_maps_prefix = self.root_path + 'hgmca_test'
		scale_int = 2
		j_min = 1
		n_freqs = len(input_maps_dict)
		wav_analysis_maps = self.wav_class.multifrequency_wavelet_maps(
			input_maps_dict,output_maps_prefix,scale_int,j_min)

		# Run hgmca on the wavelet analysis maps
		n_sources = 5
		n_epochs = 5
		n_iterations = 2
		lam_s = 1e-3
		lam_hier = np.random.rand(n_sources)
		lam_global = np.random.rand(n_sources)
		A_global = np.random.rand(n_freqs*n_sources).reshape(
			n_freqs,n_sources)
		seed = 5
		min_rmse_rate = 2
		hgmca_dict = hgmca_core.hgmca_opt(wav_analysis_maps,n_sources,n_epochs,
			lam_hier,lam_s,n_iterations,A_init=None,A_global=A_global,
			lam_global=lam_global,seed=seed,enforce_nn_A=True,
			min_rmse_rate=min_rmse_rate,verbose=True)

		# Allocate what we need for the core code.
		X_level = hgmca_core.convert_wav_to_X_level(wav_analysis_maps)
		n_freqs = wav_analysis_maps['n_freqs']
		A_shape = (n_freqs,n_sources)
		A_hier_list = hgmca_core.allocate_A_hier(self.m_level,A_shape,
			A_init=None)
		S_level = hgmca_core.allocate_S_level(self.m_level,X_level,n_sources)
		hgmca_core.init_min_rmse(X_level,A_hier_list,S_level)
		hgmca_core.hgmca_epoch_numba(X_level,A_hier_list,lam_hier,A_global,
			lam_global,S_level,n_epochs,self.m_level,n_iterations,lam_s,seed,
			True,min_rmse_rate)

		for level in range(self.m_level+1):
			for patch in range(wavelets_hgmca.level_to_npatches(level)):
				np.testing.assert_almost_equal(A_hier_list[level][patch],
					hgmca_dict[str(level)]['A'][patch])
				if S_level[level].size == 0:
					self.assertEqual(hgmca_dict[str(level)]['S'].size,0)
				else:
					np.testing.assert_almost_equal(S_level[level][patch],
						hgmca_dict[str(level)]['S'][patch])

		# Check that saving doesn't cause issues
		n_epochs = 6
		save_dict = {'save_path':self.root_path,'save_rate':2}
		hgmca_dict = hgmca_core.hgmca_opt(wav_analysis_maps,n_sources,n_epochs,
			lam_hier,lam_s,n_iterations,A_init=None,A_global=A_global,
			lam_global=lam_global,seed=seed,enforce_nn_A=True,
			min_rmse_rate=min_rmse_rate,save_dict=save_dict,verbose=True)
		# Make sure loading works as well.
		hgmca_dict = hgmca_core.hgmca_opt(wav_analysis_maps,n_sources,n_epochs,
			lam_hier,lam_s,n_iterations,A_init=None,A_global=A_global,
			lam_global=lam_global,seed=seed,enforce_nn_A=True,
			min_rmse_rate=min_rmse_rate,save_dict=save_dict,verbose=True)

		# Delete all the files we made
		folder_path = os.path.join(save_dict['save_path'],'hgmca_save')
		for level in range(self.m_level+1):
			os.remove(os.path.join(folder_path,'A_%d.npy'%(level)))
			os.remove(os.path.join(folder_path,'S_%d.npy'%(level)))
		os.rmdir(folder_path)

		for freq in input_maps_dict.keys():
			os.remove(output_maps_prefix+freq+'_scaling.fits')
			j_max = wavelets_base.calc_j_max(input_maps_dict[freq]['band_lim'],
				scale_int)
			for j in range(j_min,j_max+1):
				os.remove(output_maps_prefix+freq+'_wav_%d.fits'%(j))