예제 #1
0
    def __init__(self,
                 thread_count=1,
                 fast_inverse=True,
                 second_gen=False,
                 show_pysap_plots=False,
                 force_no_pysap=False):
        """
        Load pySAP package if found, and initialize the Starlet transform.

        :param thread_count: number of threads used for pySAP computations
        :param fast_inverse: if True, reconstruction is simply the sum of each scale (only for 1st generation starlet transform)
        :param second_gen: if True, uses the second generation of starlet transform 
        :param show_pysap_plots: if True, displays pySAP plots when calling the decomposition method
        :param force_no_pysap: if True, does not load pySAP and computes starlet transforms in python.
        """
        self.use_pysap, pysap = self._load_pysap(force_no_pysap)
        if self.use_pysap:
            self._transf_class = pysap.load_transform(
                'BsplineWaveletTransformATrousAlgorithm')
        else:
            warnings.warn(
                "The python package pySAP is not used for starlet operations. "
                "They will be performed using (slower) python routines.")
        self._fast_inverse = fast_inverse
        self._second_gen = second_gen
        self._show_pysap_plots = show_pysap_plots
        self.interpol = Interpol()
        self.thread_count = thread_count
 def test_delete_cache(self):
     x, y = util.make_grid(numPix=20, deltapix=1.)
     gauss = Gaussian()
     flux = gauss.function(x, y, amp=1., center_x=0., center_y=0., sigma=1.)
     image = util.array2image(flux)
     interp = Interpol()
     kwargs_interp = {
         'image': image,
         'scale': 1.,
         'phi_G': 0.,
         'center_x': 0.,
         'center_y': 0.
     }
     output = interp.function(x, y, **kwargs_interp)
     assert hasattr(interp, '_image_interp')
     interp.delete_cache()
     assert not hasattr(interp, '_image_interp')
예제 #3
0
    def test_function(self):
        """

        :return:
        """
        x, y = util.make_grid(numPix=20, deltapix=1.)
        gauss = Gaussian()
        flux = gauss.function(x, y, amp=1., center_x=0., center_y=0., sigma=1.)
        image = util.array2image(flux)
        interp = Interpol()
        kwargs_interp = {
            'image': image,
            'scale': 1.,
            'phi_G': 0.,
            'center_x': 0.,
            'center_y': 0.
        }
        output = interp.function(x, y, **kwargs_interp)
        npt.assert_almost_equal(output, flux, decimal=0)

        flux = gauss.function(x - 1.,
                              y,
                              amp=1.,
                              center_x=0.,
                              center_y=0.,
                              sigma=1.)
        kwargs_interp = {
            'image': image,
            'scale': 1.,
            'phi_G': 0.,
            'center_x': 1.,
            'center_y': 0.
        }
        output = interp.function(x, y, **kwargs_interp)
        npt.assert_almost_equal(output, flux, decimal=0)

        flux = gauss.function(x - 1.,
                              y - 1.,
                              amp=1,
                              center_x=0.,
                              center_y=0.,
                              sigma=1.)
        kwargs_interp = {
            'image': image,
            'scale': 1.,
            'phi_G': 0.,
            'center_x': 1.,
            'center_y': 1.
        }
        output = interp.function(x, y, **kwargs_interp)
        npt.assert_almost_equal(output, flux, decimal=0)

        out = interp.function(x=1000, y=0, **kwargs_interp)
        assert out == 0
