def test_source(self):
        multi_band_list = [[
            self.kwargs_data, self.kwargs_psf, self.kwargs_numerics
        ]]
        lensPlot = ModelPlot(multi_band_list,
                             self.kwargs_model,
                             self.kwargs_params,
                             arrow_size=0.02,
                             cmap_string="gist_heat")
        source, coords_source = lensPlot.source(band_index=0,
                                                numPix=10,
                                                deltaPix=0.1,
                                                image_orientation=True)
        assert len(source) == 10

        source, coords_source = lensPlot.source(band_index=0,
                                                numPix=10,
                                                deltaPix=0.1,
                                                image_orientation=False)
        assert len(source) == 10

        source, coords_source = lensPlot.source(band_index=0,
                                                numPix=10,
                                                deltaPix=0.1,
                                                center=[0, 0])
        assert len(source) == 10
 def test_reconstruction_all_bands(self):
     multi_band_list = [[
         self.kwargs_data, self.kwargs_psf, self.kwargs_numerics
     ], [self.kwargs_data, self.kwargs_psf, self.kwargs_numerics]]
     lensPlot = ModelPlot(multi_band_list,
                          self.kwargs_model,
                          self.kwargs_params,
                          arrow_size=0.02,
                          cmap_string="gist_heat",
                          multi_band_type='joint-linear',
                          bands_compute=[True, True])
     f, axes = lensPlot.reconstruction_all_bands()
     assert len(axes) == 2
     assert len(axes[0]) == 3
     plt.close()
    def test_joint_linear(self):
        multi_band_list = [[
            self.kwargs_data, self.kwargs_psf, self.kwargs_numerics
        ], [self.kwargs_data, self.kwargs_psf, self.kwargs_numerics]]
        lensPlot = ModelPlot(multi_band_list,
                             self.kwargs_model,
                             self.kwargs_params,
                             arrow_size=0.02,
                             cmap_string="gist_heat",
                             multi_band_type='joint-linear',
                             bands_compute=[True, False])

        f, ax = plt.subplots(1, 1, figsize=(4, 4))
        ax = lensPlot.data_plot(ax=ax,
                                numPix=10,
                                deltaPix_source=0.1,
                                v_min=None,
                                v_max=None,
                                with_caustics=False,
                                caustic_color='yellow',
                                fsize=15,
                                plot_scale='linear')
        plt.close()

        f, ax = plt.subplots(1, 1, figsize=(4, 4))
        ax = lensPlot.model_plot(ax=ax,
                                 numPix=10,
                                 deltaPix_source=0.1,
                                 v_min=None,
                                 v_max=None,
                                 with_caustics=False,
                                 caustic_color='yellow',
                                 fsize=15,
                                 plot_scale='linear')
        plt.close()

        f, ax = plt.subplots(1, 1, figsize=(4, 4))
        ax = lensPlot.convergence_plot(ax=ax,
                                       numPix=10,
                                       deltaPix_source=0.1,
                                       v_min=None,
                                       v_max=None,
                                       with_caustics=False,
                                       caustic_color='yellow',
                                       fsize=15,
                                       plot_scale='linear')
        plt.close()
        f, ax = plt.subplots(1, 1, figsize=(4, 4))
        ax = lensPlot.normalized_residual_plot(ax=ax)
        plt.close()
        f, ax = plt.subplots(1, 1, figsize=(4, 4))
        ax = lensPlot.magnification_plot(ax=ax)
        plt.close()
        f, ax = plt.subplots(1, 1, figsize=(4, 4))
        ax = lensPlot.decomposition_plot(ax=ax)
        plt.close()
    def test_source_plot(self):
        multi_band_list = [[
            self.kwargs_data, self.kwargs_psf, self.kwargs_numerics
        ]]
        lensPlot = ModelPlot(multi_band_list,
                             self.kwargs_model,
                             self.kwargs_params,
                             arrow_size=0.02,
                             cmap_string="gist_heat")

        f, ax = plt.subplots(1, 1, figsize=(4, 4))
        ax = lensPlot.source_plot(ax=ax,
                                  numPix=10,
                                  deltaPix_source=0.1,
                                  v_min=None,
                                  v_max=None,
                                  with_caustics=False,
                                  caustic_color='yellow',
                                  fsize=15,
                                  plot_scale='linear')
        plt.close()
Beispiel #5
0
    def get_model_plot(self,
                       lens_name,
                       model_id=None,
                       kwargs_result=None,
                       band_index=0,
                       data_cmap='cubehelix'):
        """
        Get the `ModelPlot` instance from lenstronomy for the lens.

        :param lens_name: name of the lens
        :type lens_name: `str`
        :param model_id: model run identifier
        :type model_id: `str`
        :param kwargs_result: lenstronomy `kwargs_result` dictionary. If
            provided, it will be used to plot the model, otherwise the model
            will be plotted from the saved/loaded outputs for `lens_name` and
            `model_id`.
        :type kwargs_result: `dict`
        :param band_index: index of band to plot for multi-band case
        :type band_index: `int`
        :param data_cmap: colormap for image, reconstruction, and source plots
        :type data_cmap: `str` or `matplotlib.colors.Colormap`
        :return: `ModelPlot` instance, maximum pixel value of the image
        :rtype: `obj`, `float`
        """
        if model_id is None and kwargs_result is None:
            raise ValueError('Either the `model_id` or the `kwargs_result` '
                             'needs to be provided!')

        if kwargs_result is None:
            self.load_output(lens_name, model_id)
            kwargs_result = self.kwargs_result

        multi_band_list_out = self.get_kwargs_data_joint(
            lens_name)['multi_band_list']

        config = ModelConfig(settings=self.model_settings)

        mask = config.get_masks()
        kwargs_model = config.get_kwargs_model()

        v_max = np.log10(
            multi_band_list_out[band_index][0]['image_data'].max())

        model_plot = ModelPlot(multi_band_list_out,
                               kwargs_model,
                               kwargs_result,
                               arrow_size=0.02,
                               cmap_string=data_cmap,
                               likelihood_mask_list=mask,
                               multi_band_type='multi-linear')

        return model_plot, v_max
    def test_lensModelPlot(self):
        multi_band_list = [[
            self.kwargs_data, self.kwargs_psf, self.kwargs_numerics
        ]]
        lensPlot = ModelPlot(multi_band_list,
                             self.kwargs_model,
                             self.kwargs_params,
                             arrow_size=0.02,
                             cmap_string="gist_heat",
                             multi_band_type='single-band')

        lensPlot.plot_main(with_caustics=True)
        plt.close()
        cmap = plt.get_cmap('gist_heat')

        lensPlot = ModelPlot(multi_band_list,
                             self.kwargs_model,
                             self.kwargs_params,
                             arrow_size=0.02,
                             cmap_string=cmap)

        lensPlot.plot_separate()
        plt.close()
        lensPlot.plot_subtract_from_data_all()
        plt.close()
        f, ax = plt.subplots(1, 1, figsize=(4, 4))
        lensPlot.deflection_plot(ax=ax, with_caustics=True, axis=1)
        plt.close()

        f, ax = plt.subplots(1, 1, figsize=(4, 4))
        lensPlot.subtract_from_data_plot(ax=ax)
        plt.close()

        f, ax = plt.subplots(1, 1, figsize=(4, 4))
        lensPlot.deflection_plot(ax=ax, with_caustics=True, axis=0)
        plt.close()

        numPix = 100
        deltaPix_source = 0.01
        f, ax = plt.subplots(1, 1, figsize=(4, 4))
        lensPlot.error_map_source_plot(ax=ax,
                                       numPix=numPix,
                                       deltaPix_source=deltaPix_source,
                                       with_caustics=True)
        plt.close()

        f, ax = plt.subplots(1, 1, figsize=(4, 4))
        lensPlot.absolute_residual_plot(ax=ax)
        plt.close()

        f, ax = plt.subplots(1, 1, figsize=(4, 4))
        lensPlot.plot_extinction_map(ax=ax)
        plt.close()
Beispiel #7
0
def fit_galaxy(galaxy_im, psf_ave, psf_std=None, source_params=None, background_rms=0.04, pix_sz = 0.08,
            exp_time = 300., fix_n=None, image_plot = True, corner_plot=True,
            deep_seed = False, galaxy_msk=None, galaxy_std=None, flux_corner_plot = False,
            tag = None, no_MCMC= False, pltshow = 1, return_Chisq = False, dump_result = False, pso_diag=False):
    '''
    A quick fit for the QSO image with (so far) single sersice + one PSF. The input psf noise is optional.
    
    Parameter
    --------
        galaxy_im: An array of the QSO image.
        psf_ave: The psf image.
        psf_std: The psf noise, optional.
        source_params: The prior for the source. Default is given.
        background_rms: default as 0.04
        exp_time: default at 2400.
        deep_seed: if Ture, more mcmc steps will be performed.
        tag: The name tag for save the plot
            
    Return
    --------
        Will output the fitted image (Set image_plot = True), the corner_plot and the flux_ratio_plot.
        source_result, ps_result, image_ps, image_host
    
    To do
    --------
        
    '''
    # data specifics need to set up based on the data situation
    background_rms = background_rms  #  background noise per pixel (Gaussian)
    exp_time = exp_time  #  exposure time (arbitrary units, flux per pixel is in units #photons/exp_time unit)
    numPix = len(galaxy_im)  #  cutout pixel size
    deltaPix = pix_sz
    if psf_ave is not None:
        psf_type = 'PIXEL'  # 'gaussian', 'pixel', 'NONE'
        kernel = psf_ave
    
