예제 #1
0
    def psf_iteration(self, compute_bands=None, **kwargs_psf_iter):
        """
        iterative PSF reconstruction

        :param compute_bands: bool list, if multiple bands, this process can be limited to a subset of bands
        :param kwargs_psf_iter: keyword arguments as used or available in PSFIteration.update_iterative() definition
        :return: 0, updated PSF is stored in self.multi_band_list
        """
        kwargs_model = self._updateManager.kwargs_model
        kwargs_likelihood = self._updateManager.kwargs_likelihood
        likelihood_mask_list = kwargs_likelihood.get('image_likelihood_mask_list', None)
        kwargs_pixelbased = kwargs_likelihood.get('kwargs_pixelbased', None)
        kwargs_temp = self.best_fit(bijective=False)
        if compute_bands is None:
            compute_bands = [True] * len(self.multi_band_list)

        for band_index in range(len(self.multi_band_list)):
            if compute_bands[band_index] is True:
                kwargs_psf = self.multi_band_list[band_index][1]
                image_model = SingleBandMultiModel(self.multi_band_list, kwargs_model,
                                                   likelihood_mask_list=likelihood_mask_list, band_index=band_index,
                                                   kwargs_pixelbased=kwargs_pixelbased)
                psf_iter = PsfFitting(image_model_class=image_model)
                kwargs_psf = psf_iter.update_iterative(kwargs_psf, kwargs_params=kwargs_temp, **kwargs_psf_iter)
                self.multi_band_list[band_index][1] = kwargs_psf
        return 0
예제 #2
0
 def psf_iteration(self, fitting_kwargs, lens_input, source_input,
                   lens_light_input, ps_input, cosmo_input):
     #lens_temp = copy.deepcopy(lens_input)
     lens_updated = self._param.update_lens_scaling(cosmo_input, lens_input)
     source_updated = self._param.image2source_plane(
         source_input, lens_updated)
     psf_iter_factor = fitting_kwargs['psf_iter_factor']
     psf_iter_num = fitting_kwargs['psf_iter_num']
     compute_bool = fitting_kwargs.get('compute_bands',
                                       [True] * len(self.multi_band_list))
     kwargs_psf_iter = fitting_kwargs.get('kwargs_psf_iter', {})
     for i in range(len(self.multi_band_list)):
         if compute_bool[i] is True:
             kwargs_data = self.multi_band_list[i][0]
             kwargs_psf = self.multi_band_list[i][1]
             kwargs_numerics = self.multi_band_list[i][2]
             image_model = class_creator.create_image_model(
                 kwargs_data=kwargs_data,
                 kwargs_psf=kwargs_psf,
                 kwargs_numerics=kwargs_numerics,
                 kwargs_model=self.kwargs_model)
             psf_iter = PsfFitting(image_model_class=image_model,
                                   kwargs_psf_iter=kwargs_psf_iter)
             kwargs_psf = psf_iter.update_iterative(kwargs_psf,
                                                    lens_updated,
                                                    source_updated,
                                                    lens_light_input,
                                                    ps_input,
                                                    factor=psf_iter_factor,
                                                    num_iter=psf_iter_num,
                                                    verbose=self._verbose,
                                                    no_break=True)
             self.multi_band_list[i][1] = kwargs_psf
             self.fitting.multi_band_list[i][1] = kwargs_psf
     return 0