예제 #4
0
    def __init__(self, light_model_list, smoothing=0.001):
        """

        :param light_model_list: list of light models
        :param smoothing: smoothing factor for certain models (deprecated)
        """
        self.profile_type_list = light_model_list
        self.func_list = []
        for profile_type in light_model_list:
            if profile_type == 'GAUSSIAN':
                from lenstronomy.LightModel.Profiles.gaussian import Gaussian
                self.func_list.append(Gaussian())
            elif profile_type == 'GAUSSIAN_ELLIPSE':
                from lenstronomy.LightModel.Profiles.gaussian import GaussianEllipse
                self.func_list.append(GaussianEllipse())
            elif profile_type == 'ELLIPSOID':
                from lenstronomy.LightModel.Profiles.ellipsoid import Ellipsoid
                self.func_list.append(Ellipsoid())
            elif profile_type == 'MULTI_GAUSSIAN':
                from lenstronomy.LightModel.Profiles.gaussian import MultiGaussian
                self.func_list.append(MultiGaussian())
            elif profile_type == 'MULTI_GAUSSIAN_ELLIPSE':
                from lenstronomy.LightModel.Profiles.gaussian import MultiGaussianEllipse
                self.func_list.append(MultiGaussianEllipse())
            elif profile_type == 'SERSIC':
                from lenstronomy.LightModel.Profiles.sersic import Sersic
                self.func_list.append(Sersic(smoothing=smoothing))
            elif profile_type == 'SERSIC_ELLIPSE':
                from lenstronomy.LightModel.Profiles.sersic import SersicElliptic
                self.func_list.append(
                    SersicElliptic(smoothing=smoothing,
                                   sersic_major_axis=sersic_major_axis_conf))
            elif profile_type == 'CORE_SERSIC':
                from lenstronomy.LightModel.Profiles.sersic import CoreSersic
                self.func_list.append(CoreSersic(smoothing=smoothing))
            elif profile_type == 'SHAPELETS':
                from lenstronomy.LightModel.Profiles.shapelets import ShapeletSet
                self.func_list.append(ShapeletSet())
            elif profile_type == 'SHAPELETS_POLAR':
                from lenstronomy.LightModel.Profiles.shapelets_polar import ShapeletSetPolar
                self.func_list.append(ShapeletSetPolar(exponential=False))
            elif profile_type == 'SHAPELETS_POLAR_EXP':
                from lenstronomy.LightModel.Profiles.shapelets_polar import ShapeletSetPolar
                self.func_list.append(ShapeletSetPolar(exponential=True))
            elif profile_type == 'HERNQUIST':
                from lenstronomy.LightModel.Profiles.hernquist import Hernquist
                self.func_list.append(Hernquist())
            elif profile_type == 'HERNQUIST_ELLIPSE':
                from lenstronomy.LightModel.Profiles.hernquist import HernquistEllipse
                self.func_list.append(HernquistEllipse())
            elif profile_type == 'PJAFFE':
                from lenstronomy.LightModel.Profiles.p_jaffe import PJaffe
                self.func_list.append(PJaffe())
            elif profile_type == 'PJAFFE_ELLIPSE':
                from lenstronomy.LightModel.Profiles.p_jaffe import PJaffe_Ellipse
                self.func_list.append(PJaffe_Ellipse())
            elif profile_type == 'UNIFORM':
                from lenstronomy.LightModel.Profiles.uniform import Uniform
                self.func_list.append(Uniform())
            elif profile_type == 'POWER_LAW':
                from lenstronomy.LightModel.Profiles.power_law import PowerLaw
                self.func_list.append(PowerLaw())
            elif profile_type == 'NIE':
                from lenstronomy.LightModel.Profiles.nie import NIE
                self.func_list.append(NIE())
            elif profile_type == 'CHAMELEON':
                from lenstronomy.LightModel.Profiles.chameleon import Chameleon
                self.func_list.append(Chameleon())
            elif profile_type == 'DOUBLE_CHAMELEON':
                from lenstronomy.LightModel.Profiles.chameleon import DoubleChameleon
                self.func_list.append(DoubleChameleon())
            elif profile_type == 'TRIPLE_CHAMELEON':
                from lenstronomy.LightModel.Profiles.chameleon import TripleChameleon
                self.func_list.append(TripleChameleon())
            elif profile_type == 'INTERPOL':
                from lenstronomy.LightModel.Profiles.interpolation import Interpol
                self.func_list.append(Interpol())
            elif profile_type == 'SLIT_STARLETS':
                from lenstronomy.LightModel.Profiles.starlets import SLIT_Starlets
                self.func_list.append(
                    SLIT_Starlets(fast_inverse=True, second_gen=False))
            elif profile_type == 'SLIT_STARLETS_GEN2':
                from lenstronomy.LightModel.Profiles.starlets import SLIT_Starlets
                self.func_list.append(SLIT_Starlets(second_gen=True))
            else:
                raise ValueError(
                    'No light model of type %s found! Supported are the following models: %s'
                    % (profile_type, _MODELS_SUPPORTED))
        self._num_func = len(self.func_list)