#    if psf_std is not None:
#        kwargs_numerics = {'subgrid_res': 1, 'psf_error_map': True}     #Turn on the PSF error map
#    else: 
    kwargs_numerics = {'supersampling_factor': 1, 'supersampling_convolution': False}
        
    if source_params is None:
        # here are the options for the host galaxy fitting
        fixed_source = []
        kwargs_source_init = []
        kwargs_source_sigma = []
        kwargs_lower_source = []
        kwargs_upper_source = []
        # Disk component, as modelled by an elliptical Sersic profile
        if fix_n == None:
            fixed_source.append({})  # we fix the Sersic index to n=1 (exponential)
            kwargs_source_init.append({'R_sersic': 0.3, 'n_sersic': 2., 'e1': 0., 'e2': 0., 'center_x': 0., 'center_y': 0.})
            kwargs_source_sigma.append({'n_sersic': 0.5, 'R_sersic': 0.1, 'e1': 0.1, 'e2': 0.1, 'center_x': 0.1, 'center_y': 0.1})
            kwargs_lower_source.append({'e1': -0.5, 'e2': -0.5, 'R_sersic': 0.01, 'n_sersic': 0.3, 'center_x': -10, 'center_y': -10})
            kwargs_upper_source.append({'e1': 0.5, 'e2': 0.5, 'R_sersic': 3., 'n_sersic': 7., 'center_x': 10, 'center_y': 10})
        elif fix_n is not None:
            fixed_source.append({'n_sersic': fix_n})
            kwargs_source_init.append({'R_sersic': 0.3, 'n_sersic': fix_n, 'e1': 0., 'e2': 0., 'center_x': 0., 'center_y': 0.})
            kwargs_source_sigma.append({'n_sersic': 0.001, 'R_sersic': 0.1, 'e1': 0.1, 'e2': 0.1, 'center_x': 0.1, 'center_y': 0.1})
            kwargs_lower_source.append({'e1': -0.5, 'e2': -0.5, 'R_sersic': 0.01, 'n_sersic': fix_n, 'center_x': -10, 'center_y': -10})
            kwargs_upper_source.append({'e1': 0.5, 'e2': 0.5, 'R_sersic': 3, 'n_sersic': fix_n, 'center_x': 10, 'center_y': 10})
        source_params = [kwargs_source_init, kwargs_source_sigma, fixed_source, kwargs_lower_source, kwargs_upper_source]
    else:
        source_params = source_params
    kwargs_params = {'source_model': source_params}
    
    #==============================================================================
    #Doing the QSO fitting 
    #==============================================================================
    kwargs_data = sim_util.data_configure_simple(numPix, deltaPix, exp_time, background_rms, inverse=True)
    data_class = ImageData(**kwargs_data)
    if psf_ave is not None:
        kwargs_psf = {'psf_type': psf_type, 'kernel_point_source': kernel}
    else:
        kwargs_psf =  {'psf_type': 'NONE'}
    
    psf_class = PSF(**kwargs_psf)
    data_class.update_data(galaxy_im)
    
    light_model_list = ['SERSIC_ELLIPSE'] * len(source_params[0])
    lightModel = LightModel(light_model_list=light_model_list)
    
    kwargs_model = { 'source_light_model_list': light_model_list}
    # numerical options and fitting sequences
    kwargs_constraints = {}
    
    kwargs_likelihood = {'check_bounds': True,  #Set the bonds, if exceed, reutrn "penalty"
                         'source_marg': False,  #In likelihood_module.LikelihoodModule -- whether to fully invert the covariance matrix for marginalization
                          'check_positive_flux': True,       
                          'image_likelihood_mask_list': [galaxy_msk]
                         }
    kwargs_data['image_data'] = galaxy_im
    if galaxy_std is not None:
        kwargs_data['noise_map'] = galaxy_std
    if psf_std is not None:
        kwargs_psf['psf_error_map'] = psf_std
                  
    image_band = [kwargs_data, kwargs_psf, kwargs_numerics]
    multi_band_list = [image_band]
    
    kwargs_data_joint = {'multi_band_list': multi_band_list, 'multi_band_type': 'multi-linear'}  # 'single-band', 'multi-linear', 'joint-linear'
    fitting_seq = FittingSequence(kwargs_data_joint, kwargs_model, kwargs_constraints, kwargs_likelihood, kwargs_params)
    
    if deep_seed == False:
        fitting_kwargs_list = [
            ['PSO', {'sigma_scale': 0.8, 'n_particles': 50, 'n_iterations': 50}],
            ['MCMC', {'n_burn': 10, 'n_run': 10, 'walkerRatio': 50, 'sigma_scale': .1}]
            ]            
    elif deep_seed == True:
         fitting_kwargs_list = [
            ['PSO', {'sigma_scale': 0.8, 'n_particles': 100, 'n_iterations': 80}],
            ['MCMC', {'n_burn': 10, 'n_run': 15, 'walkerRatio': 50, 'sigma_scale': .1}]
            ]
    elif deep_seed == 'very_deep':
         fitting_kwargs_list = [
            ['PSO', {'sigma_scale': 0.8, 'n_particles': 150, 'n_iterations': 150}],
            ['MCMC', {'n_burn': 10, 'n_run': 20, 'walkerRatio': 50, 'sigma_scale': .1}]
            ]
    if no_MCMC == True:
        fitting_kwargs_list = [fitting_kwargs_list[0],
                               ]        
    
    start_time = time.time()
    chain_list = fitting_seq.fit_sequence(fitting_kwargs_list)
    kwargs_result = fitting_seq.best_fit()
    ps_result = kwargs_result['kwargs_ps']
    source_result = kwargs_result['kwargs_source']
    
    if no_MCMC == False:
        sampler_type, samples_mcmc, param_mcmc, dist_mcmc  = chain_list[1]      
    
#    chain_list, param_list, samples_mcmc, param_mcmc, dist_mcmc = fitting_seq.fit_sequence(fitting_kwargs_list)
#    lens_result, source_result, lens_light_result, ps_result, cosmo_temp = fitting_seq.best_fit()
    end_time = time.time()
    print(end_time - start_time, 'total time needed for computation')
    print('============ CONGRATULATION, YOUR JOB WAS SUCCESSFUL ================ ')
    # this is the linear inversion. The kwargs will be updated afterwards
    imageModel = ImageModel(data_class, psf_class, source_model_class=lightModel,kwargs_numerics=kwargs_numerics)
    imageLinearFit = ImageLinearFit(data_class=data_class, psf_class=psf_class,
                                       source_model_class=lightModel,
                                       kwargs_numerics=kwargs_numerics)    
    image_reconstructed, error_map, _, _ = imageLinearFit.image_linear_solve(kwargs_source=source_result, kwargs_ps=ps_result)
#    image_host = []   #!!! The linear_solver before and after could have different result for very faint sources.
#    for i in range(len(source_result)):
#        image_host_i = imageModel.source_surface_brightness(source_result,de_lensed=True,unconvolved=False, k=i)
#        print("image_host_i", source_result[i])
#        print("total flux", image_host_i.sum())
#        image_host.append(image_host_i)  
        
    # let's plot the output of the PSO minimizer
    modelPlot = ModelPlot(multi_band_list, kwargs_model, kwargs_result,
                          arrow_size=0.02, cmap_string="gist_heat", likelihood_mask_list=[galaxy_msk])  
    
    if pso_diag == True:
        f, axes = chain_plot.plot_chain_list(chain_list,0)
        if pltshow == 0:
            plt.close()
        else:
            plt.show()
                
    reduced_Chisq =  imageLinearFit.reduced_chi2(image_reconstructed, error_map)
    if image_plot:
        f, axes = plt.subplots(1, 3, figsize=(16, 16), sharex=False, sharey=False)
        modelPlot.data_plot(ax=axes[0])
        modelPlot.model_plot(ax=axes[1])
        modelPlot.normalized_residual_plot(ax=axes[2], v_min=-6, v_max=6)
        f.tight_layout()
        #f.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0., hspace=0.05)
        if tag is not None:
            f.savefig('{0}_fitted_image.pdf'.format(tag))
        if pltshow == 0:
            plt.close()
        else:
            plt.show()
    image_host = []    
    for i in range(len(source_result)):
        image_host_i = imageModel.source_surface_brightness(source_result,de_lensed=True,unconvolved=False, k=i)
#        print("image_host_i", source_result[i])
#        print("total flux", image_host_i.sum())
        image_host.append(image_host_i)  
        
    if corner_plot==True and no_MCMC==False:
        # here the (non-converged) MCMC chain of the non-linear parameters
        if not samples_mcmc == []:
           n, num_param = np.shape(samples_mcmc)
           plot = corner.corner(samples_mcmc, labels=param_mcmc, show_titles=True)
           if tag is not None:
               plot.savefig('{0}_para_corner.pdf'.format(tag))
           if pltshow == 0:
               plt.close()
           else:
               plt.show()
    if flux_corner_plot ==True and no_MCMC==False:
        param = Param(kwargs_model, kwargs_fixed_source=source_params[2], **kwargs_constraints)
        mcmc_new_list = []
        labels_new = ["host{0} flux".format(i) for i in range(len(source_params[0]))]
        for i in range(len(samples_mcmc)):
            kwargs_out = param.args2kwargs(samples_mcmc[i])
            kwargs_light_source_out = kwargs_out['kwargs_source']
            kwargs_ps_out =  kwargs_out['kwargs_ps']
            image_reconstructed, _, _, _ = imageLinearFit.image_linear_solve(kwargs_source=kwargs_light_source_out, kwargs_ps=kwargs_ps_out)
            fluxs = []
            for j in range(len(source_params[0])):
                image_j = imageModel.source_surface_brightness(kwargs_light_source_out,unconvolved= False, k=j)
                fluxs.append(np.sum(image_j))
            mcmc_new_list.append( fluxs )
            if int(i/1000) > int((i-1)/1000) :
                print(len(samples_mcmc), "MCMC samplers in total, finished translate:", i    )
        plot = corner.corner(mcmc_new_list, labels=labels_new, show_titles=True)
        if tag is not None:
            plot.savefig('{0}_HOSTvsQSO_corner.pdf'.format(tag))
        if pltshow == 0:
            plt.close()
        else:
            plt.show() 

    if galaxy_std is None:
        noise_map = np.sqrt(data_class.C_D+np.abs(error_map))
    else:
        noise_map = np.sqrt(galaxy_std**2+np.abs(error_map))   
        
    if dump_result == True:
        if flux_corner_plot==True and no_MCMC==False:
            trans_paras = [source_params[2], mcmc_new_list, labels_new, 'source_params[2], mcmc_new_list, labels_new']
        else:
            trans_paras = []
        picklename= tag + '.pkl'
        best_fit = [source_result, image_host, 'source_result, image_host']
