Exemple #1
0
    def setup(self):
        numpix = 10
        self.ccd_gain = 4.
        self.pixel_scale = 0.13
        self.read_noise = 10.
        kwargs_instrument = {'read_noise': self.read_noise, 'pixel_scale': self.pixel_scale, 'ccd_gain': self.ccd_gain}

        exposure_time = 100
        sky_brightness = 20.
        self.magnitude_zero_point = 21.
        num_exposures = 2
        seeing = 0.9
        kwargs_observations = {'exposure_time': exposure_time, 'sky_brightness': sky_brightness,
                               'magnitude_zero_point': self.magnitude_zero_point, 'num_exposures': num_exposures,
                               'seeing': seeing, 'psf_type': 'GAUSSIAN', 'kernel_point_source': None}
        self.kwargs_data = util.merge_dicts(kwargs_instrument, kwargs_observations)
        self.api = DataAPI(numpix=numpix, data_count_unit='ADU', **self.kwargs_data)

        kwargs_observations = {'exposure_time': exposure_time, 'sky_brightness': sky_brightness,
                               'magnitude_zero_point': self.magnitude_zero_point, 'num_exposures': num_exposures,
                               'seeing': seeing, 'psf_type': 'PIXEL', 'kernel_point_source': np.ones((3, 3))}
        kwargs_data = util.merge_dicts(kwargs_instrument, kwargs_observations)
        self.api_pixel = DataAPI(numpix=numpix, data_count_unit='ADU', **kwargs_data)

        self.ra_at_xy_0 = 0.02
        self.dec_at_xy_0 = 0.02
        self.transform_pix2angle = [[-self.pixel_scale,0],[0,self.pixel_scale]]
        kwargs_pixel_grid = {'ra_at_xy_0':self.ra_at_xy_0,'dec_at_xy_0':self.dec_at_xy_0,
                             'transform_pix2angle':self.transform_pix2angle}
        self.api_pixel_grid = DataAPI(numpix=numpix,
                                      kwargs_pixel_grid=kwargs_pixel_grid,
                                      data_count_unit='ADU',**self.kwargs_data)
Exemple #2
0
    def test_raise(self):
        numpix = 10
        self.ccd_gain = 4.
        self.pixel_scale = 0.13
        self.read_noise = 10.
        kwargs_instrument = {'read_noise': self.read_noise, 'pixel_scale': self.pixel_scale, 'ccd_gain': self.ccd_gain}

        exposure_time = 100
        sky_brightness = 20.
        magnitude_zero_point = 21.
        num_exposures = 2
        seeing = 0.9
        kwargs_observations = {'exposure_time': exposure_time, 'sky_brightness': sky_brightness,
                               'magnitude_zero_point': magnitude_zero_point, 'num_exposures': num_exposures,
                               'seeing': seeing, 'psf_type': 'wrong', 'kernel_point_source': None}
        kwargs_data = util.merge_dicts(kwargs_instrument, kwargs_observations)
        data_api = DataAPI(numpix=numpix, data_count_unit='ADU', **kwargs_data)
        print(data_api._psf_type)
        with self.assertRaises(ValueError):
            data_api = DataAPI(numpix=numpix, data_count_unit='ADU', **kwargs_data)
            psf_class = data_api.psf_class

        kwargs_observations = {'exposure_time': exposure_time, 'sky_brightness': sky_brightness,
                               'magnitude_zero_point': magnitude_zero_point, 'num_exposures': num_exposures,
                               'seeing': seeing, 'psf_type': 'PIXEL', 'kernel_point_source': None}
        kwargs_data = util.merge_dicts(kwargs_instrument, kwargs_observations)
        with self.assertRaises(ValueError):
            data_api = DataAPI(numpix=numpix, data_count_unit='ADU', **kwargs_data)
            psf_class = data_api.psf_class
        
        kwargs_data['kernel_point_source'] = np.ones((3, 3))
        kwargs_pixel_grid = {'ra_at_xy_0':0.02,'dec_at_xy_0':0.02}
        with self.assertRaises(ValueError):
            data_api = DataAPI(numpix=numpix,kwargs_pixel_grid=kwargs_pixel_grid,
                **kwargs_data)
