def test_function(self): """ :return: """ x, y = util.make_grid(numPix=20, deltapix=1.) gauss = Gaussian() flux = gauss.function(x, y, amp=1., center_x=0., center_y=0., sigma=1.) image = util.array2image(flux) interp = Interpol() kwargs_interp = { 'image': image, 'scale': 1., 'phi_G': 0., 'center_x': 0., 'center_y': 0. } output = interp.function(x, y, **kwargs_interp) npt.assert_almost_equal(output, flux, decimal=0) flux = gauss.function(x - 1., y, amp=1., center_x=0., center_y=0., sigma=1.) kwargs_interp = { 'image': image, 'scale': 1., 'phi_G': 0., 'center_x': 1., 'center_y': 0. } output = interp.function(x, y, **kwargs_interp) npt.assert_almost_equal(output, flux, decimal=0) flux = gauss.function(x - 1., y - 1., amp=1, center_x=0., center_y=0., sigma=1.) kwargs_interp = { 'image': image, 'scale': 1., 'phi_G': 0., 'center_x': 1., 'center_y': 1. } output = interp.function(x, y, **kwargs_interp) npt.assert_almost_equal(output, flux, decimal=0) out = interp.function(x=1000, y=0, **kwargs_interp) assert out == 0
def test_delete_cache(self): x, y = util.make_grid(numPix=20, deltapix=1.) gauss = Gaussian() flux = gauss.function(x, y, amp=1., center_x=0., center_y=0., sigma=1.) image = util.array2image(flux) interp = Interpol() kwargs_interp = { 'image': image, 'scale': 1., 'phi_G': 0., 'center_x': 0., 'center_y': 0. } output = interp.function(x, y, **kwargs_interp) assert hasattr(interp, '_image_interp') interp.delete_cache() assert not hasattr(interp, '_image_interp')
class SLIT_Starlets(object): """ Decomposition of an image using the Isotropic Undecimated Walevet Transform, also known as "starlet" or "B-spline", using the 'a trous' algorithm. Astronomical data (galaxies, stars, ...) are often very sparsely represented in the starlet basis. Based on Starck et al. : https://ui.adsabs.harvard.edu/abs/2007ITIP...16..297S/abstract """ param_names = [ 'amp', 'n_scales', 'n_pixels', 'scale', 'center_x', 'center_y' ] lower_limit_default = { 'amp': [0], 'n_scales': 2, 'n_pixels': 5, 'center_x': -1000, 'center_y': -1000, 'scale': 0.000000001 } upper_limit_default = { 'amp': [1e8], 'n_scales': 20, 'n_pixels': 1e10, 'center_x': 1000, 'center_y': 1000, 'scale': 10000000000 } def __init__(self, thread_count=1, fast_inverse=True, second_gen=False, show_pysap_plots=False, force_no_pysap=False): """ Load pySAP package if found, and initialize the Starlet transform. :param thread_count: number of threads used for pySAP computations :param fast_inverse: if True, reconstruction is simply the sum of each scale (only for 1st generation starlet transform) :param second_gen: if True, uses the second generation of starlet transform :param show_pysap_plots: if True, displays pySAP plots when calling the decomposition method :param force_no_pysap: if True, does not load pySAP and computes starlet transforms in python. """ self.use_pysap, pysap = self._load_pysap(force_no_pysap) if self.use_pysap: self._transf_class = pysap.load_transform( 'BsplineWaveletTransformATrousAlgorithm') else: warnings.warn( "The python package pySAP is not used for starlet operations. " "They will be performed using (slower) python routines.") self._fast_inverse = fast_inverse self._second_gen = second_gen self._show_pysap_plots = show_pysap_plots self.interpol = Interpol() self.thread_count = thread_count def function(self, x, y, amp=None, n_scales=None, n_pixels=None, scale=1, center_x=0, center_y=0): """ 1D inverse starlet transform from starlet coefficients stored in coeffs Follows lenstronomy conventions for light profiles. :param amp: decomposition coefficients ('amp' to follow conventions in other light profile) This is an ndarray with shape (n_scales, sqrt(n_pixels), sqrt(n_pixels)) or (n_scales*n_pixels,) :param n_scales: number of decomposition scales :param n_pixels: number of pixels in a single scale :return: reconstructed signal as 1D array of shape (n_pixels,) """ if len(amp.shape) == 1: coeffs = util.array2cube(amp, n_scales, n_pixels) elif len(amp.shape) == 3: coeffs = amp else: raise ValueError( "Starlets 'amp' has not the right shape (1D or 3D arrays are supported)" ) image = self.function_2d(coeffs, n_scales, n_pixels) image = self.interpol.function(x, y, image=image, scale=scale, center_x=center_x, center_y=center_y, amp=1, phi_G=0) return image def function_2d(self, coeffs, n_scales, n_pixels): """ 2D inverse starlet transform from starlet coefficients stored in coeffs :param coeffs: decomposition coefficients, ndarray with shape (n_scales, sqrt(n_pixels), sqrt(n_pixels)) :param n_scales: number of decomposition scales :return: reconstructed signal as 2D array of shape (sqrt(n_pixels), sqrt(n_pixels)) """ if self.use_pysap and not self._second_gen: return self._inverse_transform(coeffs, n_scales, n_pixels) else: return starlets_util.inverse_transform(coeffs, fast=self._fast_inverse, second_gen=self._second_gen) def decomposition(self, image, n_scales): """ 1D starlet transform from starlet coefficients stored in coeffs :param image: 2D image to be decomposed, ndarray with shape (sqrt(n_pixels), sqrt(n_pixels)) :param n_scales: number of decomposition scales :return: reconstructed signal as 1D array of shape (n_scales*n_pixels,) """ if len(image.shape) == 1: image_2d = util.array2image(image) elif len(image.shape) == 2: image_2d = image else: raise ValueError( "image has not the right shape (1D or 2D arrays are supported for starlets decomposition)" ) return util.cube2array(self.decomposition_2d(image_2d, n_scales)) def decomposition_2d(self, image, n_scales): """ 2D starlet transform from starlet coefficients stored in coeffs :param image: 2D image to be decomposed, ndarray with shape (sqrt(n_pixels), sqrt(n_pixels)) :param n_scales: number of decomposition scales :return: reconstructed signal as 2D array of shape (n_scales, sqrt(n_pixels), sqrt(n_pixels)) """ if self.use_pysap and not self._second_gen: coeffs = self._transform(image, n_scales) else: coeffs = starlets_util.transform(image, n_scales, second_gen=self._second_gen) return coeffs def _inverse_transform(self, coeffs, n_scales, n_pixels): """reconstructs image from starlet coefficients""" self._check_transform_pysap(n_scales, n_pixels) if self._fast_inverse and not self._second_gen: # for 1st gen starlet the reconstruction can be performed by summing all scales image = np.sum(coeffs, axis=0) else: coeffs = self._coeffs2pysap(coeffs) self._transf.analysis_data = coeffs result = self._transf.synthesis() if self._show_pysap_plots: result.show() image = result.data return image def _transform(self, image, n_scales): """decomposes an image into starlets coefficients""" self._check_transform_pysap(n_scales, image.size) self._transf.data = image self._transf.analysis() if self._show_pysap_plots: self._transf.show() coeffs = self._transf.analysis_data coeffs = self._pysap2coeffs(coeffs) return coeffs def _check_transform_pysap(self, n_scales, n_pixels): """if needed, update the loaded pySAP transform to correct number of scales""" if not hasattr( self, '_transf' ) or n_scales != self._n_scales or n_pixels != self._n_pixels: self._transf = self._transf_class(nb_scale=n_scales, verbose=False, nb_procs=self.thread_count) self._n_scales = n_scales self._n_pixels = n_pixels def _pysap2coeffs(self, coeffs): """convert pySAP decomposition coefficients to numpy array""" return np.asarray(coeffs) def _coeffs2pysap(self, coeffs): """convert coefficients stored in numpy array to list required by pySAP""" coeffs_list = [] for i in range(coeffs.shape[0]): coeffs_list.append(coeffs[i, :, :]) return coeffs_list def _load_pysap(self, force_no_pysap): """load pySAP module""" if force_no_pysap: return False, None try: import pysap except ImportError: return False, None else: return True, pysap def delete_cache(self): """delete the cached interpolated image""" self.interpol.delete_cache()
def test_function(self): """ :return: """ for len_x, len_y in [(20, 20), (14, 20)]: x, y = util.make_grid(numPix=(len_x, len_y), deltapix=1.) gauss = Gaussian() flux = gauss.function(x, y, amp=1., center_x=0., center_y=0., sigma=1.) image = util.array2image(flux, nx=len_y, ny=len_x) interp = Interpol() kwargs_interp = { 'image': image, 'scale': 1., 'phi_G': 0., 'center_x': 0., 'center_y': 0. } output = interp.function(x, y, **kwargs_interp) npt.assert_equal(output, flux) flux = gauss.function(x - 1., y, amp=1., center_x=0., center_y=0., sigma=1.) kwargs_interp = { 'image': image, 'scale': 1., 'phi_G': 0., 'center_x': 1., 'center_y': 0. } output = interp.function(x, y, **kwargs_interp) npt.assert_almost_equal(output, flux, decimal=0) flux = gauss.function(x - 1., y - 1., amp=1, center_x=0., center_y=0., sigma=1.) kwargs_interp = { 'image': image, 'scale': 1., 'phi_G': 0., 'center_x': 1., 'center_y': 1. } output = interp.function(x, y, **kwargs_interp) npt.assert_almost_equal(output, flux, decimal=0) out = interp.function(x=1000, y=0, **kwargs_interp) assert out == 0 # test change of center without re-doing interpolation out = interp.function(x=0, y=0, image=image, scale=1., phi_G=0, center_x=0, center_y=0) out_shift = interp.function(x=1, y=0, image=image, scale=1., phi_G=0, center_x=1, center_y=0) assert out_shift == out # function must give a single value when evaluated at a single point assert isinstance(out, float) # test change of scale without re-doing interpolation out = interp.function(x=1., y=0, image=image, scale=1., phi_G=0, center_x=0, center_y=0) out_scaled = interp.function(x=2., y=0, image=image, scale=2, phi_G=0, center_x=0, center_y=0) assert out_scaled == out
image_reconstructed = shapeletSet.function(x, y, coeff_ngc, n_max, beta, center_x=0, center_y=0) image_reconstructed_2d = util.array2image(image_reconstructed) theta_x_high_res, theta_y_high_res = util.make_grid(numPix=numPix*high_res_factor, deltapix=deltaPix/high_res_factor) beta_x_high_res, beta_y_high_res = lensModel.ray_shooting(theta_x_high_res, theta_y_high_res, kwargs=kwargs_lens_list) source_lensed = shapeletSet.function(beta_x_high_res, beta_y_high_res, coeff_ngc, n_max, beta=.05, center_x=cen[ii], center_y=0) source_lensed = util.array2image(source_lensed) kwargs_interp = {'image': ngc_data_resized, 'center_x': 0, 'center_y': 0, 'scale': 0.005, 'phi_G':0.2} interp_light = Interpol() source_lensed_interp = interp_light.function(beta_x_high_res, beta_y_high_res, **kwargs_interp) source_lensed_interp = util.array2image(source_lensed_interp) light_model_list = ['SERSIC_ELLIPSE', 'SERSIC_ELLIPSE','NIE'] kwargs_lens_light = [ {'amp': .3, 'R_sersic': 0.04, 'n_sersic': 0.3, 'e1': 0, 'e2': 0, 'center_x': 0, 'center_y': 0}, {'amp': .01, 'R_sersic': 0.05, 'n_sersic': 0.2, 'e1': 0, 'e2': 0, 'center_x': 0, 'center_y': 0}, {'amp': .05, 'e1':.5, 'e2':.4, 's_scale':1} ] lensLightModel = LightModel(light_model_list=light_model_list) flux_lens_light = lensLightModel.surface_brightness(theta_x_high_res, theta_y_high_res, kwargs_lens_light) flux_lens_light = util.array2image(flux_lens_light) image_combined = source_lensed_interp + flux_lens_light