#        pso_fit = [chain_list, param_list, 'chain_list, param_list']
#        mcmc_fit = [samples_mcmc, param_mcmc, dist_mcmc, 'samples_mcmc, param_mcmc, dist_mcmc']
        chain_list_result = [chain_list, 'chain_list']
        pickle.dump([best_fit, chain_list_result, trans_paras], open(picklename, 'wb'))
        
    if return_Chisq == False:
        return source_result, image_host, noise_map
    elif return_Chisq == True:
        return source_result, image_host, noise_map, reduced_Chisq
        )
        fix_setting = [fixed_lens, fixed_source, fixed_lens_light, None, None]
        # sampler_type, samples_mcmc, param_mcmc, dist_mcmc  = chain_list[-1]
        mcmc_new_list = []
        pickle.dump([
            multi_band_list, kwargs_model, kwargs_result, chain_list,
            fix_setting, mcmc_new_list
        ], open(folder + savename, 'wb'))
    #Print fitting result:
    multi_band_list, kwargs_model, kwargs_result, chain_list, fix_setting, _ = pickle.load(
        open(folder + savename, 'rb'))
    fixed_lens, fixed_source, fixed_lens_light, fixed_ps, fixed_cosmo = fix_setting
    labels_new = [r"$\gamma$", r"$e1$", r"$e2$"]
    modelPlot = ModelPlot(multi_band_list,
                          kwargs_model,
                          kwargs_result,
                          arrow_size=0.02,
                          cmap_string="gist_heat")
    f, axes = modelPlot.plot_main()
    f.show()
    f, axes = modelPlot.plot_separate()
    f.show()
    f, axes = modelPlot.plot_subtract_from_data_all()
    f.show()

    sampler_type, samples_mcmc, param_mcmc, dist_mcmc = chain_list[-1]
    # for i in range(len(chain_list)):
    #     chain_plot.plot_chain_list(chain_list, i)

    param = Param(kwargs_model,
                  fixed_lens,
Beispiel #9
0
    def model_plot(self, save_plot=False, show_plot=True):
        """
        Show the fitting plot based on lenstronomy.Plots.model_plot.ModelPlot
        """
        # this is the linear inversion. The kwargs will be updated afterwards
        modelPlot = ModelPlot(
            self.fitting_specify_class.kwargs_data_joint['multi_band_list'],
            self.fitting_specify_class.kwargs_model,
            self.kwargs_result,
            arrow_size=0.02,
            cmap_string="gist_heat",
            likelihood_mask_list=self.fitting_specify_class.
            kwargs_likelihood['image_likelihood_mask_list'])

        f, axes = plt.subplots(3,
                               3,
                               figsize=(16, 16),
                               sharex=False,
                               sharey=False)
        modelPlot.data_plot(ax=axes[0, 0], text="Data")
        modelPlot.model_plot(ax=axes[0, 1])
        modelPlot.normalized_residual_plot(ax=axes[0, 2], v_min=-6, v_max=6)

        modelPlot.decomposition_plot(ax=axes[1, 0],
                                     text='Host galaxy',
                                     lens_light_add=True,
                                     unconvolved=True)
        modelPlot.decomposition_plot(ax=axes[1, 1],
                                     text='Host galaxy convolved',
                                     lens_light_add=True)
        modelPlot.decomposition_plot(ax=axes[1, 2],
                                     text='All components convolved',
                                     source_add=True,
                                     lens_light_add=True,
                                     point_source_add=True)

        modelPlot.subtract_from_data_plot(ax=axes[2, 0],
                                          text='Data - Point Source',
                                          point_source_add=True)
        modelPlot.subtract_from_data_plot(ax=axes[2, 1],
                                          text='Data - host galaxy',
                                          lens_light_add=True)
        modelPlot.subtract_from_data_plot(
            ax=axes[2, 2],
            text='Data - host galaxy - Point Source',
            lens_light_add=True,
            point_source_add=True)
        f.tight_layout()
        if save_plot == True:
            plt.savefig('{0}_model.pdf'.format(self.savename))
        if show_plot == True:
            plt.show()
        else:
            plt.close()
 def plot_modeling(self,
                   kwargs_result,
                   center=[0, 0],
                   deltaPix_s=0.03,
                   numPix_s=None,
                   text_source='',
                   data_index=0,
                   text='sys',
                   img_name='sys',
                   font_size=25,
                   scale_size=0.1,
                   fig_close=False,
                   likelihood_mask_list=None):
     """
     a function to show modeling process containing data, reconstructed image, residual map,
     and reconstructed source.
     :param kwargs_result: modeling results
     :param deltaPix: pixel scale in the source plane
     :param numPix: pixel numbers in the source plane
     :param multi_band_type:string, e.g., 'joint-linear', 'single-band'
     :param text: string, label of reconstructed image
     :param text_source:string, label of reconstructed source
     :param img_name:  string, label of saved images
     :param font_size: font_size
     :return:
     """
     model_plot = ModelPlot(self.multi_band_list,
                            self.kwargs_model,
                            kwargs_result,
                            arrow_size=0.02,
                            cmap_string="gist_heat",
                            multi_band_type=self.multi_band_type,
                            likelihood_mask_list=likelihood_mask_list)
     num_bands = len(self.kwargs_data_joint['multi_band_list'])
     if num_bands > 1:
         f, axes = plt.subplots(num_bands, 3, figsize=(22, 18))
     else:
         f, axes = plt.subplots(num_bands, 3, figsize=(22, 6))
     for band_index in range(num_bands):
         if num_bands > 1:
             ax1 = axes[band_index, 0]
             ax2 = axes[band_index, 1]
             ax3 = axes[band_index, 2]
             img_index = band_index
         else:
             ax1 = axes[0]
             ax2 = axes[1]
             ax3 = axes[2]
             img_index = data_index
         model_plot.data_plot(ax=ax1,
                              band_index=band_index,
                              text='Observed' + text,
                              font_size=font_size)
         model_plot.model_plot(ax=ax2,
                               image_names=True,
                               band_index=band_index,
                               font_size=font_size,
                               text='Modeled' + text)
         model_plot.normalized_residual_plot(ax=ax3,
                                             v_min=-6,
                                             v_max=6,
                                             band_index=band_index,
                                             font_size=font_size)
     f.savefig(img_name + 'residual.pdf', bbox_inches='tight')
     if fig_close:
         plt.close(f)
     if numPix_s is None:
         numPix_s = self.kwargs_data_joint['multi_band_list'][0][0][
             'image_data'].shape[0]
     f_s, axes_s = plt.subplots(1, 1, figsize=(9, 6))
     model_plot.source_plot(ax=axes_s,
                            deltaPix_source=deltaPix_s,
                            numPix=numPix_s,
                            center=center,
                            band_index=band_index,
                            scale_size=scale_size,
                            font_size=font_size,
                            text="Source" + text_source,
                            plot_scale='log',
                            v_min=-5,
                            with_caustics=True)
     f_s.savefig(img_name + 'source.pdf')
     if fig_close:
         plt.close(f_s)
Beispiel #11
0
                          'Ddt_sampling': True
                                  }
    index = folder.split('idx')[1].split('_')[0]
    # save_file = save_pkl_folder+'idx{0}_ID'.format(index)+ID+'_'+savename
    save_file = folder[:-1]
    multi_band_list, kwargs_model, kwargs_result_best, chain_list, fix_setting, mcmc_new_list = pickle.load(open(save_file,'rb'))
    # fixed_lens, fixed_source, fixed_lens_light, fixed_ps, fixed_cosmo = fix_setting
    
    lens_data = np.ones([99,99])
    lens_mask = cr_mask(lens_data, 'normal_mask.reg')    
    framesize = len(multi_band_list[0][0]['image_data'])  #81
    ct = int((len(lens_data) - framesize)/2)
    lens_mask = (1-lens_mask)[ct:-ct,ct:-ct]    

    labels_new = [r"$\gamma$", r"$D_{\Delta t}$","H$_0$" ]
    modelPlot = ModelPlot(multi_band_list, kwargs_model, kwargs_result_best, arrow_size=0.02, cmap_string="gist_heat", 
                          likelihood_mask_list= [lens_mask])
    f, axes = modelPlot.plot_main()
    plt.show()
    # f, axes = modelPlot.plot_separate()_
    # f.show()
    # f, axes = modelPlot.plot_subtract_from_data_all()
    # f.show()
    # multi_band_list = fitting_seq.multi_band_list
    # kwargs_psf_updated = multi_band_list[0][1]
    # f, axes = chain_plot.psf_iteration_compare(kwargs_psf_updated)
    # f.show()

#    for i in range(len(chain_list)):
#        chain_plot.plot_chain_list(chain_list, i)
#    plt.close()
    truths=[para_s[0][0]['gamma'],TD_distance, 73.907]	
def main():
    args = script_utils.parse_inference_args()
    test_cfg = TestConfig.from_file(args.test_config_file_path)
    baobab_cfg = BaobabConfig.from_file(test_cfg.data.test_baobab_cfg_path)
    cfg = TrainValConfig.from_file(test_cfg.train_val_config_file_path)
    # Set device and default data type
    device = torch.device(test_cfg.device_type)
    if device.type == 'cuda':
        torch.set_default_tensor_type('torch.cuda.' + cfg.data.float_type)
    else:
        torch.set_default_tensor_type('torch.' + cfg.data.float_type)
    script_utils.seed_everything(test_cfg.global_seed)

    ############
    # Data I/O #
    ############
    # Define val data and loader
    test_data = XYData(
        is_train=False,
        Y_cols=cfg.data.Y_cols,
        float_type=cfg.data.float_type,
        define_src_pos_wrt_lens=cfg.data.define_src_pos_wrt_lens,
        rescale_pixels=False,
        rescale_pixels_type=None,
        log_pixels=False,
        add_pixel_noise=cfg.data.add_pixel_noise,
        eff_exposure_time={"TDLMC_F160W": test_cfg.data.eff_exposure_time},
        train_Y_mean=np.zeros((1, len(cfg.data.Y_cols))),
        train_Y_std=np.ones((1, len(cfg.data.Y_cols))),
        train_baobab_cfg_path=cfg.data.train_baobab_cfg_path,
        val_baobab_cfg_path=test_cfg.data.test_baobab_cfg_path,
        for_cosmology=True)
    master_truth = test_data.Y_df
    master_truth = metadata_utils.add_qphi_columns(master_truth)
    master_truth = metadata_utils.add_gamma_psi_ext_columns(master_truth)
    # Figure out how many lenses BNN will predict on (must be consecutive)
    if test_cfg.data.lens_indices is None:
        if args.lens_indices_path is None:
            # Test on all n_test lenses in the test set
            n_test = test_cfg.data.n_test
            lens_range = range(n_test)
        else:
            # Test on the lens indices in a text file at the specified path
            lens_range = []
            with open(args.lens_indices_path, "r") as f:
                for line in f:
                    lens_range.append(int(line.strip()))
            n_test = len(lens_range)
            print("Performing H0 inference on {:d} specified lenses...".format(
                n_test))
    else:
        if args.lens_indices_path is None:
            # Test on the lens indices specified in the test config file
            lens_range = test_cfg.data.lens_indices
            n_test = len(lens_range)
            print("Performing H0 inference on {:d} specified lenses...".format(
                n_test))
        else:
            raise ValueError(
                "Specific lens indices were specified in both the test config file and the command-line argument."
            )
    batch_size = max(lens_range) + 1
    test_loader = DataLoader(test_data,
                             batch_size=batch_size,
                             shuffle=False,
                             drop_last=True)
    # Output directory into which the H0 histograms and H0 samples will be saved
    out_dir = test_cfg.out_dir
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
        print("Destination folder path: {:s}".format(out_dir))
    else:
        warnings.warn("Destination folder already exists.")

    ################
    # Compile data #
    ################
    # Image data
    with torch.no_grad():
        for X_, Y_ in test_loader:
            X = X_.to(device)
            break
    X = X.detach().cpu().numpy()

    #############
    # MCMC loop #
    #############
    kwargs_lens_eqn_solver = dict(
        min_distance=0.05,
        search_window=baobab_cfg.instrument['pixel_scale'] *
        baobab_cfg.image['num_pix'],
        num_iter_max=200)
    fm_posterior = ForwardModelingPosterior(
        kwargs_lens_eqn_solver=kwargs_lens_eqn_solver,
        astrometric_sigma=test_cfg.image_position_likelihood.sigma,
        supersampling_factor=baobab_cfg.numerics.supersampling_factor)
    # Get H0 samples for each system
    if not test_cfg.time_delay_likelihood.baobab_time_delays:
        if 'abcd_ordering_i' not in master_truth:
            raise ValueError(
                "If the time delay measurements were not generated using Baobab, the user must specify the order of image positions in which the time delays are listed, in order of increasing dec."
            )

    total_progress = tqdm(total=n_test)
    realized_time_delays = pd.read_csv(
        test_cfg.error_model.realized_time_delays, index_col=None)
    # For each lens system...
    for i, lens_i in enumerate(lens_range):
        ###########################
        # Relevant data and prior #
        ###########################
        data_i = master_truth.iloc[lens_i].copy()
        lcdm = LCDM(z_lens=data_i['z_lens'],
                    z_source=data_i['z_src'],
                    flat=True)
        measured_td_wrt0 = np.array(
            literal_eval(
                realized_time_delays.iloc[lens_i]['measured_td_wrt0']))
        n_img = len(measured_td_wrt0) + 1
        #print(baobab_cfg.survey_object_dict)
        fm_posterior.set_kwargs_data_joint(
            image=X[lens_i, 0, :, :],
            measured_td=measured_td_wrt0,
            measured_td_sigma=test_cfg.time_delay_likelihood.sigma,
            survey_object_dict=baobab_cfg.survey_object_dict,
            eff_exposure_time=test_cfg.data.eff_exposure_time,
        )
        # Update solver according to number of lensed images
        if test_cfg.numerics.solver_type == 'NONE':
            fm_posterior.kwargs_constraints['solver_type'] = 'NONE'
        else:
            fm_posterior.kwargs_constraints[
                'solver_type'] = 'PROFILE_SHEAR' if n_img == 4 else 'ELLIPSE'
        fm_posterior.kwargs_constraints['num_point_source_list'] = [n_img]
        #print(fm_posterior.kwargs_params['point_source_model'][0][0])
        true_D_dt = lcdm.D_dt(H_0=data_i['H0'], Om0=0.3)
        # Pull truth param values and initialize walkers there
        if test_cfg.numerics.initialize_walkers_to_truth:
            fm_posterior.kwargs_lens_init = metadata_utils.get_kwargs_lens_mass(
                data_i)
            fm_posterior.kwargs_lens_light_init = metadata_utils.get_kwargs_lens_light(
                data_i)
            fm_posterior.kwargs_source_init = metadata_utils.get_kwargs_src_light(
                data_i)
            fm_posterior.kwargs_ps_init = metadata_utils.get_kwargs_ps_lensed(
                data_i)
            fm_posterior.kwargs_special_init = dict(D_dt=true_D_dt)

        ###########################
        # MCMC posterior sampling #
        ###########################
        lens_i_start_time = time.time()
        #with script_utils.HiddenPrints():
        chain_list_mcmc, kwargs_result_mcmc = fm_posterior.run_mcmc(
            test_cfg.numerics.mcmc)
        lens_i_end_time = time.time()
        inference_time = (lens_i_end_time - lens_i_start_time) / 60.0  # min

        #############################
        # Plotting the MCMC samples #
        #############################
        # sampler_type : 'EMCEE'
        # samples_mcmc : np.array of shape `[n_mcmc_eval, n_params]`
        # param_mcmc : list of str of length n_params, the parameter names
        sampler_type, samples_mcmc, param_mcmc, _ = chain_list_mcmc[0]
        new_samples_mcmc = mcmc_utils.postprocess_mcmc_chain(
            kwargs_result_mcmc,
            samples_mcmc,
            fm_posterior.kwargs_model,
            fm_posterior.kwargs_params['lens_model'][2],
            fm_posterior.kwargs_params['point_source_model'][2],
            fm_posterior.kwargs_params['source_model'][2],
            fm_posterior.kwargs_params['special'][2],
            fm_posterior.kwargs_constraints,
            kwargs_fixed_lens_light=fm_posterior.
            kwargs_params['lens_light_model'][2],
            verbose=False)
        #from lenstronomy.Plots import chain_plot
        model_plot = ModelPlot(fm_posterior.multi_band_list,
                               fm_posterior.kwargs_model,
                               kwargs_result_mcmc,
                               arrow_size=0.02,
                               cmap_string="gist_heat")
        plotting_utils.plot_forward_modeling_comparisons(model_plot, out_dir)

        # Plot D_dt histogram
        D_dt_samples = new_samples_mcmc[
            'D_dt'].values  # may contain negative values
        data_i['D_dt'] = true_D_dt
        # Export D_dt samples for this lens
        lens_inference_dict = dict(
            D_dt_samples=D_dt_samples,  # kappa_ext=0 for these samples
            inference_time=inference_time,
            true_D_dt=true_D_dt,
        )
        lens_inference_dict_save_path = os.path.join(
            out_dir, 'D_dt_dict_{0:04d}.npy'.format(lens_i))
        np.save(lens_inference_dict_save_path, lens_inference_dict)
        # Optionally export the MCMC samples
        if test_cfg.export.mcmc_samples:
            mcmc_samples_path = os.path.join(
                out_dir, 'mcmc_samples_{0:04d}.csv'.format(lens_i))
            new_samples_mcmc.to_csv(mcmc_samples_path, index=None)
        # Optionally export the D_dt histogram
        if test_cfg.export.D_dt_histogram:
            cleaned_D_dt_samples = h0_utils.remove_outliers_from_lognormal(
                D_dt_samples, 3)
            _ = plotting_utils.plot_D_dt_histogram(cleaned_D_dt_samples,
                                                   lens_i,
                                                   true_D_dt,
                                                   save_dir=out_dir)
        # Optionally export the plot of MCMC chain
        if test_cfg.export.mcmc_chain:
            mcmc_chain_path = os.path.join(
                out_dir, 'mcmc_chain_{0:04d}.png'.format(lens_i))
            plotting_utils.plot_mcmc_chain(chain_list_mcmc, mcmc_chain_path)
        # Optionally export posterior cornerplot of select lens model parameters with D_dt
        if test_cfg.export.mcmc_corner:
            mcmc_corner_path = os.path.join(
                out_dir, 'mcmc_corner_{0:04d}.png'.format(lens_i))
            plotting_utils.plot_mcmc_corner(
                new_samples_mcmc[test_cfg.export.mcmc_cols],
                data_i[test_cfg.export.mcmc_cols],
                test_cfg.export.mcmc_col_labels, mcmc_corner_path)
        total_progress.update(1)
        gc.collect()
    realized_time_delays.to_csv(os.path.join(out_dir,
                                             'realized_time_delays.csv'),
                                index=None)
    total_progress.close()