Exemple #3
0
 def __init__(self, numpix, kwargs_single_band, kwargs_model):
     """
     
     :param numpix: number of pixels per axis
     :param kwargs_single_band: keyword arguments specifying the class instance of DataAPI 
     :param kwargs_model: keyword arguments specifying the class instance of ModelAPI 
     """
     DataAPI.__init__(self, numpix, **kwargs_single_band)
     ModelAPI.__init__(self, **kwargs_model)
Exemple #4
0
def get_data_api(pixel_scale, num_pix):
    """Instantiate a simulation tool that knows the detector and observation conditions

    Parameters
    ----------
    pixel_scale : float
        arcsec per pixel
    num_pix : int
        number of pixels per side

    Returns
    -------
    Lenstronomy.DataAPI object

    """
    kwargs_detector = {
        'pixel_scale': pixel_scale,
        'ccd_gain': 100.0,
        'magnitude_zero_point': 30.0,
        'exposure_time': 10000.0,
        'psf_type': 'NONE',
        'background_noise': 0.0
    }
    data_api = DataAPI(num_pix, **kwargs_detector)
    return data_api
Exemple #5
0
 def __init__(self, numpix, kwargs_single_band, kwargs_model,
              kwargs_numerics):
     """
     
     :param numpix: number of pixels per axis
     :param kwargs_single_band: keyword arguments specifying the class instance of DataAPI 
     :param kwargs_model: keyword arguments specifying the class instance of ModelAPI 
     :param kwargs_numerics: keyword argument with various numeric description (see ImageNumerics class for options)
     """
     DataAPI.__init__(self, numpix, **kwargs_single_band)
     ModelAPI.__init__(self, **kwargs_model)
     self._image_model_class = ImageModel(self.data_class, self.psf_class,
                                          self.lens_model_class,
                                          self.source_model_class,
                                          self.lens_light_model_class,
                                          self.point_source_model_class,
                                          kwargs_numerics)
Exemple #6
0
    def setup(self):
        numpix = 10
        self.ccd_gain = 4.
        self.pixel_scale = 0.13
        self.read_noise = 10.
        kwargs_instrument = {
            'read_noise': self.read_noise,
            'pixel_scale': self.pixel_scale,
            'ccd_gain': self.ccd_gain
        }

        exposure_time = 100
        sky_brightness = 20.
        self.magnitude_zero_point = 21.
        num_exposures = 2
        seeing = 0.9
        kwargs_observations = {
            'exposure_time': exposure_time,
            'sky_brightness': sky_brightness,
            'magnitude_zero_point': self.magnitude_zero_point,
            'num_exposures': num_exposures,
            'seeing': seeing,
            'psf_type': 'GAUSSIAN',
            'psf_model': None
        }
        self.kwargs_data = util.merge_dicts(kwargs_instrument,
                                            kwargs_observations)
        self.api = DataAPI(numpix=numpix,
                           data_count_unit='ADU',
                           **self.kwargs_data)

        kwargs_observations = {
            'exposure_time': exposure_time,
            'sky_brightness': sky_brightness,
            'magnitude_zero_point': self.magnitude_zero_point,
            'num_exposures': num_exposures,
            'seeing': seeing,
            'psf_type': 'PIXEL',
            'psf_model': np.ones((3, 3))
        }
        kwargs_data = util.merge_dicts(kwargs_instrument, kwargs_observations)
        self.api_pixel = DataAPI(numpix=numpix,
                                 data_count_unit='ADU',
                                 **kwargs_data)