예제 #3
0
    def psf_iteration(self,
                      num_iter=10,
                      no_break=True,
                      stacking_method='median',
                      block_center_neighbour=0,
                      keep_psf_error_map=True,
                      psf_symmetry=1,
                      psf_iter_factor=1,
                      error_map_radius=None,
                      verbose=True,
                      compute_bands=None):
        """
        iterative PSF reconstruction

        :param num_iter: number of iterations in the process
        :param no_break: bool, if False will break the process as soon as one step lead to a wors reconstruction then the previous step
        :param stacking_method: string, 'median' and 'mean' supported
        :param block_center_neighbour: radius of neighbouring point source to be blocked in the reconstruction
        :param keep_psf_error_map: bool, whether or not to keep the previous psf_error_map
        :param psf_symmetry: int, number of invariant rotations in the reconstructed PSF
        :param psf_iter_factor: factor of new estimated PSF relative to the old one PSF_updated = (1-psf_iter_factor) * PSF_old + psf_iter_factor*PSF_new
        :param error_map_radius: float, radius (in arc seconds) of the outermost error in the PSF estimate
         (e.g. to avoid double counting of overlapping PSF errors)
        :param verbose: bool, print statements
        :param compute_bands: bool list, if multiple bands, this process can be limited to a subset of bands
        :return: 0, updated PSF is stored in self.multi_band_list
        """
        kwargs_model = self._updateManager.kwargs_model
        kwargs_likelihood = self._updateManager.kwargs_likelihood
        likelihood_mask_list = kwargs_likelihood.get(
            'image_likelihood_mask_list', None)
        kwargs_pixelbased = kwargs_likelihood.get('kwargs_pixelbased', None)
        kwargs_temp = self.best_fit(bijective=False)
        if compute_bands is None:
            compute_bands = [True] * len(self.multi_band_list)

        for band_index in range(len(self.multi_band_list)):
            if compute_bands[band_index] is True:
                kwargs_psf = self.multi_band_list[band_index][1]
                image_model = SingleBandMultiModel(
                    self.multi_band_list,
                    kwargs_model,
                    likelihood_mask_list=likelihood_mask_list,
                    band_index=band_index,
                    kwargs_pixelbased=kwargs_pixelbased)
                psf_iter = PsfFitting(image_model_class=image_model)
                kwargs_psf = psf_iter.update_iterative(
                    kwargs_psf,
                    kwargs_params=kwargs_temp,
                    num_iter=num_iter,
                    no_break=no_break,
                    stacking_method=stacking_method,
                    block_center_neighbour=block_center_neighbour,
                    keep_psf_error_map=keep_psf_error_map,
                    psf_symmetry=psf_symmetry,
                    psf_iter_factor=psf_iter_factor,
                    error_map_radius=error_map_radius,
                    verbose=verbose)
                self.multi_band_list[band_index][1] = kwargs_psf
        return 0
예제 #4
0
    def psf_iteration(self, num_iter=10, no_break=True, stacking_method='median', block_center_neighbour=0, keep_psf_error_map=True,
                 psf_symmetry=1, psf_iter_factor=1, verbose=True, compute_bands=None):
        """
        iterative PSF reconstruction

        :param num_iter: number of iterations in the process
        :param no_break: bool, if False will break the process as soon as one step lead to a wors reconstruction then the previous step
        :param stacking_method: string, 'median' and 'mean' supported
        :param block_center_neighbour: radius of neighbouring point source to be blocked in the reconstruction
        :param keep_psf_error_map: bool, whether or not to keep the previous psf_error_map
        :param psf_symmetry: int, number of invariant rotations in the reconstructed PSF
        :param psf_iter_factor: factor of new estimated PSF relative to the old one PSF_updated = (1-psf_iter_factor) * PSF_old + psf_iter_factor*PSF_new
        :param verbose: bool, print statements
        :param compute_bands: bool list, if multiple bands, this process can be limited to a subset of bands
        :return: 0, updated PSF is stored in self.mult_iband_list
        """
        kwargs_model = self._updateManager.kwargs_model
        kwargs_likelihood = self._updateManager.kwargs_likelihood
        likelihood_mask_list = kwargs_likelihood.get('image_likelihood_mask_list', None)
        param_class = self._param_class
        lens_updated = param_class.update_lens_scaling(self._cosmo_temp, self._lens_temp)
        source_updated = param_class.image2source_plane(self._source_temp, lens_updated)
        if compute_bands is None:
            compute_bands = [True] * len(self.multi_band_list)

        for band_index in range(len(self.multi_band_list)):
            if compute_bands[band_index] is True:
                kwargs_psf = self.multi_band_list[band_index][1]
                image_model = SingleBandMultiModel(self.multi_band_list, kwargs_model,
                                                   likelihood_mask_list=likelihood_mask_list, band_index=band_index)
                psf_iter = PsfFitting(image_model_class=image_model)
                kwargs_psf = psf_iter.update_iterative(kwargs_psf, lens_updated, source_updated,
                                                       self._lens_light_temp, self._ps_temp, num_iter=num_iter,
                                                       no_break=no_break, stacking_method=stacking_method,
                                                       block_center_neighbour=block_center_neighbour,
                                                       keep_psf_error_map=keep_psf_error_map,
                 psf_symmetry=psf_symmetry, psf_iter_factor=psf_iter_factor, verbose=verbose)
                self.multi_band_list[band_index][1] = kwargs_psf
        return 0