Beispiel #13
0
def fit_qso_multiband(QSO_im_list, psf_ave_list, psf_std_list=None, source_params=None,ps_param=None,
                      background_rms_list=[0.04]*5, pix_sz = 0.168,
                      exp_time = 300., fix_n=None, image_plot = True, corner_plot=True,
                      flux_ratio_plot=True, deep_seed = False, fixcenter = False, QSO_msk_list=None,
                      QSO_std_list=None, tag = None, no_MCMC= False, pltshow = 1, new_band_seq=None):
    '''
    A quick fit for the QSO image with (so far) single sersice + one PSF. The input psf noise is optional.
    
    Parameter
    --------
        QSO_im: An array of the QSO image.
        psf_ave: The psf image.
        psf_std: The psf noise, optional.
        source_params: The prior for the source. Default is given.
        background_rms: default as 0.04
        exp_time: default at 2400.
        deep_seed: if Ture, more mcmc steps will be performed.
        tag: The name tag for save the plot
            
    Return
    --------
        Will output the fitted image (Set image_plot = True), the corner_plot and the flux_ratio_plot.
        source_result, ps_result, image_ps, image_host
    
    To do
    --------
        
    '''
    # data specifics need to set up based on the data situation
    background_rms_list = background_rms_list  #  background noise per pixel (Gaussian)
    exp_time = exp_time  #  exposure time (arbitrary units, flux per pixel is in units #photons/exp_time unit)
    numPix = len(QSO_im_list[0])  #  cutout pixel size
    deltaPix = pix_sz
    psf_type = 'PIXEL'  # 'gaussian', 'pixel', 'NONE'
    kernel_list = psf_ave_list
    if new_band_seq == None:
        new_band_seq= range(len(QSO_im_list))
    
#    if psf_std_list is not None:
#        kwargs_numerics_list = [{'subgrid_res': 1, 'psf_subgrid': False, 'psf_error_map': True}] * len(QSO_im_list)     #Turn on the PSF error map
#    else: 
    kwargs_numerics_list = [{'supersampling_factor': 1, 'supersampling_convolution': False}] * len(QSO_im_list)
    
    if source_params is None:
        # here are the options for the host galaxy fitting
        fixed_source = []
        kwargs_source_init = []
        kwargs_source_sigma = []
        kwargs_lower_source = []
        kwargs_upper_source = []
        
        # Disk component, as modelled by an elliptical Sersic profile
        if fix_n == None:
            fixed_source.append({})  # we fix the Sersic index to n=1 (exponential)
            kwargs_source_init.append({'R_sersic': 0.3, 'n_sersic': 2., 'e1': 0., 'e2': 0., 'center_x': 0., 'center_y': 0.})
            kwargs_source_sigma.append({'n_sersic': 0.5, 'R_sersic': 0.5, 'e1': 0.1, 'e2': 0.1, 'center_x': 0.1, 'center_y': 0.1})
            kwargs_lower_source.append({'e1': -0.5, 'e2': -0.5, 'R_sersic': 0.1, 'n_sersic': 0.3, 'center_x': -10, 'center_y': -10})
            kwargs_upper_source.append({'e1': 0.5, 'e2': 0.5, 'R_sersic': 3., 'n_sersic': 7., 'center_x': 10, 'center_y': 10})
        elif fix_n is not None:
            fixed_source.append({'n_sersic': fix_n})
            kwargs_source_init.append({'R_sersic': 0.3, 'n_sersic': fix_n, 'e1': 0., 'e2': 0., 'center_x': 0., 'center_y': 0.})
            kwargs_source_sigma.append({'n_sersic': 0.001, 'R_sersic': 0.5, 'e1': 0.1, 'e2': 0.1, 'center_x': 0.1, 'center_y': 0.1})
            kwargs_lower_source.append({'e1': -0.5, 'e2': -0.5, 'R_sersic': 0.1, 'n_sersic': fix_n, 'center_x': -10, 'center_y': -10})
            kwargs_upper_source.append({'e1': 0.5, 'e2': 0.5, 'R_sersic': 3, 'n_sersic': fix_n, 'center_x': 10, 'center_y': 10})
        source_params = [kwargs_source_init, kwargs_source_sigma, fixed_source, kwargs_lower_source, kwargs_upper_source]
    else:
        source_params = source_params
    
    if ps_param is None:
        center_x = 0.0
        center_y = 0.0
        point_amp = QSO_im_list[0].sum()/2.
        fixed_ps = [{}]
        kwargs_ps = [{'ra_image': [center_x], 'dec_image': [center_y], 'point_amp': [point_amp]}]
        kwargs_ps_init = kwargs_ps
        kwargs_ps_sigma = [{'ra_image': [0.01], 'dec_image': [0.01]}]
        kwargs_lower_ps = [{'ra_image': [-10], 'dec_image': [-10]}]
        kwargs_upper_ps = [{'ra_image': [10], 'dec_image': [10]}]
        ps_param = [kwargs_ps_init, kwargs_ps_sigma, fixed_ps, kwargs_lower_ps, kwargs_upper_ps]
    else:
        ps_param = ps_param
    
    kwargs_params = {'source_model': source_params,
                     'point_source_model': ps_param}
    
    #==============================================================================
    #Doing the QSO fitting 
    #==============================================================================
    kwargs_data_list, data_class_list = [], []
    for i in range(len(QSO_im_list)):
        kwargs_data_i = sim_util.data_configure_simple(numPix, deltaPix, exp_time, background_rms_list[i], inverse=True)
        kwargs_data_list.append(kwargs_data_i)
        data_class_list.append(ImageData(**kwargs_data_i))
    kwargs_psf_list = []
    psf_class_list = []
    for i in range(len(QSO_im_list)):
        kwargs_psf_i = {'psf_type': psf_type, 'kernel_point_source': kernel_list[i]}
        kwargs_psf_list.append(kwargs_psf_i)
        psf_class_list.append(PSF(**kwargs_psf_i))
        data_class_list[i].update_data(QSO_im_list[i])
    
    light_model_list = ['SERSIC_ELLIPSE'] * len(source_params[0])
    lightModel = LightModel(light_model_list=light_model_list)
    point_source_list = ['UNLENSED']
    pointSource = PointSource(point_source_type_list=point_source_list)
    
    imageModel_list = []
    for i in range(len(QSO_im_list)):
        kwargs_data_list[i]['image_data'] = QSO_im_list[i]
#        if QSO_msk_list is not None:
#            kwargs_numerics_list[i]['mask'] = QSO_msk_list[i]
        if QSO_std_list is not None:
            kwargs_data_list[i]['noise_map'] = QSO_std_list[i]