Exemple #7
0
    def _set_sim_api(self, num_pix, kwargs_detector, psf_kernel_size, which_psf_maps):
        """Set the simulation API objects

        """
        self.data_api = DataAPI(num_pix, **kwargs_detector)
        #self.pixel_scale = data_api.pixel_scale
        pixel_scale = kwargs_detector['pixel_scale']
        psf_model = psf_utils.get_PSF_model(kwargs_detector['psf_type'], pixel_scale, seeing=kwargs_detector['seeing'], kernel_size=psf_kernel_size, which_psf_maps=which_psf_maps)
        # Set the precision level of lens equation solver
        self.min_distance = 0.05
        self.search_window = pixel_scale*num_pix
        self.image_model = ImageModel(self.data_api.data_class, psf_model, self.lens_mass_model, self.src_light_model, self.lens_light_model, self.ps_model, kwargs_numerics=self.kwargs_numerics)
        if 'agn_light' in self.components:
            self.unlensed_image_model = ImageModel(self.data_api.data_class, psf_model, None, self.src_light_model, None, self.unlensed_ps_model, kwargs_numerics=self.kwargs_numerics)
        else:
            self.unlensed_image_model = ImageModel(self.data_api.data_class, psf_model, None, self.src_light_model, None, None, kwargs_numerics=self.kwargs_numerics)
Exemple #8
0
    def survey_kwargs(self, survey_kwargs):
        survey_name = survey_kwargs['survey_name']
        bandpass_list = survey_kwargs['bandpass_list']
        coadd_years = survey_kwargs.get('coadd_years')
        override_obs_kwargs = survey_kwargs.get('override_obs_kwargs', {})
        override_camera_kwargs = survey_kwargs.get('override_camera_kwargs', {})

        import lenstronomy.SimulationAPI.ObservationConfig as ObsConfig
        from importlib import import_module
        sys.path.insert(0, ObsConfig.__path__[0])
        SurveyClass = getattr(import_module(survey_name), survey_name)
        self._data_api = [] # init
        self._image_model = [] # init
        for bp in bandpass_list:
            survey_obj = SurveyClass(band=bp, 
                                     psf_type=self.psf_type, 
                                     coadd_years=coadd_years)
            # Override as specified in survey_kwargs
            survey_obj.camera.update(override_camera_kwargs)
            survey_obj.obs.update(override_obs_kwargs)
            # This is what we'll actually use
            kwargs_detector = survey_obj.kwargs_single_band()
            data_api = DataAPI(self.n_pix, **kwargs_detector)
            psf_model = psf_utils.get_PSF_model(self.psf_type, 
                                                self.pixel_scale, 
                                                seeing=kwargs_detector['seeing'], 
                                                kernel_size=self.psf_kernel_size, 
                                                which_psf_maps=self.which_psf_maps)
            image_model_bp = ImageModel(data_api.data_class, 
                                        psf_model, 
                                        self.lens_model, 
                                        self.src_model, 
                                        None, 
                                        None, 
                                        kwargs_numerics=self.kwargs_numerics)
            self._data_api.append(data_api)
            self._image_model.append(image_model_bp)