예제 #5
0
    def test_function(self):
        """

        :return:
        """
        for len_x, len_y in [(20, 20), (14, 20)]:
            x, y = util.make_grid(numPix=(len_x, len_y), deltapix=1.)
            gauss = Gaussian()
            flux = gauss.function(x,
                                  y,
                                  amp=1.,
                                  center_x=0.,
                                  center_y=0.,
                                  sigma=1.)
            image = util.array2image(flux, nx=len_y, ny=len_x)

            interp = Interpol()
            kwargs_interp = {
                'image': image,
                'scale': 1.,
                'phi_G': 0.,
                'center_x': 0.,
                'center_y': 0.
            }
            output = interp.function(x, y, **kwargs_interp)

            npt.assert_equal(output, flux)

            flux = gauss.function(x - 1.,
                                  y,
                                  amp=1.,
                                  center_x=0.,
                                  center_y=0.,
                                  sigma=1.)
            kwargs_interp = {
                'image': image,
                'scale': 1.,
                'phi_G': 0.,
                'center_x': 1.,
                'center_y': 0.
            }
            output = interp.function(x, y, **kwargs_interp)
            npt.assert_almost_equal(output, flux, decimal=0)

            flux = gauss.function(x - 1.,
                                  y - 1.,
                                  amp=1,
                                  center_x=0.,
                                  center_y=0.,
                                  sigma=1.)
            kwargs_interp = {
                'image': image,
                'scale': 1.,
                'phi_G': 0.,
                'center_x': 1.,
                'center_y': 1.
            }
            output = interp.function(x, y, **kwargs_interp)
            npt.assert_almost_equal(output, flux, decimal=0)

            out = interp.function(x=1000, y=0, **kwargs_interp)
            assert out == 0

        # test change of center without re-doing interpolation
        out = interp.function(x=0,
                              y=0,
                              image=image,
                              scale=1.,
                              phi_G=0,
                              center_x=0,
                              center_y=0)
        out_shift = interp.function(x=1,
                                    y=0,
                                    image=image,
                                    scale=1.,
                                    phi_G=0,
                                    center_x=1,
                                    center_y=0)
        assert out_shift == out

        # function must give a single value when evaluated at a single point
        assert isinstance(out, float)

        # test change of scale without re-doing interpolation
        out = interp.function(x=1.,
                              y=0,
                              image=image,
                              scale=1.,
                              phi_G=0,
                              center_x=0,
                              center_y=0)
        out_scaled = interp.function(x=2.,
                                     y=0,
                                     image=image,
                                     scale=2,
                                     phi_G=0,
                                     center_x=0,
                                     center_y=0)
        assert out_scaled == out