#        if psf_std_list is not None:
#            kwargs_psf_list[i]['psf_error_map'] = psf_std_list[i]
    
    image_band_list = []
    for i in range(len(QSO_im_list)):
        imageModel_list.append(ImageModel(data_class_list[i], psf_class_list[i], source_model_class=lightModel,
                                        point_source_class=pointSource, kwargs_numerics=kwargs_numerics_list[i]))
                  
        
        image_band_list.append([kwargs_data_list[i], kwargs_psf_list[i], kwargs_numerics_list[i]])
    multi_band_list = [image_band_list[i] for i in range(len(QSO_im_list))]
    
    # numerical options and fitting sequences
    
    kwargs_model = { 'source_light_model_list': light_model_list,
                    'point_source_model_list': point_source_list
                    }
    
    if fixcenter == False:
        kwargs_constraints = {'num_point_source_list': [1]
                              }
    elif fixcenter == True:
        kwargs_constraints = {'joint_source_with_point_source': [[0, 0]],
                              'num_point_source_list': [1]
                              }
    
    kwargs_likelihood = {'check_bounds': True,  #Set the bonds, if exceed, reutrn "penalty"
                         'source_marg': False,  #In likelihood_module.LikelihoodModule -- whether to fully invert the covariance matrix for marginalization
                          'check_positive_flux': True,       
                          'image_likelihood_mask_list': [QSO_msk_list]
                         }
    
#    mpi = False  # MPI possible, but not supported through that notebook.
    # The Params for the fitting. kwargs_init: initial input. kwargs_sigma: The parameter uncertainty. kwargs_fixed: fixed parameters;
    #kwargs_lower,kwargs_upper: Lower and upper limits.
    kwargs_data_joint = {'multi_band_list': multi_band_list, 'multi_band_type': 'multi-linear'}  # 'single-band', 'multi-linear', 'joint-linear'
    fitting_seq = FittingSequence(kwargs_data_joint, kwargs_model, kwargs_constraints, kwargs_likelihood, kwargs_params)
    
    if deep_seed == False:
        fitting_kwargs_list = [
            ['PSO', {'sigma_scale': 0.8, 'n_particles': 80, 'n_iterations': 60, 'compute_bands': [True]+[False]*(len(QSO_im_list)-1)}],
            ['align_images', {'n_particles': 10, 'n_iterations': 10, 'compute_bands': [False]+[True]*(len(QSO_im_list)-1)}],
            ['PSO', {'sigma_scale': 0.8, 'n_particles': 100, 'n_iterations': 200, 'compute_bands': [True]*len(QSO_im_list)}],
            ['MCMC', {'n_burn': 10, 'n_run': 20, 'walkerRatio': 50, 'sigma_scale': .1}]              
            ]
    elif deep_seed == True:
         fitting_kwargs_list = [
            ['PSO', {'sigma_scale': 0.8, 'n_particles': 150, 'n_iterations': 60, 'compute_bands': [True]+[False]*(len(QSO_im_list)-1)}],
            ['align_images', {'n_particles': 20, 'n_iterations': 20, 'compute_bands': [False]+[True]*(len(QSO_im_list)-1)}],
            ['PSO', {'sigma_scale': 0.8, 'n_particles': 150, 'n_iterations': 200, 'compute_bands': [True]*len(QSO_im_list)}],
            ['MCMC', {'n_burn': 20, 'n_run': 40, 'walkerRatio': 50, 'sigma_scale': .1}]                 
            ]
    if no_MCMC == True:
        del fitting_kwargs_list[-1]
    
    start_time = time.time()
#    lens_result, source_result, lens_light_result, ps_result, cosmo_temp, chain_list, param_list, samples_mcmc, param_mcmc, dist_mcmc = fitting_seq.fit_sequence(fitting_kwargs_list)
    chain_list, param_list, samples_mcmc, param_mcmc, dist_mcmc = fitting_seq.fit_sequence(fitting_kwargs_list)
    lens_result, source_result, lens_light_result, ps_result, cosmo_temp = fitting_seq.best_fit()    
    end_time = time.time()
    print(end_time - start_time, 'total time needed for computation')
    print('============ CONGRATULATION, YOUR JOB WAS SUCCESSFUL ================ ')
    source_result_list, ps_result_list = [], []
    image_reconstructed_list, error_map_list, image_ps_list, image_host_list, shift_RADEC_list=[], [], [], [],[]
    imageLinearFit_list = []
    for k in range(len(QSO_im_list)):
    # this is the linear inversion. The kwargs will be updated afterwards
        imageLinearFit_k = ImageLinearFit(data_class_list[k], psf_class_list[k], source_model_class=lightModel,
                                        point_source_class=pointSource, kwargs_numerics=kwargs_numerics_list[k])  
        image_reconstructed_k, error_map_k, _, _ = imageLinearFit_k.image_linear_solve(kwargs_source=source_result, kwargs_ps=ps_result)
        imageLinearFit_list.append(imageLinearFit_k) 
        
        [kwargs_data_k, kwargs_psf_k, kwargs_numerics_k] = fitting_seq.multi_band_list[k]
#        data_class_k = data_class_list[k] #ImageData(**kwargs_data_k)
#        psf_class_k = psf_class_list[k] #PSF(**kwargs_psf_k)
#        imageModel_k = ImageModel(data_class_k, psf_class_k, source_model_class=lightModel,
#                                point_source_class=pointSource, kwargs_numerics=kwargs_numerics_list[k])
        imageModel_k = imageModel_list[k]
        modelPlot = ModelPlot(multi_band_list[k], kwargs_model, lens_result, source_result,
                                 lens_light_result, ps_result, arrow_size=0.02, cmap_string="gist_heat", likelihood_mask=QSO_im_list[k])
        print("source_result", 'for', "k", source_result)
        image_host_k = []
        for i in range(len(source_result)):
            image_host_k.append(imageModel_list[k].source_surface_brightness(source_result,de_lensed=True,unconvolved=False, k=i))
        image_ps_k = imageModel_k.point_source(ps_result)
        # let's plot the output of the PSO minimizer
        
        image_reconstructed_list.append(image_reconstructed_k)
        source_result_list.append(source_result)
        ps_result_list.append(ps_result)
        error_map_list.append(error_map_k)
        image_ps_list.append(image_ps_k)
        image_host_list.append(image_host_k)
        if 'ra_shift' in fitting_seq.multi_band_list[k][0].keys():
            shift_RADEC_list.append([fitting_seq.multi_band_list[k][0]['ra_shift'], fitting_seq.multi_band_list[k][0]['dec_shift']])
        else:
            shift_RADEC_list.append([0,0])
        if image_plot:
            f, axes = plt.subplots(3, 3, figsize=(16, 16), sharex=False, sharey=False)
            modelPlot.data_plot(ax=axes[0,0], text="Data")
            modelPlot.model_plot(ax=axes[0,1])
            modelPlot.normalized_residual_plot(ax=axes[0,2], v_min=-6, v_max=6)
            
            modelPlot.decomposition_plot(ax=axes[1,0], text='Host galaxy', source_add=True, unconvolved=True)
            modelPlot.decomposition_plot(ax=axes[1,1], text='Host galaxy convolved', source_add=True)
            modelPlot.decomposition_plot(ax=axes[1,2], text='All components convolved', source_add=True, lens_light_add=True, point_source_add=True)
            
            modelPlot.subtract_from_data_plot(ax=axes[2,0], text='Data - Point Source', point_source_add=True)
            modelPlot.subtract_from_data_plot(ax=axes[2,1], text='Data - host galaxy', source_add=True)
            modelPlot.subtract_from_data_plot(ax=axes[2,2], text='Data - host galaxy - Point Source', source_add=True, point_source_add=True)
            f.tight_layout()
            if tag is not None:
                f.savefig('{0}_fitted_image_band{1}.pdf'.format(tag,new_band_seq[k]))
            if pltshow == 0:
                plt.close()
            else:
                plt.show()
            
            if corner_plot==True and no_MCMC==False and k ==0:
                # here the (non-converged) MCMC chain of the non-linear parameters
                if not samples_mcmc == []:
                   n, num_param = np.shape(samples_mcmc)
                   plot = corner.corner(samples_mcmc, labels=param_mcmc, show_titles=True)
                   if tag is not None:
                       plot.savefig('{0}_para_corner.pdf'.format(tag))
                   if pltshow == 0:
                       plt.close()
                   else:
                       plt.show()
            if flux_ratio_plot==True and no_MCMC==False:
                param = Param(kwargs_model, kwargs_fixed_source=source_params[2], kwargs_fixed_ps=fixed_ps, **kwargs_constraints)
                mcmc_new_list = []
                labels_new = [r"Quasar flux", r"host_flux", r"source_x", r"source_y"]
                # transform the parameter position of the MCMC chain in a lenstronomy convention with keyword arguments #
                for i in range(len(samples_mcmc)/10):
                    kwargs_lens_out, kwargs_light_source_out, kwargs_light_lens_out, kwargs_ps_out, kwargs_cosmo = param.getParams(samples_mcmc[i+ len(samples_mcmc)/10*9])
                    image_reconstructed, _, _, _ = imageLinearFit_list[k].image_linear_solve(kwargs_source=kwargs_light_source_out, kwargs_ps=kwargs_ps_out)
                    
                    image_ps = imageModel_list[k].point_source(kwargs_ps_out)
                    flux_quasar = np.sum(image_ps)
                    image_disk = imageModel_list[k].source_surface_brightness(kwargs_light_source_out,de_lensed=True,unconvolved=False, k=0)
                    flux_disk = np.sum(image_disk)
                    source_x = kwargs_ps_out[0]['ra_image']
                    source_y = kwargs_ps_out[0]['dec_image']
                    if flux_disk>0:
                        mcmc_new_list.append([flux_quasar, flux_disk, source_x, source_y])
                plot = corner.corner(mcmc_new_list, labels=labels_new, show_titles=True)
                if tag is not None:
                    plot.savefig('{0}_HOSTvsQSO_corner_band{1}.pdf'.format(tag,new_band_seq[k]))
                if pltshow == 0:
                    plt.close()
                else:
                    plt.show()
    errp_list = []
    for k in range(len(QSO_im_list)):
        if QSO_std_list is None:
            errp_list.append(np.sqrt(data_class_list[k].C_D+np.abs(error_map_list[k])))
        else:
            errp_list.append(np.sqrt(QSO_std_list[k]**2+np.abs(error_map_list[k])))
    return source_result_list, ps_result_list, image_ps_list, image_host_list, errp_list, shift_RADEC_list, fitting_seq     #fitting_seq.multi_band_list