Exemple #9
0
def main():
    args = parse_args()
    cfg = BaobabConfig.from_file(args.config)
    if args.n_data is not None:
        cfg.n_data = args.n_data
    # Seed for reproducibility
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)
    # Create data directory
    save_dir = cfg.out_dir
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        print("Destination folder path: {:s}".format(save_dir))
        print("Log path: {:s}".format(cfg.log_path))
        cfg.export_log()
    else:
        raise OSError("Destination folder already exists.")
    # Instantiate PSF models
    psf_models = instantiate_PSF_models(cfg.psf, cfg.instrument.pixel_scale)
    n_psf = len(psf_models)
    # Instantiate density models
    kwargs_model = dict(
        lens_model_list=[
            cfg.bnn_omega.lens_mass.profile,
            cfg.bnn_omega.external_shear.profile
        ],
        source_light_model_list=[cfg.bnn_omega.src_light.profile],
    )
    lens_mass_model = LensModel(
        lens_model_list=kwargs_model['lens_model_list'])
    src_light_model = LightModel(
        light_model_list=kwargs_model['source_light_model_list'])
    lens_eq_solver = LensEquationSolver(lens_mass_model)
    lens_light_model = None
    ps_model = None
    if 'lens_light' in cfg.components:
        kwargs_model['lens_light_model_list'] = [
            cfg.bnn_omega.lens_light.profile
        ]
        lens_light_model = LightModel(
            light_model_list=kwargs_model['lens_light_model_list'])
    if 'agn_light' in cfg.components:
        kwargs_model['point_source_model_list'] = [
            cfg.bnn_omega.agn_light.profile
        ]
        ps_model = PointSource(
            point_source_type_list=kwargs_model['point_source_model_list'],
            fixed_magnification_list=[False])
    # Instantiate Selection object
    selection = Selection(cfg.selection, cfg.components)
    # Initialize BNN prior
    bnn_prior = getattr(bnn_priors, cfg.bnn_prior_class)(cfg.bnn_omega,
                                                         cfg.components)
    # Initialize empty metadata dataframe
    metadata = pd.DataFrame()
    metadata_path = os.path.join(save_dir, 'metadata.csv')
    current_idx = 0  # running idx of dataset
    pbar = tqdm(total=cfg.n_data)
    while current_idx < cfg.n_data:
        sample = bnn_prior.sample()  # FIXME: sampling in batches
        # Selections on sampled parameters
        if selection.reject_initial(sample):
            continue
        psf_model = get_PSF_model(psf_models, n_psf, current_idx)
        # Instantiate the image maker data_api with detector and observation conditions
        kwargs_detector = util.merge_dicts(cfg.instrument, cfg.bandpass,
                                           cfg.observation)
        kwargs_detector.update(seeing=cfg.psf.fwhm,
                               psf_type=cfg.psf.type,
                               kernel_point_source=psf_model,
                               background_noise=0.0)
        data_api = DataAPI(cfg.image.num_pix, **kwargs_detector)
        # Generate the image
        img, img_features = generate_image(
            sample,
            psf_model,
            data_api,
            lens_mass_model,
            src_light_model,
            lens_eq_solver,
            cfg.instrument.pixel_scale,
            cfg.image.num_pix,
            cfg.components,
            cfg.numerics,
            min_magnification=cfg.selection.magnification.min,
            lens_light_model=lens_light_model,
            ps_model=ps_model)
        if img is None:  # couldn't make the magnification cut
            continue
        # Save image file
        img_filename = 'X_{0:07d}.npy'.format(current_idx)
        img_path = os.path.join(save_dir, img_filename)
        np.save(img_path, img)
        # Save labels
        meta = {}
        for comp in cfg.components:
            for param_name, param_value in sample[comp].items():
                meta['{:s}_{:s}'.format(comp, param_name)] = param_value
        #if cfg.bnn_prior_class in ['DiagonalCosmoBNNPrior']:
        #    if cfg.bnn_omega.time_delays.calculate_time_delays:
        #        # Order time delays in increasing dec
        #        unordered_td = sample['misc']['true_td'] # np array
        #        increasing_dec_i = np.argsort(img_features['y_image'])
        #        td = unordered_td[increasing_dec_i]
        #        td = td[1:] - td[0] # take BCD - A
        #        sample['misc']['true_td'] = list(td)
        #        img_features['x_image'] = img_features['x_image'][increasing_dec_i]
        #        img_features['y_image'] = img_features['y_image'][increasing_dec_i]
        if cfg.bnn_prior_class in [
                'EmpiricalBNNPrior', 'DiagonalCosmoBNNPrior'
        ]:
            for misc_name, misc_value in sample['misc'].items():
                meta['{:s}'.format(misc_name)] = misc_value
        if 'agn_light' in cfg.components:
            x_image = np.zeros(4)
            y_image = np.zeros(4)
            n_img = len(img_features['x_image'])
            meta['n_img'] = n_img
            x_image[:n_img] = img_features['x_image']
            y_image[:n_img] = img_features['y_image']
            for i in range(4):
                meta['x_image_{:d}'.format(i)] = x_image[i]
                meta['y_image_{:d}'.format(i)] = y_image[i]
        meta['total_magnification'] = img_features['total_magnification']
        meta['img_filename'] = img_filename
        metadata = metadata.append(meta, ignore_index=True)
        # Export metadata.csv for the first time
        if current_idx == 0:
            # Sort columns lexicographically
            metadata = metadata.reindex(sorted(metadata.columns), axis=1)
            # Export to csv
            metadata.to_csv(metadata_path, index=None)
            # Initialize empty dataframe for next checkpoint chunk
            metadata = pd.DataFrame()
            gc.collect()

        # Export metadata every checkpoint interval
        if (current_idx + 1) % cfg.checkpoint_interval == 0:
            # Export to csv
            metadata.to_csv(metadata_path, index=None, mode='a', header=None)
            # Initialize empty dataframe for next checkpoint chunk
            metadata = pd.DataFrame()
            gc.collect()
        # Update progress
        current_idx += 1
        pbar.update(1)
    # Export to csv
    metadata.to_csv(metadata_path, index=None, mode='a', header=None)
    pbar.close()