예제 #5
0
    def setup(self):

        # data specifics
        sigma_bkg = 0.01  # background noise per pixel
        exp_time = 100  # exposure time (arbitrary units, flux per pixel is in units #photons/exp_time unit)
        numPix = 100  # cutout pixel size
        deltaPix = 0.05  # pixel size in arcsec (area per pixel = deltaPix**2)
        fwhm = 0.3  # full width half max of PSF

        # PSF specification

        kwargs_data = sim_util.data_configure_simple(numPix, deltaPix, exp_time, sigma_bkg)
        data_class = Data(kwargs_data)
        sigma = util.fwhm2sigma(fwhm)
        x_grid, y_grid = util.make_grid(numPix=31, deltapix=0.05)
        from lenstronomy.LightModel.Profiles.gaussian import Gaussian
        gaussian = Gaussian()
        kernel_point_source = gaussian.function(x_grid, y_grid, amp=1., sigma_x=sigma, sigma_y=sigma,
                                                center_x=0, center_y=0)
        kernel_point_source /= np.sum(kernel_point_source)
        kernel_point_source = util.array2image(kernel_point_source)
        self.kwargs_psf = {'psf_type': 'PIXEL', 'kernel_point_source': kernel_point_source}

        psf_class = PSF(kwargs_psf=self.kwargs_psf)

        # 'EXERNAL_SHEAR': external shear
        kwargs_shear = {'e1': 0.01, 'e2': 0.01}  # gamma_ext: shear strength, psi_ext: shear angel (in radian)
        phi, q = 0.2, 0.8
        e1, e2 = param_util.phi_q2_ellipticity(phi, q)
        kwargs_spemd = {'theta_E': 1., 'gamma': 1.8, 'center_x': 0, 'center_y': 0, 'e1': e1, 'e2': e2}

        lens_model_list = ['SPEP', 'SHEAR']
        self.kwargs_lens = [kwargs_spemd, kwargs_shear]
        lens_model_class = LensModel(lens_model_list=lens_model_list)
        # list of light profiles (for lens and source)
        # 'SERSIC': spherical Sersic profile
        kwargs_sersic = {'amp': 1., 'R_sersic': 0.1, 'n_sersic': 2, 'center_x': 0, 'center_y': 0}
        # 'SERSIC_ELLIPSE': elliptical Sersic profile
        phi, q = 0.2, 0.9
        e1, e2 = param_util.phi_q2_ellipticity(phi, q)
        kwargs_sersic_ellipse = {'amp': 1., 'R_sersic': .6, 'n_sersic': 7, 'center_x': 0, 'center_y': 0,
                                 'e1': e1, 'e2': e2}

        lens_light_model_list = ['SERSIC']
        self.kwargs_lens_light = [kwargs_sersic]
        lens_light_model_class = LightModel(light_model_list=lens_light_model_list)
        source_model_list = ['SERSIC_ELLIPSE']
        self.kwargs_source = [kwargs_sersic_ellipse]
        source_model_class = LightModel(light_model_list=source_model_list)
        self.kwargs_ps = [{'ra_source': 0.0, 'dec_source': 0.0,
                           'source_amp': 10.}]  # quasar point source position in the source plane and intrinsic brightness
        point_source_class = PointSource(point_source_type_list=['SOURCE_POSITION'], fixed_magnification_list=[True])
        kwargs_numerics = {'subgrid_res': 3, 'psf_subgrid': True}
        imageModel = ImageModel(data_class, psf_class, lens_model_class, source_model_class,
                                     lens_light_model_class,
                                     point_source_class, kwargs_numerics=kwargs_numerics)
        image_sim = sim_util.simulate_simple(imageModel, self.kwargs_lens, self.kwargs_source,
                                         self.kwargs_lens_light, self.kwargs_ps)
        data_class.update_data(image_sim)
        self.imageModel = ImageModel(data_class, psf_class, lens_model_class, source_model_class,
                                lens_light_model_class,
                                point_source_class, kwargs_numerics=kwargs_numerics)

        self.psf_fitting = PsfFitting(self.imageModel)