Beispiel #14
0
def fit_qso(QSO_im, psf_ave, psf_std=None, source_params=None,ps_param=None, background_rms=0.04, pix_sz = 0.168,
            exp_time = 300., fix_n=None, image_plot = True, corner_plot=True, supersampling_factor = 2, 
            flux_ratio_plot=False, deep_seed = False, fixcenter = False, QSO_msk=None, QSO_std=None,
            tag = None, no_MCMC= False, pltshow = 1, return_Chisq = False, dump_result = False, pso_diag=False):
    '''
    A quick fit for the QSO image with (so far) single sersice + one PSF. The input psf noise is optional.
    
    Parameter
    --------
        QSO_im: An array of the QSO image.
        psf_ave: The psf image.
        psf_std: The psf noise, optional.
        source_params: The prior for the source. Default is given. If [], means no Sersic light.
        background_rms: default as 0.04
        exp_time: default at 2400.
        deep_seed: if Ture, more mcmc steps will be performed.
        tag: The name tag for save the plot
            
    Return
    --------
        Will output the fitted image (Set image_plot = True), the corner_plot and the flux_ratio_plot.
        source_result, ps_result, image_ps, image_host
    
    To do
    --------
        
    '''
    # data specifics need to set up based on the data situation
    background_rms = background_rms  #  background noise per pixel (Gaussian)
    exp_time = exp_time  #  exposure time (arbitrary units, flux per pixel is in units #photons/exp_time unit)
    numPix = len(QSO_im)  #  cutout pixel size
    deltaPix = pix_sz
    psf_type = 'PIXEL'  # 'gaussian', 'pixel', 'NONE'
    kernel = psf_ave

    kwargs_numerics = {'supersampling_factor': supersampling_factor, 'supersampling_convolution': False} 
    
    if source_params is None:
        # here are the options for the host galaxy fitting
        fixed_source = []
        kwargs_source_init = []
        kwargs_source_sigma = []
        kwargs_lower_source = []
        kwargs_upper_source = []
        
        if fix_n == None:
            fixed_source.append({})  # we fix the Sersic index to n=1 (exponential)
            kwargs_source_init.append({'R_sersic': 0.3, 'n_sersic': 2., 'e1': 0., 'e2': 0., 'center_x': 0., 'center_y': 0.})
            kwargs_source_sigma.append({'n_sersic': 0.5, 'R_sersic': 0.5, 'e1': 0.1, 'e2': 0.1, 'center_x': 0.1, 'center_y': 0.1})
            kwargs_lower_source.append({'e1': -0.5, 'e2': -0.5, 'R_sersic': 0.1, 'n_sersic': 0.3, 'center_x': -10, 'center_y': -10})
            kwargs_upper_source.append({'e1': 0.5, 'e2': 0.5, 'R_sersic': 3., 'n_sersic': 7., 'center_x': 10, 'center_y': 10})
        elif fix_n is not None:
            fixed_source.append({'n_sersic': fix_n})
            kwargs_source_init.append({'R_sersic': 0.3, 'n_sersic': fix_n, 'e1': 0., 'e2': 0., 'center_x': 0., 'center_y': 0.})
            kwargs_source_sigma.append({'n_sersic': 0.001, 'R_sersic': 0.5, 'e1': 0.1, 'e2': 0.1, 'center_x': 0.1, 'center_y': 0.1})
            kwargs_lower_source.append({'e1': -0.5, 'e2': -0.5, 'R_sersic': 0.1, 'n_sersic': fix_n, 'center_x': -10, 'center_y': -10})
            kwargs_upper_source.append({'e1': 0.5, 'e2': 0.5, 'R_sersic': 3, 'n_sersic': fix_n, 'center_x': 10, 'center_y': 10})
        source_params = [kwargs_source_init, kwargs_source_sigma, fixed_source, kwargs_lower_source, kwargs_upper_source]
    else:
        source_params = source_params
    
    if ps_param is None:
        center_x = 0.0
        center_y = 0.0
        point_amp = QSO_im.sum()/2.
        fixed_ps = [{}]
        kwargs_ps = [{'ra_image': [center_x], 'dec_image': [center_y], 'point_amp': [point_amp]}]
        kwargs_ps_init = kwargs_ps
        kwargs_ps_sigma = [{'ra_image': [0.05], 'dec_image': [0.05]}]
        kwargs_lower_ps = [{'ra_image': [-0.6], 'dec_image': [-0.6]}]
        kwargs_upper_ps = [{'ra_image': [0.6], 'dec_image': [0.6]}]
        ps_param = [kwargs_ps_init, kwargs_ps_sigma, fixed_ps, kwargs_lower_ps, kwargs_upper_ps]
    else:
        ps_param = ps_param
    
    #==============================================================================
    #Doing the QSO fitting 
    #==============================================================================
    kwargs_data = sim_util.data_configure_simple(numPix, deltaPix, exp_time, background_rms, inverse=True)
    data_class = ImageData(**kwargs_data)
    kwargs_psf = {'psf_type': psf_type, 'kernel_point_source': kernel}
    psf_class = PSF(**kwargs_psf)
    data_class.update_data(QSO_im)
    
    point_source_list = ['UNLENSED'] * len(ps_param[0])
    pointSource = PointSource(point_source_type_list=point_source_list)
    
    if fixcenter == False:
        kwargs_constraints = {'num_point_source_list': [1] * len(ps_param[0])
                              }
    elif fixcenter == True:
        kwargs_constraints = {'joint_source_with_point_source': [[i, i] for i in range(len(ps_param[0]))],
                              'num_point_source_list': [1] * len(ps_param[0])
                              }
    
    
    if source_params == []:   #fitting image as Point source only.
        kwargs_params = {'point_source_model': ps_param}
        lightModel = None
        kwargs_model = {'point_source_model_list': point_source_list }
        imageModel = ImageModel(data_class, psf_class, point_source_class=pointSource, kwargs_numerics=kwargs_numerics)
        kwargs_likelihood = {'check_bounds': True,  #Set the bonds, if exceed, reutrn "penalty"
                             'image_likelihood_mask_list': [QSO_msk]
                     }
    elif source_params != []:
        kwargs_params = {'source_model': source_params,
                 'point_source_model': ps_param}

        light_model_list = ['SERSIC_ELLIPSE'] * len(source_params[0])
        lightModel = LightModel(light_model_list=light_model_list)
        kwargs_model = { 'source_light_model_list': light_model_list,
                        'point_source_model_list': point_source_list
                        }
        imageModel = ImageModel(data_class, psf_class, source_model_class=lightModel,
                                point_source_class=pointSource, kwargs_numerics=kwargs_numerics)
        # numerical options and fitting sequences
        kwargs_likelihood = {'check_bounds': True,  #Set the bonds, if exceed, reutrn "penalty"
                             'source_marg': False,  #In likelihood_module.LikelihoodModule -- whether to fully invert the covariance matrix for marginalization
                              'check_positive_flux': True, 
                              'image_likelihood_mask_list': [QSO_msk]
                             }
    
    kwargs_data['image_data'] = QSO_im
    if QSO_std is not None:
        kwargs_data['noise_map'] = QSO_std
    
    if psf_std is not None:
        kwargs_psf['psf_error_map'] = psf_std
    image_band = [kwargs_data, kwargs_psf, kwargs_numerics]
    multi_band_list = [image_band]

    kwargs_data_joint = {'multi_band_list': multi_band_list, 'multi_band_type': 'multi-linear'}  # 'single-band', 'multi-linear', 'joint-linear'
    fitting_seq = FittingSequence(kwargs_data_joint, kwargs_model, kwargs_constraints, kwargs_likelihood, kwargs_params)
    
    if deep_seed == False:
        fitting_kwargs_list = [
             ['PSO', {'sigma_scale': 0.8, 'n_particles': 100, 'n_iterations': 60}],
             ['MCMC', {'n_burn': 10, 'n_run': 10, 'walkerRatio': 50, 'sigma_scale': .1}]
            ]
    elif deep_seed == True:
         fitting_kwargs_list = [
             ['PSO', {'sigma_scale': 0.8, 'n_particles': 250, 'n_iterations': 250}],
             ['MCMC', {'n_burn': 100, 'n_run': 200, 'walkerRatio': 10, 'sigma_scale': .1}]
            ]
    if no_MCMC == True:
        fitting_kwargs_list = [fitting_kwargs_list[0],
                               ]        

    start_time = time.time()
    chain_list = fitting_seq.fit_sequence(fitting_kwargs_list)
    kwargs_result = fitting_seq.best_fit()
    ps_result = kwargs_result['kwargs_ps']
    source_result = kwargs_result['kwargs_source']
    if no_MCMC == False:
        sampler_type, samples_mcmc, param_mcmc, dist_mcmc  = chain_list[1]    
    
    end_time = time.time()
    print(end_time - start_time, 'total time needed for computation')
    print('============ CONGRATULATION, YOUR JOB WAS SUCCESSFUL ================ ')
    imageLinearFit = ImageLinearFit(data_class=data_class, psf_class=psf_class,
                                    source_model_class=lightModel,
                                    point_source_class=pointSource, 
                                    kwargs_numerics=kwargs_numerics)    
    image_reconstructed, error_map, _, _ = imageLinearFit.image_linear_solve(kwargs_source=source_result, kwargs_ps=ps_result)
    # this is the linear inversion. The kwargs will be updated afterwards
    modelPlot = ModelPlot(multi_band_list, kwargs_model, kwargs_result,
                          arrow_size=0.02, cmap_string="gist_heat", likelihood_mask_list=[QSO_msk])
    image_host = []  #!!! The linear_solver before and after LensModelPlot could have different result for very faint sources.
    for i in range(len(source_result)):
        image_host.append(imageModel.source_surface_brightness(source_result, de_lensed=True,unconvolved=False,k=i))
    
    image_ps = []
    for i in range(len(ps_result)):
        image_ps.append(imageModel.point_source(ps_result, k = i))
    
    if pso_diag == True:
        f, axes = chain_plot.plot_chain_list(chain_list,0)
        if pltshow == 0:
            plt.close()
        else:
            plt.show()

    # let's plot the output of the PSO minimizer
    reduced_Chisq =  imageLinearFit.reduced_chi2(image_reconstructed, error_map)
    if image_plot:
        f, axes = plt.subplots(3, 3, figsize=(16, 16), sharex=False, sharey=False)
        modelPlot.data_plot(ax=axes[0,0], text="Data")
        modelPlot.model_plot(ax=axes[0,1])
        modelPlot.normalized_residual_plot(ax=axes[0,2], v_min=-6, v_max=6)
        
        modelPlot.decomposition_plot(ax=axes[1,0], text='Host galaxy', source_add=True, unconvolved=True)
        modelPlot.decomposition_plot(ax=axes[1,1], text='Host galaxy convolved', source_add=True)
        modelPlot.decomposition_plot(ax=axes[1,2], text='All components convolved', source_add=True, lens_light_add=True, point_source_add=True)
        
        modelPlot.subtract_from_data_plot(ax=axes[2,0], text='Data - Point Source', point_source_add=True)
        modelPlot.subtract_from_data_plot(ax=axes[2,1], text='Data - host galaxy', source_add=True)
        modelPlot.subtract_from_data_plot(ax=axes[2,2], text='Data - host galaxy - Point Source', source_add=True, point_source_add=True)
        
        f.tight_layout()
        #f.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0., hspace=0.05)
        if tag is not None:
            f.savefig('{0}_fitted_image.pdf'.format(tag))
        if pltshow == 0:
            plt.close()
        else:
            plt.show()
        
    if corner_plot==True and no_MCMC==False:
        # here the (non-converged) MCMC chain of the non-linear parameters
        if not samples_mcmc == []:
           n, num_param = np.shape(samples_mcmc)
           plot = corner.corner(samples_mcmc, labels=param_mcmc, show_titles=True)
           if tag is not None:
               plot.savefig('{0}_para_corner.pdf'.format(tag))
           plt.close()               
           # if pltshow == 0:
           #     plt.close()
           # else:
           #     plt.show()
        
    if flux_ratio_plot==True and no_MCMC==False:
        param = Param(kwargs_model, kwargs_fixed_source=source_params[2], kwargs_fixed_ps=ps_param[2], **kwargs_constraints)
        mcmc_new_list = []
        if len(ps_param[2]) == 1:
            labels_new = ["Quasar flux"] +  ["host{0} flux".format(i) for i in range(len(source_params[0]))]
        else:
            labels_new = ["Quasar{0} flux".format(i) for i in range(len(ps_param[2]))] +  ["host{0} flux".format(i) for i in range(len(source_params[0]))]
        if len(samples_mcmc) > 10000:
            trans_steps = [len(samples_mcmc)-10000, len(samples_mcmc)]
        else:
            trans_steps = [0, len(samples_mcmc)]
        for i in range(trans_steps[0], trans_steps[1]):
            kwargs_out = param.args2kwargs(samples_mcmc[i])
            kwargs_light_source_out = kwargs_out['kwargs_source']
            kwargs_ps_out =  kwargs_out['kwargs_ps']
            image_reconstructed, _, _, _ = imageLinearFit.image_linear_solve(kwargs_source=kwargs_light_source_out, kwargs_ps=kwargs_ps_out)
            flux_quasar = []
            if len(ps_param[0]) == 1:
                image_ps_j = imageModel.point_source(kwargs_ps_out)
                flux_quasar.append(np.sum(image_ps_j))  
            else:    
                for j in range(len(ps_param[0])):
                    image_ps_j = imageModel.point_source(kwargs_ps_out, k=j)
                    flux_quasar.append(np.sum(image_ps_j))
            fluxs = []
            for j in range(len(source_params[0])):
                image_j = imageModel.source_surface_brightness(kwargs_light_source_out,unconvolved= False, k=j)
                fluxs.append(np.sum(image_j))
            mcmc_new_list.append(flux_quasar + fluxs )
            if int(i/1000) > int((i-1)/1000) :
                print(len(samples_mcmc), "MCMC samplers in total, finished translate:", i )
        plot = corner.corner(mcmc_new_list, labels=labels_new, show_titles=True)
        if tag is not None:
            plot.savefig('{0}_HOSTvsQSO_corner.pdf'.format(tag))
        if pltshow == 0:
            plt.close()
        else:
            plt.show()
    if QSO_std is None:
        noise_map = np.sqrt(data_class.C_D+np.abs(error_map))
    else:
        noise_map = np.sqrt(QSO_std**2+np.abs(error_map))
    if dump_result == True:
        if flux_ratio_plot==True and no_MCMC==False:
            trans_paras = [mcmc_new_list, labels_new, 'mcmc_new_list, labels_new']
        else:
            trans_paras = []
        picklename= tag + '.pkl'
        best_fit = [source_result, image_host, ps_result, image_ps,'source_result, image_host, ps_result, image_ps']
        chain_list_result = [chain_list, 'chain_list']
        kwargs_fixed_source=source_params[2]
        kwargs_fixed_ps=ps_param[2]
        classes = data_class, psf_class, lightModel, pointSource
        material = multi_band_list, kwargs_model, kwargs_result, QSO_msk, kwargs_fixed_source, kwargs_fixed_ps, kwargs_constraints, kwargs_numerics, classes
        pickle.dump([best_fit, chain_list_result, trans_paras, material], open(picklename, 'wb'))
    if return_Chisq == False:
        return source_result, ps_result, image_ps, image_host, noise_map
    elif return_Chisq == True:
        return source_result, ps_result, image_ps, image_host, noise_map, reduced_Chisq