예제 #6
0
    def __init__(self,
                 light_model_list,
                 deflection_scaling_list=None,
                 source_redshift_list=None,
                 smoothing=0.0000001):
        """

        :param light_model_list: list of light models
        :param deflection_scaling_list: list of floats, rescales the original reduced deflection angles from the lens model
        to enable different models to be placed at different optical (redshift) distances. None means they are all
        :param source_redshift_list: list of redshifts of the model components
        :param smoothing: smoothing factor for certain models (deprecated)
        """
        self.profile_type_list = light_model_list
        self.deflection_scaling_list = deflection_scaling_list
        self.redshift_list = source_redshift_list
        self.func_list = []
        for profile_type in light_model_list:
            if profile_type == 'GAUSSIAN':
                from lenstronomy.LightModel.Profiles.gaussian import Gaussian
                self.func_list.append(Gaussian())
            elif profile_type == 'GAUSSIAN_ELLIPSE':
                from lenstronomy.LightModel.Profiles.gaussian import GaussianEllipse
                self.func_list.append(GaussianEllipse())
            elif profile_type == 'MULTI_GAUSSIAN':
                from lenstronomy.LightModel.Profiles.gaussian import MultiGaussian
                self.func_list.append(MultiGaussian())
            elif profile_type == 'MULTI_GAUSSIAN_ELLIPSE':
                from lenstronomy.LightModel.Profiles.gaussian import MultiGaussianEllipse
                self.func_list.append(MultiGaussianEllipse())
            elif profile_type == 'SERSIC':
                from lenstronomy.LightModel.Profiles.sersic import Sersic
                self.func_list.append(Sersic(smoothing=smoothing))
            elif profile_type == 'SERSIC_ELLIPSE':
                from lenstronomy.LightModel.Profiles.sersic import SersicElliptic
                self.func_list.append(SersicElliptic(smoothing=smoothing))
            elif profile_type == 'CORE_SERSIC':
                from lenstronomy.LightModel.Profiles.sersic import CoreSersic
                self.func_list.append(CoreSersic(smoothing=smoothing))
            elif profile_type == 'SHAPELETS':
                from lenstronomy.LightModel.Profiles.shapelets import ShapeletSet
                self.func_list.append(ShapeletSet())
            elif profile_type == 'HERNQUIST':
                from lenstronomy.LightModel.Profiles.hernquist import Hernquist
                self.func_list.append(Hernquist())
            elif profile_type == 'HERNQUIST_ELLIPSE':
                from lenstronomy.LightModel.Profiles.hernquist import HernquistEllipse
                self.func_list.append(HernquistEllipse())
            elif profile_type == 'PJAFFE':
                from lenstronomy.LightModel.Profiles.p_jaffe import PJaffe
                self.func_list.append(PJaffe())
            elif profile_type == 'PJAFFE_ELLIPSE':
                from lenstronomy.LightModel.Profiles.p_jaffe import PJaffe_Ellipse
                self.func_list.append(PJaffe_Ellipse())
            elif profile_type == 'UNIFORM':
                from lenstronomy.LightModel.Profiles.uniform import Uniform
                self.func_list.append(Uniform())
            elif profile_type == 'POWER_LAW':
                from lenstronomy.LightModel.Profiles.power_law import PowerLaw
                self.func_list.append(PowerLaw())
            elif profile_type == 'NIE':
                from lenstronomy.LightModel.Profiles.nie import NIE
                self.func_list.append(NIE())
            elif profile_type == 'CHAMELEON':
                from lenstronomy.LightModel.Profiles.chameleon import Chameleon
                self.func_list.append(Chameleon())
            elif profile_type == 'DOUBLE_CHAMELEON':
                from lenstronomy.LightModel.Profiles.chameleon import DoubleChameleon
                self.func_list.append(DoubleChameleon())
            elif profile_type == 'INTERPOL':
                from lenstronomy.LightModel.Profiles.interpolation import Interpol
                self.func_list.append(Interpol())
            else:
                raise ValueError('Warning! No light model of type',
                                 profile_type, ' found!')