Exemple #10
0
def test_supersampling_simple():
    """

    :return:
    """
    from lenstronomy.Data.psf import PSF
    from lenstronomy.SimulationAPI.data_api import DataAPI

    detector_pixel_scale = 0.04
    numpix = 64
    supersampling_factor = 2
    # generate a Gaussian image

    x, y = util.make_grid(numPix=numpix * supersampling_factor,
                          deltapix=detector_pixel_scale / supersampling_factor)
    from lenstronomy.LightModel.Profiles.gaussian import Gaussian
    gaussian = Gaussian()
    image_1d = gaussian.function(x, y, amp=1, sigma=0.1)
    image = util.array2image(image_1d)

    # generate psf kernal supersampled
    kernel_super = kernel_util.kernel_gaussian(
        kernel_numPix=21 * supersampling_factor + 1,
        deltaPix=detector_pixel_scale / supersampling_factor,
        fwhm=0.2)

    psf_parameters = {
        'psf_type': 'PIXEL',
        'kernel_point_source': kernel_super,
        'point_source_supersampling_factor': supersampling_factor
    }
    kwargs_detector = {
        'pixel_scale': detector_pixel_scale,
        'ccd_gain': 2.5,
        'read_noise': 4.0,
        'magnitude_zero_point': 25.0,
        'exposure_time': 5400.0,
        'sky_brightness': 22,
        'num_exposures': 1,
        'background_noise': None
    }
    kwargs_numerics = {
        'supersampling_factor': 2,
        'supersampling_convolution': True,
        'point_source_supersampling_factor': 2,
        'supersampling_kernel_size': 21
    }
    psf_model = PSF(**psf_parameters)
    data_class = DataAPI(numpix=numpix, **kwargs_detector).data_class

    from lenstronomy.ImSim.Numerics.numerics_subframe import NumericsSubFrame
    image_numerics = NumericsSubFrame(pixel_grid=data_class,
                                      psf=psf_model,
                                      **kwargs_numerics)

    conv_class = image_numerics.convolution_class
    conv_flat = conv_class.convolution2d(image)
    print(np.shape(conv_flat), 'shape of output')

    # psf_helper = lenstronomy_utils.PSFHelper(data_class, psf_model, kwargs_numerics)

    # Convolve with lenstronomy and with scipy
    # helper_image = psf_helper.psf_model(image)
    from scipy import signal

    scipy_image = signal.fftconvolve(image, kernel_super, mode='same')
    from lenstronomy.Util import image_util
    image_scipy_resized = image_util.re_size(scipy_image, supersampling_factor)
    image_unconvolved = image_util.re_size(image, supersampling_factor)

    # Compare the outputs

    # low res convolution as comparison
    kwargs_numerics_low_res = {
        'supersampling_factor': 2,
        'supersampling_convolution': False,
        'point_source_supersampling_factor': 2,
    }
    image_numerics_low_res = NumericsSubFrame(pixel_grid=data_class,
                                              psf=psf_model,
                                              **kwargs_numerics_low_res)
    conv_class_low_res = image_numerics_low_res.convolution_class
    conv_flat_low_res = conv_class_low_res.convolution2d(image_unconvolved)

    #import matplotlib.pyplot as plt
    #plt.matshow(image_scipy_resized - image_unconvolved)
    #plt.colorbar()
    #plt.show()

    #plt.matshow(image_scipy_resized - conv_flat)
    #plt.colorbar()
    #plt.show()

    #plt.matshow(image_scipy_resized - conv_flat_low_res)
    #plt.colorbar()
    #plt.show()

    np.testing.assert_almost_equal(conv_flat, image_scipy_resized)