Beispiel #15
0
    def plot_final_qso_fit(self,
                           if_annuli=False,
                           show_plot=True,
                           arrows=False,
                           save_plot=False,
                           target_ID=None):
        """
        Plot the compact fitting result, if a QSO is fitted.
        """
        data = self.fitting_specify_class.kwargs_data['image_data']
        if 'psf_error_map' in self.fitting_specify_class.kwargs_psf.keys():
            modelPlot = ModelPlot(
                self.fitting_specify_class.
                kwargs_data_joint['multi_band_list'],
                self.fitting_specify_class.kwargs_model,
                self.kwargs_result,
                arrow_size=0.02,
                cmap_string="gist_heat",
                likelihood_mask_list=self.fitting_specify_class.
                kwargs_likelihood['image_likelihood_mask_list'])
            _, psf_error_map, _, _ = modelPlot._imageModel.image_linear_solve(
                inv_bool=True, **self.kwargs_result)
            noise = np.sqrt(
                self.fitting_specify_class.kwargs_data['noise_map']**2 +
                np.abs(psf_error_map[0]))
        else:
            noise = self.fitting_specify_class.kwargs_data['noise_map']

        ps_list = self.image_ps_list
        ps_image = np.zeros_like(ps_list[0])
        if target_ID is None:
            target_ID = 'target_ID'
        for i in range(len(ps_list)):
            ps_image = ps_image + ps_list[i]
        galaxy_list = self.image_host_list
        galaxy_image = np.zeros_like(data)
        for i in range(len(galaxy_list)):
            galaxy_image = galaxy_image + galaxy_list[i]
        model = ps_image + galaxy_image
        data_removePSF = data - ps_image
        norm_residual = (data - model) / noise
        flux_list_2d = [data, model, data_removePSF, norm_residual]
        label_list_2d = [
            'data', 'model', 'data-Point Source', 'normalized residual'
        ]
        flux_list_1d = [data, model, ps_image, galaxy_image]
        label_list_1d = [
            'data', 'model', 'Point Source',
            '{0} galaxy(s)'.format(len(galaxy_list))
        ]
        fig = total_compare(flux_list_2d,
                            label_list_2d,
                            flux_list_1d,
                            label_list_1d,
                            deltaPix=self.fitting_specify_class.deltaPix,
                            zp=self.zp,
                            if_annuli=if_annuli,
                            arrows=arrows,
                            show_plot=show_plot,
                            mask_image=self.fitting_specify_class.
                            kwargs_likelihood['image_likelihood_mask_list'][0],
                            target_ID=target_ID)
        if save_plot == True:
            savename = self.savename
            fig.savefig(savename + "_qso_final_plot.pdf")
        if show_plot == True:
            plt.show()
        else:
            plt.close()
    def test_raise(self):
        with self.assertRaises(ValueError):
            kwargs_data = sim_util.data_configure_simple(numPix=10,
                                                         deltaPix=1,
                                                         background_rms=1)
            #kwargs_data['image_data'] = np.zeros((10, 10))
            kwargs_model = {'source_light_model_list': ['GAUSSIAN']}
            kwargs_params = {
                'kwargs_lens': [],
                'kwargs_source': [{
                    'amp': 1,
                    'sigma': 1,
                    'center_x': 0,
                    'center_y': 0
                }],
                'kwargs_ps': [],
                'kwargs_lens_light': []
            }
            lensPlot = ModelPlot(
                multi_band_list=[[kwargs_data, {
                    'psf_type': 'NONE'
                }, {}]],
                kwargs_model=kwargs_model,
                kwargs_params=kwargs_params,
                arrow_size=0.02,
                cmap_string="gist_heat")
        with self.assertRaises(ValueError):
            kwargs_data = sim_util.data_configure_simple(numPix=10,
                                                         deltaPix=1,
                                                         background_rms=1)
            # kwargs_data['image_data'] = np.zeros((10, 10))
            kwargs_model = {'source_light_model_list': ['GAUSSIAN']}
            kwargs_params = {
                'kwargs_lens': [],
                'kwargs_source': [{
                    'amp': 1,
                    'sigma': 1,
                    'center_x': 0,
                    'center_y': 0
                }],
                'kwargs_ps': [],
                'kwargs_lens_light': []
            }
            lensPlot = ModelPlot(multi_band_list=[[kwargs_data, {}, {}]],
                                 kwargs_model=kwargs_model,
                                 kwargs_params=kwargs_params,
                                 arrow_size=0.02,
                                 cmap_string="gist_heat")
            f, ax = plt.subplots(1, 1, figsize=(4, 4))
            ax = lensPlot.source_plot(ax=ax,
                                      numPix=10,
                                      deltaPix_source=0.1,
                                      v_min=None,
                                      v_max=None,
                                      with_caustics=False,
                                      caustic_color='yellow',
                                      fsize=15,
                                      plot_scale='bad')
            plt.close()
        with self.assertRaises(ValueError):
            kwargs_data = sim_util.data_configure_simple(numPix=10,
                                                         deltaPix=1,
                                                         background_rms=1)
            # kwargs_data['image_data'] = np.zeros((10, 10))
            kwargs_model = {'source_light_model_list': ['GAUSSIAN']}
            kwargs_params = {
                'kwargs_lens': [],
                'kwargs_source': [{
                    'amp': 1,
                    'sigma': 1,
                    'center_x': 0,
                    'center_y': 0
                }],
                'kwargs_ps': [],
                'kwargs_lens_light': []
            }
            lensPlot = ModelPlot(
                multi_band_list=[[kwargs_data, {
                    'psf_type': 'NONE'
                }, {}]],
                kwargs_model=kwargs_model,
                kwargs_params=kwargs_params,
                bands_compute=[False],
                arrow_size=0.02,
                cmap_string="gist_heat")
            lensPlot._select_band(band_index=0)

        with self.assertRaises(ValueError):
            kwargs_data = sim_util.data_configure_simple(numPix=10,
                                                         deltaPix=1,
                                                         background_rms=1,
                                                         exposure_time=1)
            # kwargs_data['image_data'] = np.zeros((10, 10))
            kwargs_model = {'source_light_model_list': ['GAUSSIAN']}
            kwargs_params = {
                'kwargs_lens': [],
                'kwargs_source': [{
                    'amp': 1,
                    'sigma': 1,
                    'center_x': 0,
                    'center_y': 0
                }],
                'kwargs_ps': [],
                'kwargs_lens_light': []
            }
            lensPlot = ModelPlot(
                multi_band_list=[[kwargs_data, {
                    'psf_type': 'NONE'
                }, {}]],
                kwargs_model=kwargs_model,
                kwargs_params=kwargs_params,
                bands_compute=[True],
                arrow_size=0.02,
                cmap_string="gist_heat")

            f, ax = plt.subplots(1, 1, figsize=(4, 4))
            ax = lensPlot.source_plot(ax=ax,
                                      numPix=10,
                                      deltaPix_source=0.1,
                                      v_min=None,
                                      v_max=None,
                                      with_caustics=False,
                                      caustic_color='yellow',
                                      fsize=15,
                                      plot_scale='wrong')
            plt.close()