예제 #7
0
class SLIT_Starlets(object):
    """
    Decomposition of an image using the Isotropic Undecimated Walevet Transform,
    also known as "starlet" or "B-spline", using the 'a trous' algorithm.

    Astronomical data (galaxies, stars, ...) are often very sparsely represented in the starlet basis.

    Based on Starck et al. : https://ui.adsabs.harvard.edu/abs/2007ITIP...16..297S/abstract
    """
    param_names = [
        'amp', 'n_scales', 'n_pixels', 'scale', 'center_x', 'center_y'
    ]
    lower_limit_default = {
        'amp': [0],
        'n_scales': 2,
        'n_pixels': 5,
        'center_x': -1000,
        'center_y': -1000,
        'scale': 0.000000001
    }
    upper_limit_default = {
        'amp': [1e8],
        'n_scales': 20,
        'n_pixels': 1e10,
        'center_x': 1000,
        'center_y': 1000,
        'scale': 10000000000
    }

    def __init__(self,
                 thread_count=1,
                 fast_inverse=True,
                 second_gen=False,
                 show_pysap_plots=False,
                 force_no_pysap=False):
        """
        Load pySAP package if found, and initialize the Starlet transform.

        :param thread_count: number of threads used for pySAP computations
        :param fast_inverse: if True, reconstruction is simply the sum of each scale (only for 1st generation starlet transform)
        :param second_gen: if True, uses the second generation of starlet transform 
        :param show_pysap_plots: if True, displays pySAP plots when calling the decomposition method
        :param force_no_pysap: if True, does not load pySAP and computes starlet transforms in python.
        """
        self.use_pysap, pysap = self._load_pysap(force_no_pysap)
        if self.use_pysap:
            self._transf_class = pysap.load_transform(
                'BsplineWaveletTransformATrousAlgorithm')
        else:
            warnings.warn(
                "The python package pySAP is not used for starlet operations. "
                "They will be performed using (slower) python routines.")
        self._fast_inverse = fast_inverse
        self._second_gen = second_gen
        self._show_pysap_plots = show_pysap_plots
        self.interpol = Interpol()
        self.thread_count = thread_count

    def function(self,
                 x,
                 y,
                 amp=None,
                 n_scales=None,
                 n_pixels=None,
                 scale=1,
                 center_x=0,
                 center_y=0):
        """
        1D inverse starlet transform from starlet coefficients stored in coeffs
        Follows lenstronomy conventions for light profiles.

        :param amp: decomposition coefficients ('amp' to follow conventions in other light profile)
        This is an ndarray with shape (n_scales, sqrt(n_pixels), sqrt(n_pixels)) or (n_scales*n_pixels,)
        :param n_scales: number of decomposition scales
        :param n_pixels: number of pixels in a single scale
        :return: reconstructed signal as 1D array of shape (n_pixels,)
        """
        if len(amp.shape) == 1:
            coeffs = util.array2cube(amp, n_scales, n_pixels)
        elif len(amp.shape) == 3:
            coeffs = amp
        else:
            raise ValueError(
                "Starlets 'amp' has not the right shape (1D or 3D arrays are supported)"
            )
        image = self.function_2d(coeffs, n_scales, n_pixels)
        image = self.interpol.function(x,
                                       y,
                                       image=image,
                                       scale=scale,
                                       center_x=center_x,
                                       center_y=center_y,
                                       amp=1,
                                       phi_G=0)
        return image

    def function_2d(self, coeffs, n_scales, n_pixels):
        """
        2D inverse starlet transform from starlet coefficients stored in coeffs

        :param coeffs: decomposition coefficients, 
        ndarray with shape (n_scales, sqrt(n_pixels), sqrt(n_pixels))
        :param n_scales: number of decomposition scales
        :return: reconstructed signal as 2D array of shape (sqrt(n_pixels), sqrt(n_pixels))
        """
        if self.use_pysap and not self._second_gen:
            return self._inverse_transform(coeffs, n_scales, n_pixels)
        else:
            return starlets_util.inverse_transform(coeffs,
                                                   fast=self._fast_inverse,
                                                   second_gen=self._second_gen)

    def decomposition(self, image, n_scales):
        """
        1D starlet transform from starlet coefficients stored in coeffs

        :param image: 2D image to be decomposed, ndarray with shape (sqrt(n_pixels), sqrt(n_pixels))
        :param n_scales: number of decomposition scales
        :return: reconstructed signal as 1D array of shape (n_scales*n_pixels,)
        """
        if len(image.shape) == 1:
            image_2d = util.array2image(image)
        elif len(image.shape) == 2:
            image_2d = image
        else:
            raise ValueError(
                "image has not the right shape (1D or 2D arrays are supported for starlets decomposition)"
            )
        return util.cube2array(self.decomposition_2d(image_2d, n_scales))

    def decomposition_2d(self, image, n_scales):
        """
        2D starlet transform from starlet coefficients stored in coeffs

        :param image: 2D image to be decomposed, ndarray with shape (sqrt(n_pixels), sqrt(n_pixels))
        :param n_scales: number of decomposition scales
        :return: reconstructed signal as 2D array of shape (n_scales, sqrt(n_pixels), sqrt(n_pixels))
        """
        if self.use_pysap and not self._second_gen:
            coeffs = self._transform(image, n_scales)
        else:
            coeffs = starlets_util.transform(image,
                                             n_scales,
                                             second_gen=self._second_gen)
        return coeffs

    def _inverse_transform(self, coeffs, n_scales, n_pixels):
        """reconstructs image from starlet coefficients"""
        self._check_transform_pysap(n_scales, n_pixels)
        if self._fast_inverse and not self._second_gen:
            # for 1st gen starlet the reconstruction can be performed by summing all scales
            image = np.sum(coeffs, axis=0)
        else:
            coeffs = self._coeffs2pysap(coeffs)
            self._transf.analysis_data = coeffs
            result = self._transf.synthesis()
            if self._show_pysap_plots:
                result.show()
            image = result.data
        return image

    def _transform(self, image, n_scales):
        """decomposes an image into starlets coefficients"""
        self._check_transform_pysap(n_scales, image.size)
        self._transf.data = image
        self._transf.analysis()
        if self._show_pysap_plots:
            self._transf.show()
        coeffs = self._transf.analysis_data
        coeffs = self._pysap2coeffs(coeffs)
        return coeffs

    def _check_transform_pysap(self, n_scales, n_pixels):
        """if needed, update the loaded pySAP transform to correct number of scales"""
        if not hasattr(
                self, '_transf'
        ) or n_scales != self._n_scales or n_pixels != self._n_pixels:
            self._transf = self._transf_class(nb_scale=n_scales,
                                              verbose=False,
                                              nb_procs=self.thread_count)
            self._n_scales = n_scales
            self._n_pixels = n_pixels

    def _pysap2coeffs(self, coeffs):
        """convert pySAP decomposition coefficients to numpy array"""
        return np.asarray(coeffs)

    def _coeffs2pysap(self, coeffs):
        """convert coefficients stored in numpy array to list required by pySAP"""
        coeffs_list = []
        for i in range(coeffs.shape[0]):
            coeffs_list.append(coeffs[i, :, :])
        return coeffs_list

    def _load_pysap(self, force_no_pysap):
        """load pySAP module"""
        if force_no_pysap:
            return False, None
        try:
            import pysap
        except ImportError:
            return False, None
        else:
            return True, pysap

    def delete_cache(self):
        """delete the cached interpolated image"""
        self.interpol.delete_cache()