Exemple #11
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.n_data is not None:
        cfg.n_data = args.n_data
    # Seed for reproducibility
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    if not os.path.exists(cfg.out_dir):
        os.makedirs(cfg.out_dir)
        print("Destination folder: {:s}".format(cfg.out_dir))
    else:
        raise OSError("Destination folder already exists.")

    # Instantiate PSF models
    psf_models = get_PSF_models(cfg.psf, cfg.instrument.pixel_scale)
    n_psf = len(psf_models)

    # Instantiate ImageData
    #kwargs_data = sim_util.data_configure_simple(**cfg.image)
    #image_data = ImageData(**kwargs_data)

    # Instantiate density models
    kwargs_model = dict(
        lens_model_list=[
            cfg.bnn_omega.lens_mass.profile,
            cfg.bnn_omega.external_shear.profile
        ],
        source_light_model_list=[cfg.bnn_omega.src_light.profile],
    )
    lens_mass_model = LensModel(
        lens_model_list=kwargs_model['lens_model_list'])
    src_light_model = LightModel(
        light_model_list=kwargs_model['source_light_model_list'])
    lens_eq_solver = LensEquationSolver(lens_mass_model)
    lens_light_model = None
    ps_model = None

    if 'lens_light' in cfg.components:
        kwargs_model['lens_light_model_list'] = [
            cfg.bnn_omega.lens_light.profile
        ]
        lens_light_model = LightModel(
            light_model_list=kwargs_model['lens_light_model_list'])
    if 'agn_light' in cfg.components:
        kwargs_model['point_source_model_list'] = [
            cfg.bnn_omega.agn_light.profile
        ]
        ps_model = PointSource(
            point_source_type_list=kwargs_model['point_source_model_list'],
            fixed_magnification_list=[False])

    # Initialize BNN prior
    bnn_prior = getattr(bnn_priors, cfg.bnn_prior_class)(cfg.bnn_omega,
                                                         cfg.components)

    # Initialize dataframe of labels
    param_list = []
    for comp in cfg.components:
        param_list += [
            '{:s}_{:s}'.format(comp, param)
            for param in bnn_prior.params[cfg.bnn_omega[comp]['profile']]
        ]
    if 'agn_light' in cfg.components:
        param_list += ['magnification_{:d}'.format(i) for i in range(4)]
        param_list += ['x_image_{:d}'.format(i) for i in range(4)]
        param_list += ['y_image_{:d}'.format(i) for i in range(4)]
        param_list += ['n_img']
    param_list += ['img_path', 'total_magnification']
    if cfg.bnn_prior_class == 'EmpiricalBNNPrior':
        param_list += [
            'z_lens', 'z_src', 'vel_disp_iso', 'R_eff_lens', 'R_eff_src',
            'abmag_src'
        ]
    metadata = pd.DataFrame(columns=param_list)

    print("Starting simulation...")
    current_idx = 0  # running idx of dataset
    pbar = tqdm(total=cfg.n_data)
    while current_idx < cfg.n_data:
        psf_model = psf_models[current_idx % n_psf]
        sample = bnn_prior.sample()  # FIXME: sampling in batches
        if sample['lens_mass']['theta_E'] < cfg.selection.theta_E.min:
            continue

        # Instantiate SimAPI (converts mag to amp and wraps around image model)
        kwargs_detector = util.merge_dicts(cfg.instrument, cfg.bandpass,
                                           cfg.observation)
        kwargs_detector.update(
            seeing=cfg.psf.fwhm,
            psf_type=cfg.psf.type,
            kernel_point_source=psf_model
        )  # keyword deprecation warning: I asked Simon to change this to
        data_api = DataAPI(cfg.image.num_pix, **kwargs_detector)
        image_data = data_api.data_class

        #sim_api = SimAPI(numpix=cfg.image.num_pix,
        #                 kwargs_single_band=kwargs_detector,
        #                 kwargs_model=kwargs_model,
        #                 kwargs_numerics=cfg.numerics)

        kwargs_lens_mass = [sample['lens_mass'], sample['external_shear']]
        kwargs_src_light = [sample['src_light']]
        kwargs_src_light = amp_to_mag_extended(kwargs_src_light,
                                               src_light_model, data_api)
        kwargs_lens_light = None
        kwargs_ps = None

        if 'agn_light' in cfg.components:
            x_image, y_image = lens_eq_solver.findBrightImage(
                sample['src_light']['center_x'],
                sample['src_light']['center_y'],
                kwargs_lens_mass,
                numImages=4,
                min_distance=cfg.instrument.pixel_scale,
                search_window=cfg.image.num_pix * cfg.instrument.pixel_scale)
            magnification = np.abs(
                lens_mass_model.magnification(x_image,
                                              y_image,
                                              kwargs=kwargs_lens_mass))
            unlensed_mag = sample['agn_light']['magnitude']  # unlensed agn mag
            kwargs_unlensed_mag_ps = [{
                'ra_image': x_image,
                'dec_image': y_image,
                'magnitude': unlensed_mag
            }]  # note unlensed magnitude
            kwargs_unlensed_amp_ps = amp_to_mag_point(
                kwargs_unlensed_mag_ps, ps_model,
                data_api)  # note unlensed amp
            kwargs_ps = copy.deepcopy(kwargs_unlensed_amp_ps)
            for kw in kwargs_ps:
                kw.update(point_amp=kw['point_amp'] * magnification)
        else:
            kwargs_unlensed_amp_ps = None

        if 'lens_light' in cfg.components:
            kwargs_lens_light = [sample['lens_light']]
            kwargs_lens_light = amp_to_mag_extended(kwargs_lens_light,
                                                    lens_light_model, data_api)

        # Instantiate image model
        image_model = ImageModel(image_data,
                                 psf_model,
                                 lens_mass_model,
                                 src_light_model,
                                 lens_light_model,
                                 ps_model,
                                 kwargs_numerics=cfg.numerics)

        # Compute magnification
        lensed_total_flux = get_lensed_total_flux(kwargs_lens_mass,
                                                  kwargs_src_light,
                                                  kwargs_lens_light, kwargs_ps,
                                                  image_model)
        unlensed_total_flux = get_unlensed_total_flux(kwargs_src_light,
                                                      src_light_model,
                                                      kwargs_unlensed_amp_ps,
                                                      ps_model)
        total_magnification = lensed_total_flux / unlensed_total_flux

        # Apply magnification cut
        if total_magnification < cfg.selection.magnification.min:
            continue

        # Generate image for export
        img = image_model.image(kwargs_lens_mass, kwargs_src_light,
                                kwargs_lens_light, kwargs_ps)
        #kwargs_in_amp = sim_api.magnitude2amplitude(kwargs_lens_mass, kwargs_src_light, kwargs_lens_light, kwargs_ps)
        #imsim_api = sim_api.image_model_class
        #imsim_api.image(*kwargs_in_amp)

        # Add noise
        noise = data_api.noise_for_model(img,
                                         background_noise=True,
                                         poisson_noise=True,
                                         seed=None)
        img += noise

        # Save image file
        img_path = os.path.join(cfg.out_dir,
                                'X_{0:07d}.npy'.format(current_idx + 1))
        np.save(img_path, img)

        # Save labels
        meta = {}
        for comp in cfg.components:
            for param_name, param_value in sample[comp].items():
                meta['{:s}_{:s}'.format(comp, param_name)] = param_value
        if 'agn_light' in cfg.components:
            n_img = len(x_image)
            for i in range(n_img):
                meta['magnification_{:d}'.format(i)] = magnification[i]
                meta['x_image_{:d}'.format(i)] = x_image[i]
                meta['y_image_{:d}'.format(i)] = y_image[i]
                meta['n_img'] = n_img
        if cfg.bnn_prior_class == 'EmpiricalBNNPrior':
            for misc_name, misc_value in sample['misc'].items():
                meta['{:s}'.format(misc_name)] = misc_value
        meta['total_magnification'] = total_magnification
        meta['img_path'] = img_path
        metadata = metadata.append(meta, ignore_index=True)

        # Update progress
        current_idx += 1
        pbar.update(1)
    pbar.close()

    # Fix column ordering
    metadata = metadata[param_list]
    metadata_path = os.path.join(cfg.out_dir, 'metadata.csv')
    metadata.to_csv(metadata_path, index=None)
    print("Labels include: ", metadata.columns.values)