Beispiel #17
0
        if count == 0:
            QSO_img = multi_band_list[0][0]['image_data']
            plt.imshow(QSO_img, origin='low', norm=LogNorm())
            for i in range(len(source_result)):
                obj_x, obj_y = len(QSO_img) / 2 - source_result[i][
                    'center_x'] / pix_scale, len(
                        QSO_img) / 2 + source_result[i]['center_y'] / pix_scale
                plt.text(obj_x,
                         obj_y,
                         "obj{0}".format(i),
                         fontsize=15,
                         color='k')
            plt.show()
        modelPlot = ModelPlot(multi_band_list,
                              kwargs_model,
                              kwargs_result,
                              arrow_size=0.02,
                              cmap_string="gist_heat",
                              likelihood_mask_list=[QSO_msk])
        f, axes = plt.subplots(3,
                               3,
                               figsize=(16, 16),
                               sharex=False,
                               sharey=False)
        modelPlot.data_plot(ax=axes[0, 0], text="Data")
        modelPlot.model_plot(ax=axes[0, 1])
        modelPlot.normalized_residual_plot(ax=axes[0, 2], v_min=-6, v_max=6)

        modelPlot.decomposition_plot(ax=axes[1, 0],
                                     text='Host galaxy',
                                     source_add=True,
                                     unconvolved=True)
        ds = (x0 - kwargs_ps['ra_image'])**2 + (y0 - kwargs_ps['dec_image'])**2
        if ds.min() < 0.01:
            x_s.append(x[i])
            y_s.append(y[i])
    y_grid, x_grid = np.indices(
        (framesize, framesize))  #with higher resolution 60*6
    # for i in range(len(x_s)):
    #     lens_mask[np.sqrt((y_grid-y_s[i])**2 + (x_grid-x_s[i])**2) <4] = 0

    # plt.imshow(multi_band_list[0][0]['noise_map'],origin='lower')
    # plt.show()
    # plt.imshow(lens_mask,origin='lower')
    # plt.show()
    modelPlot = ModelPlot(multi_band_list,
                          kwargs_model,
                          kwargs_result,
                          arrow_size=0.02,
                          likelihood_mask_list=[lens_mask])
    logL = modelPlot._imageModel.likelihood_data_given_model(source_marg=False,
                                                             linear_prior=None,
                                                             **kwargs_result)
    n_data = modelPlot._imageModel.num_data_evaluate
    chisq = -logL * 2 / n_data

    #!!!
    kwargs_result['kwargs_lens'][0]['gamma'] = np.median(chain_list[-1][1][:,
                                                                           0])

    if chisq < 0.0:
        print(folder[-4:-1], round(np.median(H0_list), 3))
        f, axes = modelPlot.plot_main()
def make_lensmodel(lens_info, theta_E, source_info, box_f):
    # lens data specifics
    lens_image = lens_info['image']
    psf_lens = lens_info['psf']
    background_rms = background_rms_image(5, lens_image)
    exposure_time = 100
    kwargs_data_lens = sim_util.data_configure_simple(len(lens_image),
                                                      lens_info['deltapix'],
                                                      exposure_time,
                                                      background_rms)
    kwargs_data_lens['image_data'] = lens_image
    data_class_lens = ImageData(**kwargs_data_lens)
    #PSF
    kwargs_psf_lens = {
        'psf_type': 'PIXEL',
        'pixel_size': lens_info['deltapix'],
        'kernel_point_source': psf_lens
    }
    psf_class_lens = PSF(**kwargs_psf_lens)
    # lens light model
    lens_light_model_list = ['SERSIC_ELLIPSE']
    lens_light_model_class = LightModel(light_model_list=lens_light_model_list)
    kwargs_model = {'lens_light_model_list': lens_light_model_list}
    kwargs_numerics_galfit = {'supersampling_factor': 1}
    kwargs_constraints = {}
    kwargs_likelihood = {'check_bounds': True}
    image_band = [kwargs_data_lens, kwargs_psf_lens, kwargs_numerics_galfit]
    multi_band_list = [image_band]
    kwargs_data_joint = {
        'multi_band_list': multi_band_list,
        'multi_band_type': 'multi-linear'
    }
    # Sersic component
    fixed_lens_light = [{}]
    kwargs_lens_light_init = [{
        'R_sersic': .1,
        'n_sersic': 4,
        'e1': 0,
        'e2': 0,
        'center_x': 0,
        'center_y': 0
    }]
    kwargs_lens_light_sigma = [{
        'n_sersic': 0.5,
        'R_sersic': 0.2,
        'e1': 0.1,
        'e2': 0.1,
        'center_x': 0.1,
        'center_y': 0.1
    }]
    kwargs_lower_lens_light = [{
        'e1': -0.5,
        'e2': -0.5,
        'R_sersic': 0.01,
        'n_sersic': 0.5,
        'center_x': -10,
        'center_y': -10
    }]
    kwargs_upper_lens_light = [{
        'e1': 0.5,
        'e2': 0.5,
        'R_sersic': 10,
        'n_sersic': 8,
        'center_x': 10,
        'center_y': 10
    }]
    lens_light_params = [
        kwargs_lens_light_init, kwargs_lens_light_sigma, fixed_lens_light,
        kwargs_lower_lens_light, kwargs_upper_lens_light
    ]
    kwargs_params = {'lens_light_model': lens_light_params}
    fitting_seq = FittingSequence(kwargs_data_joint, kwargs_model,
                                  kwargs_constraints, kwargs_likelihood,
                                  kwargs_params)
    fitting_kwargs_list = [[
        'PSO', {
            'sigma_scale': 1.,
            'n_particles': 50,
            'n_iterations': 50
        }
    ]]
    chain_list = fitting_seq.fit_sequence(fitting_kwargs_list)
    kwargs_result = fitting_seq.best_fit()
    modelPlot = ModelPlot(multi_band_list, kwargs_model, kwargs_result)
    # Lens light best result
    kwargs_light_lens = kwargs_result['kwargs_lens_light'][0]
    #Lens model
    kwargs_lens_list = [{
        'theta_E': theta_E,
        'e1': kwargs_light_lens['e1'],
        'e2': kwargs_light_lens['e2'],
        'center_x': kwargs_light_lens['center_x'],
        'center_y': kwargs_light_lens['center_y']
    }]
    lensModel = LensModel(['SIE'])
    lme = LensModelExtensions(lensModel)
    #random position for the source
    x_crit_list, y_crit_list = lme.critical_curve_tiling(
        kwargs_lens_list,
        compute_window=(len(source_info['image'])) * (source_info['deltapix']),
        start_scale=source_info['deltapix'],
        max_order=10)
    if len(x_crit_list) > 2 and len(y_crit_list) > 2:
        x_caustic_list, y_caustic_list = lensModel.ray_shooting(
            x_crit_list, y_crit_list, kwargs_lens_list)
        xsamp0 = np.arange(
            min(x_caustic_list) - min(x_caustic_list) * box_f[0],
            max(x_caustic_list) + max(x_caustic_list) * box_f[1], 0.1)
        xsamp = xsamp0[abs(xsamp0.round(1)) != 0.1]
        ysamp0 = np.arange(
            min(y_caustic_list) - min(y_caustic_list) * box_f[0],
            max(y_caustic_list) + max(y_caustic_list) * box_f[1], 0.1)
        ysamp = ysamp0[abs(ysamp0.round(1)) != 0.1]
        if len(xsamp) == 0 or len(ysamp) == 0:
            x_shift, y_shift = 0.15, 0.15  #arcseconds
        else:
            y_shift = rand.sample(list(ysamp), 1)[0]
            x_shift = rand.sample(list(xsamp), 1)[0]
    else:
        x_shift, y_shift = -0.15, 0.15  #arcseconds
        x_caustic_list = [0]
        y_caustic_list = [0]
    solver = LensEquationSolver(lensModel)
    theta_ra, theta_dec = solver.image_position_from_source(
        x_shift, y_shift, kwargs_lens_list)
    if len(theta_ra) <= 1:
        x_shift, y_shift = -0.2, -0.2  #arcseconds1
    if abs(x_shift) >= int(theta_E) or abs(y_shift) >= int(theta_E):
        x_shift, y_shift = 0.3, -0.3
        print('BLABLA')
    print('HERE',
          min(x_caustic_list) - min(x_caustic_list) * box_f[0],
          max(x_caustic_list) + max(x_caustic_list) * box_f[1],
          min(y_caustic_list) - min(y_caustic_list) * box_f[0],
          max(y_caustic_list) + max(y_caustic_list) * box_f[1])
    return {
        'lens_light_model_list': ['SERSIC_ELLIPSE'],
        'kwargs_light_lens': [kwargs_light_lens],
        'lens_light_model_class': lens_light_model_class,
        'kwargs_lens_list': kwargs_lens_list,
        'kwargs_data_lens': kwargs_data_lens,
        'source_shift': [x_shift, y_shift]
    }
            #    delta_fermat_12 = fermat_pot[0] - fermat_pot[2]
            gamma = kwargs_result['kwargs_lens'][0]['gamma']
            #    phi_ext, gamma_ext = kwargs_result['kwargs_lens'][1]['gamma1'], kwargs_result['kwargs_lens'][1]['gamma2']
            mcmc_new_list.append([gamma, D_dt, cal_h0(z_l, z_s, D_dt)])
        pickle.dump([
            multi_band_list, kwargs_model, kwargs_result, chain_list,
            fix_setting, mcmc_new_list
        ], open(folder + savename, 'wb'))
    #%%Print fitting result:
    multi_band_list, kwargs_model, kwargs_result, chain_list, fix_setting, mcmc_new_list = pickle.load(
        open(folder + savename, 'rb'))
    fixed_lens, fixed_source, fixed_lens_light, fixed_ps, fixed_cosmo = fix_setting
    labels_new = [r"$\gamma$", r"$D_{\Delta t}$", "H$_0$"]
    modelPlot = ModelPlot(multi_band_list,
                          kwargs_model,
                          kwargs_result,
                          arrow_size=0.02,
                          cmap_string="gist_heat")
    f, axes = modelPlot.plot_main()
    f.show()
    # f, axes = modelPlot.plot_separate()
    # f.show()
    # f, axes = modelPlot.plot_subtract_from_data_all()
    # f.show()

    for i in range(len(chain_list)):
        chain_plot.plot_chain_list(chain_list, i)
    plt.show()

    truths = [para_s[0][0]['gamma'], TD_distance, 73.907]
    plot = corner.corner(
Beispiel #21
0
picklename = 'zoutput/' + 'l35_sersicdisk_1.pkl'

result = pickle.load(open(picklename, 'rb'))
best_fit, chain_list_result, trans_paras, material = result
source_result, image_host, ps_result, image_ps, _ = best_fit
chain_list, _ = chain_list_result
if chain_list[-1][0] == 'EMCEE':
    sampler_type, samples_mcmc, param_mcmc, dist_mcmc = chain_list[-1]
    mcmc_new_list, labels_new, _ = trans_paras
multi_band_list, kwargs_model, kwargs_result, QSO_msk, kwargs_fixed_source, kwargs_fixed_ps, kwargs_constraints, kwargs_numerics, classes = material

#%% Recover the plot
from lenstronomy.Plots.model_plot import ModelPlot
modelPlot = ModelPlot(multi_band_list,
                      kwargs_model,
                      kwargs_result,
                      arrow_size=0.02,
                      cmap_string="gist_heat",
                      likelihood_mask_list=[QSO_msk])
# f, axes = plt.subplots(3, 3, figsize=(16, 16), sharex=False, sharey=False)
# modelPlot.data_plot(ax=axes[0,0], text="Data")
# modelPlot.model_plot(ax=axes[0,1])
# modelPlot.normalized_residual_plot(ax=axes[0,2], v_min=-6, v_max=6)
# modelPlot.decomposition_plot(ax=axes[1,0], text='Host galaxy', source_add=True, unconvolved=True)
# modelPlot.decomposition_plot(ax=axes[1,1], text='Host galaxy convolved', source_add=True)
# modelPlot.decomposition_plot(ax=axes[1,2], text='All components convolved', source_add=True, lens_light_add=True, point_source_add=True)
# modelPlot.subtract_from_data_plot(ax=axes[2,0], text='Data - Point Source', point_source_add=True)
# modelPlot.subtract_from_data_plot(ax=axes[2,1], text='Data - host galaxy', source_add=True)
# modelPlot.subtract_from_data_plot(ax=axes[2,2], text='Data - host galaxy - Point Source', source_add=True, point_source_add=True)
# f.tight_layout()
# plt.show()