예제 #8
0
파일: code.py 프로젝트: SSingh087/seq-pred
          image_reconstructed = shapeletSet.function(x, y, coeff_ngc, n_max, beta, center_x=0, center_y=0)
          image_reconstructed_2d = util.array2image(image_reconstructed)

          theta_x_high_res, theta_y_high_res = util.make_grid(numPix=numPix*high_res_factor,
                                                              deltapix=deltaPix/high_res_factor)
          beta_x_high_res, beta_y_high_res = lensModel.ray_shooting(theta_x_high_res, theta_y_high_res,
                                                                      kwargs=kwargs_lens_list)
          source_lensed = shapeletSet.function(beta_x_high_res, beta_y_high_res,
                                              coeff_ngc, n_max, beta=.05,
                                              center_x=cen[ii], center_y=0)

          source_lensed = util.array2image(source_lensed)
          kwargs_interp = {'image': ngc_data_resized, 'center_x': 0, 'center_y': 0, 'scale': 0.005, 'phi_G':0.2}

          interp_light = Interpol()
          source_lensed_interp = interp_light.function(beta_x_high_res, beta_y_high_res, **kwargs_interp)
          source_lensed_interp = util.array2image(source_lensed_interp)

          light_model_list = ['SERSIC_ELLIPSE', 'SERSIC_ELLIPSE','NIE']
          kwargs_lens_light = [
              {'amp':  .3, 'R_sersic': 0.04, 'n_sersic': 0.3, 'e1': 0, 'e2': 0, 'center_x': 0, 'center_y': 0},
              {'amp': .01, 'R_sersic': 0.05, 'n_sersic': 0.2, 'e1': 0, 'e2': 0, 'center_x': 0, 'center_y': 0},
              {'amp': .05, 'e1':.5, 'e2':.4, 's_scale':1}
          ]
          lensLightModel = LightModel(light_model_list=light_model_list)

          flux_lens_light = lensLightModel.surface_brightness(theta_x_high_res, theta_y_high_res, kwargs_lens_light)
          flux_lens_light = util.array2image(flux_lens_light)
          image_combined = source_lensed_interp + flux_lens_light