예제 #6
0
class TestImageModel(object):
    """
    tests the source model routines
    """

    def setup(self):

        # data specifics
        sigma_bkg = 0.01  # background noise per pixel
        exp_time = 100  # exposure time (arbitrary units, flux per pixel is in units #photons/exp_time unit)
        numPix = 100  # cutout pixel size
        deltaPix = 0.05  # pixel size in arcsec (area per pixel = deltaPix**2)
        fwhm = 0.3  # full width half max of PSF

        # PSF specification

        kwargs_data = sim_util.data_configure_simple(numPix, deltaPix, exp_time, sigma_bkg)
        data_class = Data(kwargs_data)
        sigma = util.fwhm2sigma(fwhm)
        x_grid, y_grid = util.make_grid(numPix=31, deltapix=0.05)
        from lenstronomy.LightModel.Profiles.gaussian import Gaussian
        gaussian = Gaussian()
        kernel_point_source = gaussian.function(x_grid, y_grid, amp=1., sigma_x=sigma, sigma_y=sigma,
                                                center_x=0, center_y=0)
        kernel_point_source /= np.sum(kernel_point_source)
        kernel_point_source = util.array2image(kernel_point_source)
        self.kwargs_psf = {'psf_type': 'PIXEL', 'kernel_point_source': kernel_point_source}

        psf_class = PSF(kwargs_psf=self.kwargs_psf)

        # 'EXERNAL_SHEAR': external shear
        kwargs_shear = {'e1': 0.01, 'e2': 0.01}  # gamma_ext: shear strength, psi_ext: shear angel (in radian)
        phi, q = 0.2, 0.8
        e1, e2 = param_util.phi_q2_ellipticity(phi, q)
        kwargs_spemd = {'theta_E': 1., 'gamma': 1.8, 'center_x': 0, 'center_y': 0, 'e1': e1, 'e2': e2}

        lens_model_list = ['SPEP', 'SHEAR']
        self.kwargs_lens = [kwargs_spemd, kwargs_shear]
        lens_model_class = LensModel(lens_model_list=lens_model_list)
        # list of light profiles (for lens and source)
        # 'SERSIC': spherical Sersic profile
        kwargs_sersic = {'amp': 1., 'R_sersic': 0.1, 'n_sersic': 2, 'center_x': 0, 'center_y': 0}
        # 'SERSIC_ELLIPSE': elliptical Sersic profile
        phi, q = 0.2, 0.9
        e1, e2 = param_util.phi_q2_ellipticity(phi, q)
        kwargs_sersic_ellipse = {'amp': 1., 'R_sersic': .6, 'n_sersic': 7, 'center_x': 0, 'center_y': 0,
                                 'e1': e1, 'e2': e2}

        lens_light_model_list = ['SERSIC']
        self.kwargs_lens_light = [kwargs_sersic]
        lens_light_model_class = LightModel(light_model_list=lens_light_model_list)
        source_model_list = ['SERSIC_ELLIPSE']
        self.kwargs_source = [kwargs_sersic_ellipse]
        source_model_class = LightModel(light_model_list=source_model_list)
        self.kwargs_ps = [{'ra_source': 0.0, 'dec_source': 0.0,
                           'source_amp': 10.}]  # quasar point source position in the source plane and intrinsic brightness
        point_source_class = PointSource(point_source_type_list=['SOURCE_POSITION'], fixed_magnification_list=[True])
        kwargs_numerics = {'subgrid_res': 3, 'psf_subgrid': True}
        imageModel = ImageModel(data_class, psf_class, lens_model_class, source_model_class,
                                     lens_light_model_class,
                                     point_source_class, kwargs_numerics=kwargs_numerics)
        image_sim = sim_util.simulate_simple(imageModel, self.kwargs_lens, self.kwargs_source,
                                         self.kwargs_lens_light, self.kwargs_ps)
        data_class.update_data(image_sim)
        self.imageModel = ImageModel(data_class, psf_class, lens_model_class, source_model_class,
                                lens_light_model_class,
                                point_source_class, kwargs_numerics=kwargs_numerics)

        self.psf_fitting = PsfFitting(self.imageModel)

    def test_update_psf(self):
        fwhm = 0.5
        sigma = util.fwhm2sigma(fwhm)
        x_grid, y_grid = util.make_grid(numPix=31, deltapix=0.05)
        from lenstronomy.LightModel.Profiles.gaussian import Gaussian
        gaussian = Gaussian()
        kernel_point_source = gaussian.function(x_grid, y_grid, amp=1., sigma_x=sigma, sigma_y=sigma,
                                                center_x=0, center_y=0)
        kernel_point_source /= np.sum(kernel_point_source)
        kernel_point_source = util.array2image(kernel_point_source)
        kwargs_psf = {'psf_type': 'PIXEL', 'kernel_point_source': kernel_point_source}

        kwargs_psf_iter = {'stacking_method': 'median'}
        kwargs_psf_return, improved_bool, error_map = self.psf_fitting.update_psf(kwargs_psf, self.kwargs_lens, self.kwargs_source,
                                                                       self.kwargs_lens_light, self.kwargs_ps, **kwargs_psf_iter)
        assert improved_bool
        kernel_new = kwargs_psf_return['kernel_point_source']
        kernel_true = self.kwargs_psf['kernel_point_source']
        kernel_old = kwargs_psf['kernel_point_source']
        diff_old = np.sum((kernel_old - kernel_true) ** 2)
        diff_new = np.sum((kernel_new - kernel_true) ** 2)
        assert diff_old > diff_new

    def test_update_iterative(self):
        fwhm = 0.5
        sigma = util.fwhm2sigma(fwhm)
        x_grid, y_grid = util.make_grid(numPix=31, deltapix=0.05)
        from lenstronomy.LightModel.Profiles.gaussian import Gaussian
        gaussian = Gaussian()
        kernel_point_source = gaussian.function(x_grid, y_grid, amp=1., sigma_x=sigma, sigma_y=sigma,
                                              center_x=0, center_y=0)
        kernel_point_source /= np.sum(kernel_point_source)
        kernel_point_source = util.array2image(kernel_point_source)
        kwargs_psf = {'psf_type': 'PIXEL', 'kernel_point_source': kernel_point_source}
        kwargs_psf_iter = {'stacking_method': 'median'}
        kwargs_psf_new = self.psf_fitting.update_iterative(kwargs_psf, self.kwargs_lens, self.kwargs_source,
                                                                       self.kwargs_lens_light, self.kwargs_ps,
                                                           **kwargs_psf_iter)
        kernel_new = kwargs_psf_new['kernel_point_source']
        kernel_true = self.kwargs_psf['kernel_point_source']
        kernel_old = kwargs_psf['kernel_point_source']
        diff_old = np.sum((kernel_old - kernel_true) ** 2)
        diff_new = np.sum((kernel_new - kernel_true) ** 2)
        assert diff_old > diff_new
        assert diff_new < 0.01

        kwargs_psf_new = self.psf_fitting.update_iterative(kwargs_psf, self.kwargs_lens, self.kwargs_source,
                                                           self.kwargs_lens_light, self.kwargs_ps, num_iter=3,
                                                           no_break=True)
        kernel_new = kwargs_psf_new['kernel_point_source']
        kernel_true = self.kwargs_psf['kernel_point_source']
        kernel_old = kwargs_psf['kernel_point_source']
        diff_old = np.sum((kernel_old - kernel_true) ** 2)
        diff_new = np.sum((kernel_new - kernel_true) ** 2)
        assert diff_old > diff_new
        assert diff_new < 0.01

    def test_mask_point_source(self):
        ra_image, dec_image, amp = self.imageModel.PointSource.point_source_list(self.kwargs_ps, self.kwargs_lens)
        print(ra_image, dec_image, amp)
        x_grid, y_grid = self.imageModel.Data.coordinates
        x_grid = util.image2array(x_grid)
        y_grid = util.image2array(y_grid)
        radius = 0.5
        mask_point_source = self.psf_fitting.mask_point_source(ra_image, dec_image, x_grid, y_grid, radius, i=0)
        assert mask_point_source[10, 10] == 1
예제 #7
0
class TestPSFIterationOld(object):
    """
    tests the source model routines
    """
    def setup(self):

        # data specifics
        sigma_bkg = 0.01  # background noise per pixel
        exp_time = 100  # exposure time (arbitrary units, flux per pixel is in units #photons/exp_time unit)
        numPix = 100  # cutout pixel size
        deltaPix = 0.05  # pixel size in arcsec (area per pixel = deltaPix**2)
        fwhm = 0.3  # full width half max of PSF

        # PSF specification

        kwargs_data = sim_util.data_configure_simple(numPix, deltaPix,
                                                     exp_time, sigma_bkg)
        data_class = ImageData(**kwargs_data)
        sigma = util.fwhm2sigma(fwhm)
        x_grid, y_grid = util.make_grid(numPix=31, deltapix=0.05)
        from lenstronomy.LightModel.Profiles.gaussian import Gaussian
        gaussian = Gaussian()
        kernel_point_source = gaussian.function(x_grid,
                                                y_grid,
                                                amp=1.,
                                                sigma=sigma,
                                                center_x=0,
                                                center_y=0)
        kernel_point_source /= np.sum(kernel_point_source)
        kernel_point_source = util.array2image(kernel_point_source)
        psf_error_map = np.zeros_like(kernel_point_source)
        self.kwargs_psf = {
            'psf_type': 'PIXEL',
            'kernel_point_source': kernel_point_source,
            'psf_error_map': psf_error_map
        }

        psf_class = PSF(**self.kwargs_psf)

        # 'EXERNAL_SHEAR': external shear
        kwargs_shear = {
            'gamma1': 0.01,
            'gamma2': 0.01
        }  # gamma_ext: shear strength, psi_ext: shear angel (in radian)
        phi, q = 0.2, 0.8
        e1, e2 = param_util.phi_q2_ellipticity(phi, q)
        kwargs_spemd = {
            'theta_E': 1.,
            'gamma': 1.8,
            'center_x': 0,
            'center_y': 0,
            'e1': e1,
            'e2': e2
        }

        lens_model_list = ['SPEP', 'SHEAR']
        self.kwargs_lens = [kwargs_spemd, kwargs_shear]
        lens_model_class = LensModel(lens_model_list=lens_model_list)
        # list of light profiles (for lens and source)
        # 'SERSIC': spherical Sersic profile
        kwargs_sersic = {
            'amp': 1.,
            'R_sersic': 0.1,
            'n_sersic': 2,
            'center_x': 0,
            'center_y': 0
        }
        # 'SERSIC_ELLIPSE': elliptical Sersic profile
        phi, q = 0.2, 0.9
        e1, e2 = param_util.phi_q2_ellipticity(phi, q)
        kwargs_sersic_ellipse = {
            'amp': 1.,
            'R_sersic': .6,
            'n_sersic': 7,
            'center_x': 0,
            'center_y': 0,
            'e1': e1,
            'e2': e2
        }

        lens_light_model_list = ['SERSIC']
        self.kwargs_lens_light = [kwargs_sersic]
        lens_light_model_class = LightModel(
            light_model_list=lens_light_model_list)
        source_model_list = ['SERSIC_ELLIPSE']
        self.kwargs_source = [kwargs_sersic_ellipse]
        source_model_class = LightModel(light_model_list=source_model_list)
        self.kwargs_ps = [
            {
                'ra_source': 0.0,
                'dec_source': 0.0,
                'source_amp': 10.
            }
        ]  # quasar point source position in the source plane and intrinsic brightness
        point_source_class = PointSource(
            point_source_type_list=['SOURCE_POSITION'],
            fixed_magnification_list=[True])

        kwargs_numerics = {
            'supersampling_factor': 3,
            'supersampling_convolution': False,
            'compute_mode': 'regular',
            'point_source_supersampling_factor': 3
        }
        imageModel = ImageModel(data_class,
                                psf_class,
                                lens_model_class,
                                source_model_class,
                                lens_light_model_class,
                                point_source_class,
                                kwargs_numerics=kwargs_numerics)
        image_sim = sim_util.simulate_simple(imageModel, self.kwargs_lens,
                                             self.kwargs_source,
                                             self.kwargs_lens_light,
                                             self.kwargs_ps)
        data_class.update_data(image_sim)
        self.imageModel = ImageLinearFit(data_class,
                                         psf_class,
                                         lens_model_class,
                                         source_model_class,
                                         lens_light_model_class,
                                         point_source_class,
                                         kwargs_numerics=kwargs_numerics)

        self.psf_fitting = PsfFitting(self.imageModel)
        self.kwargs_params = {
            'kwargs_lens': self.kwargs_lens,
            'kwargs_source': self.kwargs_source,
            'kwargs_lens_light': self.kwargs_lens_light,
            'kwargs_ps': self.kwargs_ps
        }

    def test_update_psf(self):
        fwhm = 0.5
        sigma = util.fwhm2sigma(fwhm)
        x_grid, y_grid = util.make_grid(numPix=31, deltapix=0.05)
        from lenstronomy.LightModel.Profiles.gaussian import Gaussian
        gaussian = Gaussian()
        kernel_point_source = gaussian.function(x_grid,
                                                y_grid,
                                                amp=1.,
                                                sigma=sigma,
                                                center_x=0,
                                                center_y=0)
        kernel_point_source /= np.sum(kernel_point_source)
        kernel_point_source = util.array2image(kernel_point_source)
        kwargs_psf = {
            'psf_type': 'PIXEL',
            'kernel_point_source': kernel_point_source
        }

        kwargs_psf_iter = {
            'stacking_method': 'median',
            'error_map_radius': 0.5,
            'new_procedure': False
        }

        kwargs_psf_return, improved_bool, error_map = self.psf_fitting.update_psf(
            kwargs_psf, self.kwargs_params, **kwargs_psf_iter)
        assert improved_bool
        kernel_new = kwargs_psf_return['kernel_point_source']
        kernel_true = self.kwargs_psf['kernel_point_source']
        kernel_old = kwargs_psf['kernel_point_source']
        diff_old = np.sum((kernel_old - kernel_true)**2)
        diff_new = np.sum((kernel_new - kernel_true)**2)
        assert diff_old > diff_new

    def test_update_iterative(self):
        fwhm = 0.5
        sigma = util.fwhm2sigma(fwhm)
        x_grid, y_grid = util.make_grid(numPix=31, deltapix=0.05)
        from lenstronomy.LightModel.Profiles.gaussian import Gaussian
        gaussian = Gaussian()
        kernel_point_source = gaussian.function(x_grid,
                                                y_grid,
                                                amp=1.,
                                                sigma=sigma,
                                                center_x=0,
                                                center_y=0)
        kernel_point_source /= np.sum(kernel_point_source)
        kernel_point_source = util.array2image(kernel_point_source)
        kwargs_psf = {
            'psf_type': 'PIXEL',
            'kernel_point_source': kernel_point_source,
            'kernel_point_source_init': kernel_point_source
        }
        kwargs_psf_iter = {
            'stacking_method': 'median',
            'psf_symmetry': 2,
            'psf_iter_factor': 0.2,
            'block_center_neighbour': 0.1,
            'error_map_radius': 0.5,
            'new_procedure': False,
            'no_break': False,
            'verbose': True,
            'keep_psf_error_map': False
        }

        kwargs_params = copy.deepcopy(self.kwargs_params)
        kwargs_ps = kwargs_params['kwargs_ps']
        del kwargs_ps[0]['source_amp']
        print(kwargs_params['kwargs_ps'])
        kwargs_psf_new = self.psf_fitting.update_iterative(
            kwargs_psf, kwargs_params, **kwargs_psf_iter)
        kernel_new = kwargs_psf_new['kernel_point_source']
        kernel_true = self.kwargs_psf['kernel_point_source']
        kernel_old = kwargs_psf['kernel_point_source']
        diff_old = np.sum((kernel_old - kernel_true)**2)
        diff_new = np.sum((kernel_new - kernel_true)**2)
        assert diff_old > diff_new
        assert diff_new < 0.01
        assert 'psf_error_map' in kwargs_psf_new

        kwargs_psf_new = self.psf_fitting.update_iterative(
            kwargs_psf,
            kwargs_params,
            num_iter=3,
            no_break=True,
            keep_psf_error_map=True)
        kernel_new = kwargs_psf_new['kernel_point_source']
        kernel_true = self.kwargs_psf['kernel_point_source']
        kernel_old = kwargs_psf['kernel_point_source']
        diff_old = np.sum((kernel_old - kernel_true)**2)
        diff_new = np.sum((kernel_new - kernel_true)**2)
        assert diff_old > diff_new
        assert diff_new < 0.01