Exemple #1
0
    def critical_cruves_caustics(self,
                                 lens_system=None,
                                 main=None,
                                 halos=None,
                                 multiplane=None,
                                 compute_window=1.5,
                                 scale=0.5,
                                 max_order=10,
                                 method=None,
                                 grid_scale=0.005):

        if lens_system is None:
            lens_system = self.build_system(main=main,
                                            realization=halos,
                                            multiplane=multiplane)

        lenstronomy = self.lenstronomy_build()

        lensmodel, lensmodel_params = lenstronomy.get_lensmodel(lens_system)

        extension = LensModelExtensions(lensmodel)

        if method == 'tiling':
            xcrit, ycrit = extension.critical_curve_tiling(
                lensmodel_params,
                compute_window=compute_window,
                start_scale=scale,
                max_order=max_order)
            return xcrit, ycrit
        else:

            ra_crit_list, dec_crit_list, ra_caustic_list, dec_caustic_list = \
                extension.critical_curve_caustics(lensmodel_params,compute_window=5,grid_scale=grid_scale)

            return ra_crit_list, dec_crit_list, ra_caustic_list, dec_caustic_list
Exemple #2
0
 def test_critical_curves(self):
     lens_model_list = ['SPEP']
     phi, q = 1., 0.8
     e1, e2 = param_util.phi_q2_ellipticity(phi, q)
     kwargs_lens = [{
         'theta_E': 1.,
         'gamma': 2.,
         'e1': e1,
         'e2': e2,
         'center_x': 0,
         'center_y': 0
     }]
     lensModel = LensModelExtensions(LensModel(lens_model_list))
     ra_crit_list, dec_crit_list, ra_caustic_list, dec_caustic_list = lensModel.critical_curve_caustics(
         kwargs_lens, compute_window=5, grid_scale=0.005)
     print(ra_caustic_list)
     npt.assert_almost_equal(ra_caustic_list[0][3],
                             -0.25629009803139047,
                             decimal=5)
     npt.assert_almost_equal(dec_caustic_list[0][3],
                             -0.39153358367275115,
                             decimal=5)
     npt.assert_almost_equal(ra_crit_list[0][3],
                             -0.53249999999999997,
                             decimal=5)
     npt.assert_almost_equal(dec_crit_list[0][3],
                             -1.2536936868024853,
                             decimal=5)
Exemple #3
0
 def test_critical_curves(self):
     lens_model_list = ['SPEP']
     kwargs_lens = [{
         'theta_E': 1.,
         'gamma': 2.,
         'q': 0.8,
         'phi_G': 1.,
         'center_x': 0,
         'center_y': 0
     }]
     lensModel = LensModelExtensions(lens_model_list)
     ra_crit_list, dec_crit_list, ra_caustic_list, dec_caustic_list = lensModel.critical_curve_caustics(
         kwargs_lens, compute_window=5, grid_scale=0.005)
     print(ra_caustic_list)
     npt.assert_almost_equal(ra_caustic_list[0][3],
                             -0.25629009803139047,
                             decimal=5)
     npt.assert_almost_equal(dec_caustic_list[0][3],
                             -0.39153358367275115,
                             decimal=5)
     npt.assert_almost_equal(ra_crit_list[0][3],
                             -0.53249999999999997,
                             decimal=5)
     npt.assert_almost_equal(dec_crit_list[0][3],
                             -1.2536936868024853,
                             decimal=5)
     """
    def test_critical_curves(self):
        lens_model_list = ['SPEP']
        phi, q = 1., 0.8
        e1, e2 = param_util.phi_q2_ellipticity(phi, q)
        kwargs_lens = [{
            'theta_E': 1.,
            'gamma': 2.,
            'e1': e1,
            'e2': e2,
            'center_x': 0,
            'center_y': 0
        }]
        lens_model = LensModel(lens_model_list)
        lensModelExtensions = LensModelExtensions(LensModel(lens_model_list))
        ra_crit_list, dec_crit_list, ra_caustic_list, dec_caustic_list = lensModelExtensions.critical_curve_caustics(
            kwargs_lens, compute_window=5, grid_scale=0.005)

        # here we test whether the caustic points are in fact at high magnifications (close to infinite)
        # close here means above magnification of 1000000 (with matplotlib method, this limit achieved was 170)
        for k in range(len(ra_crit_list)):
            ra_crit = ra_crit_list[k]
            dec_crit = dec_crit_list[k]
            mag = lens_model.magnification(ra_crit, dec_crit, kwargs_lens)
            assert np.all(np.abs(mag) > 100000)
Exemple #5
0
def lens_model_plot(ax,
                    lensModel,
                    kwargs_lens,
                    numPix=500,
                    deltaPix=0.01,
                    sourcePos_x=0,
                    sourcePos_y=0,
                    point_source=False,
                    with_caustics=False,
                    with_convergence=True,
                    coord_center_ra=0,
                    coord_center_dec=0,
                    coord_inverse=False,
                    fast_caustic=False):
    """
    plots a lens model (convergence) and the critical curves and caustics

    :param ax:
    :param kwargs_lens:
    :param numPix:
    :param deltaPix:
    :param fast_caustic: boolean, if True, uses faster but less precise caustic calculation
     (might have troubles for the outer caustic (inner critical curve)
    :param with_convergence: boolean, if True, plots the convergence of the deflector
    :return:
    """
    kwargs_data = sim_util.data_configure_simple(numPix,
                                                 deltaPix,
                                                 center_ra=coord_center_ra,
                                                 center_dec=coord_center_dec,
                                                 inverse=coord_inverse)
    data = ImageData(**kwargs_data)
    _coords = data
    _frame_size = numPix * deltaPix
    x_grid, y_grid = data.pixel_coordinates
    lensModelExt = LensModelExtensions(lensModel)
    x_grid1d = util.image2array(x_grid)
    y_grid1d = util.image2array(y_grid)
    if with_convergence:
        kappa_result = lensModel.kappa(x_grid1d, y_grid1d, kwargs_lens)
        kappa_result = util.array2image(kappa_result)
        im = ax.matshow(np.log10(kappa_result),
                        origin='lower',
                        extent=[0, _frame_size, 0, _frame_size],
                        cmap='Greys',
                        vmin=-1,
                        vmax=1)  #, cmap=self._cmap, vmin=v_min, vmax=v_max)
    if with_caustics is True:
        if fast_caustic:
            ra_crit_list, dec_crit_list, ra_caustic_list, dec_caustic_list = lensModelExt.critical_curve_caustics(
                kwargs_lens, compute_window=_frame_size, grid_scale=deltaPix)
            plot_util.plot_line_set_list(ax,
                                         _coords,
                                         ra_caustic_list,
                                         dec_caustic_list,
                                         color='g')
            plot_util.plot_line_set_list(ax,
                                         _coords,
                                         ra_crit_list,
                                         dec_crit_list,
                                         color='r')
        else:
            ra_crit_list, dec_crit_list = lensModelExt.critical_curve_tiling(
                kwargs_lens,
                compute_window=_frame_size,
                start_scale=deltaPix,
                max_order=10)
            ra_caustic_list, dec_caustic_list = lensModel.ray_shooting(
                ra_crit_list, dec_crit_list, kwargs_lens)
            plot_util.plot_line_set(ax,
                                    _coords,
                                    ra_caustic_list,
                                    dec_caustic_list,
                                    color='g')
            plot_util.plot_line_set(ax,
                                    _coords,
                                    ra_crit_list,
                                    dec_crit_list,
                                    color='r')
    if point_source:
        from lenstronomy.LensModel.Solver.lens_equation_solver import LensEquationSolver
        solver = LensEquationSolver(lensModel)
        theta_x, theta_y = solver.image_position_from_source(
            sourcePos_x,
            sourcePos_y,
            kwargs_lens,
            min_distance=deltaPix,
            search_window=deltaPix * numPix)
        mag_images = lensModel.magnification(theta_x, theta_y, kwargs_lens)
        x_image, y_image = _coords.map_coord2pix(theta_x, theta_y)
        abc_list = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K']
        for i in range(len(x_image)):
            x_ = (x_image[i] + 0.5) * deltaPix
            y_ = (y_image[i] + 0.5) * deltaPix
            ax.plot(x_,
                    y_,
                    'dk',
                    markersize=4 * (1 + np.log(np.abs(mag_images[i]))),
                    alpha=0.5)
            ax.text(x_, y_, abc_list[i], fontsize=20, color='k')
        x_source, y_source = _coords.map_coord2pix(sourcePos_x, sourcePos_y)
        ax.plot((x_source + 0.5) * deltaPix, (y_source + 0.5) * deltaPix,
                '*k',
                markersize=10)
    ax.set_xlim([0, _frame_size])
    ax.set_ylim([0, _frame_size])
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.autoscale(False)
    return ax
Exemple #6
0
class LensModelPlot(object):
    """
    class that manages the summary plots of a lens model
    """
    def __init__(self,
                 kwargs_data,
                 kwargs_psf,
                 kwargs_numerics,
                 kwargs_model,
                 kwargs_lens,
                 kwargs_source,
                 kwargs_lens_light,
                 kwargs_ps,
                 arrow_size=0.1,
                 cmap_string="gist_heat",
                 high_res=5):
        """

        :param kwargs_options:
        :param kwargs_data:
        :param arrow_size:
        :param cmap_string:
        """
        self._kwargs_data = kwargs_data
        if isinstance(cmap_string, str) or isinstance(cmap_string, unicode):
            cmap = plt.get_cmap(cmap_string)
        else:
            cmap = cmap_string
        cmap.set_bad(color='k', alpha=1.)
        cmap.set_under('k')
        self._cmap = cmap
        self._arrow_size = arrow_size
        data = Data(kwargs_data)
        self._coords = data._coords
        nx, ny = np.shape(kwargs_data['image_data'])
        Mpix2coord = kwargs_data['transform_pix2angle']
        self._Mpix2coord = Mpix2coord

        self._deltaPix = self._coords.pixel_size
        self._frame_size = self._deltaPix * nx

        self._x_grid, self._y_grid = data.coordinates

        self._imageModel = class_creator.creat_image_model(
            kwargs_data, kwargs_psf, kwargs_numerics, kwargs_model)
        self._analysis = LensAnalysis(kwargs_model)
        self._lensModel = LensModelExtensions(
            lens_model_list=kwargs_model.get('lens_model_list', ['NONE']),
            z_source=kwargs_model.get('z_source', None),
            redshift_list=kwargs_model.get('redshift_list', None),
            multi_plane=kwargs_model.get('multi_plane', False))
        self._ra_crit_list, self._dec_crit_list, self._ra_caustic_list, self._dec_caustic_list = self._lensModel.critical_curve_caustics(
            kwargs_lens, compute_window=self._frame_size, grid_scale=0.01)

        model, error_map, cov_param, param = self._imageModel.image_linear_solve(
            kwargs_lens,
            kwargs_source,
            kwargs_lens_light,
            kwargs_ps,
            inv_bool=True)
        self._kwargs_lens = kwargs_lens
        self._kwargs_source = kwargs_source
        self._kwargs_lens_light = kwargs_lens_light
        self._kwargs_else = kwargs_ps
        self._model = model
        self._data = kwargs_data['image_data']
        self._cov_param = cov_param
        self._norm_residuals = self._imageModel.reduced_residuals(
            model, error_map=error_map)
        self._reduced_x2 = self._imageModel.reduced_chi2(model,
                                                         error_map=error_map)
        log_model = np.log10(model)
        log_model[np.isnan(log_model)] = -5
        self._v_min_default = max(np.min(log_model), -5)
        self._v_max_default = min(np.max(log_model), 10)
        print("reduced chi^^ = ", self._reduced_x2)

    def data_plot(self, ax, v_min=None, v_max=None):
        """

        :param ax:
        :return:
        """
        if v_min is None:
            v_min = self._v_min_default
        if v_max is None:
            v_max = self._v_max_default
        im = ax.matshow(np.log10(self._data),
                        origin='lower',
                        extent=[0, self._frame_size, 0, self._frame_size],
                        cmap=self._cmap,
                        vmin=v_min,
                        vmax=v_max)  # , vmin=0, vmax=2

        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)

        scale_bar(ax, self._frame_size, dist=1, text='1"')
        text_description(ax,
                         self._frame_size,
                         text="Observed",
                         color="w",
                         backgroundcolor='k')
        coordinate_arrows(ax,
                          self._frame_size,
                          self._coords,
                          arrow_size=self._arrow_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'log$_{10}$ flux', fontsize=15)
        return ax

    def model_plot(self, ax, v_min=None, v_max=None):
        """

        :param ax:
        :param model:
        :param v_min:
        :param v_max:
        :return:
        """
        if v_min is None:
            v_min = self._v_min_default
        if v_max is None:
            v_max = self._v_max_default
        im = ax.matshow(np.log10(self._model),
                        origin='lower',
                        vmin=v_min,
                        vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size],
                        cmap=self._cmap)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"')
        text_description(ax,
                         self._frame_size,
                         text="Reconstructed",
                         color="w",
                         backgroundcolor='k')
        coordinate_arrows(ax,
                          self._frame_size,
                          self._coords,
                          arrow_size=self._arrow_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'log$_{10}$ flux', fontsize=15)

        plot_line_set(ax,
                      self._coords,
                      self._ra_caustic_list,
                      self._dec_caustic_list,
                      color='b')
        plot_line_set(ax,
                      self._coords,
                      self._ra_crit_list,
                      self._dec_crit_list,
                      color='r')
        ra_image, dec_image = self._imageModel.image_positions(
            self._kwargs_else, self._kwargs_lens)
        image_position_plot(ax, self._coords, ra_image[0], dec_image[0])
        source_position_plot(ax, self._coords, self._kwargs_source)

    def convergence_plot(self, ax, v_min=None, v_max=None):
        """

        :param x_grid:
        :param y_grid:
        :param kwargs_lens:
        :param kwargs_else:
        :return:
        """
        kappa_result = util.array2image(
            self._lensModel.kappa(self._x_grid, self._y_grid,
                                  self._kwargs_lens))
        im = ax.matshow(np.log10(kappa_result),
                        origin='lower',
                        extent=[0, self._frame_size, 0, self._frame_size],
                        cmap=self._cmap,
                        vmin=v_min,
                        vmax=v_max)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"', color='w')
        coordinate_arrows(ax,
                          self._frame_size,
                          self._coords,
                          color='w',
                          arrow_size=self._arrow_size)
        text_description(ax,
                         self._frame_size,
                         text="Convergence",
                         color="w",
                         backgroundcolor='k',
                         flipped=False)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'log$_{10}$ $\kappa$', fontsize=15)
        return ax

    def normalized_residual_plot(self, ax, v_min=-6, v_max=6):
        """

        :param ax:
        :param residuals:
        :return:
        """
        im = ax.matshow(self._norm_residuals,
                        vmin=v_min,
                        vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size],
                        cmap='bwr',
                        origin='lower')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"', color='k')
        text_description(ax,
                         self._frame_size,
                         text="Normalized Residuals",
                         color="k",
                         backgroundcolor='w')
        coordinate_arrows(ax,
                          self._frame_size,
                          self._coords,
                          color='k',
                          arrow_size=self._arrow_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'(f$_{model}$-f$_{data}$)/$\sigma$', fontsize=15)
        return ax

    def absolute_residual_plot(self, ax, v_min=-1, v_max=1):
        """

        :param ax:
        :param residuals:
        :return:
        """
        im = ax.matshow(self._model - self._data,
                        vmin=v_min,
                        vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size],
                        cmap='bwr',
                        origin='lower')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"', color='k')
        text_description(ax,
                         self._frame_size,
                         text="Residuals",
                         color="k",
                         backgroundcolor='w')
        coordinate_arrows(ax,
                          self._frame_size,
                          self._coords,
                          color='k',
                          arrow_size=self._arrow_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'(f$_{model}$-f$_{data}$)', fontsize=15)
        return ax

    def source_plot(self,
                    ax,
                    numPix,
                    deltaPix_source,
                    source_sigma=0.001,
                    convolution=False,
                    v_min=None,
                    v_max=None):
        """

        :param ax:
        :param coords_source:
        :param source:
        :return:
        """
        if v_min is None:
            v_min = self._v_min_default
        if v_max is None:
            v_max = self._v_max_default
        d_s = numPix * deltaPix_source
        x_grid_source, y_grid_source = util.make_grid_transformed(
            numPix, self._Mpix2coord * deltaPix_source / self._deltaPix)
        x_center = self._kwargs_source[0]['center_x']
        y_center = self._kwargs_source[0]['center_y']
        x_grid_source += x_center
        y_grid_source += y_center
        coords_source = Coordinates(self._Mpix2coord * deltaPix_source /
                                    self._deltaPix,
                                    ra_at_xy_0=x_grid_source[0],
                                    dec_at_xy_0=y_grid_source[0])

        source = self._imageModel.SourceModel.surface_brightness(
            x_grid_source, y_grid_source, self._kwargs_source)
        source = util.array2image(source)
        if convolution:
            source = ndimage.filters.gaussian_filter(source,
                                                     sigma=source_sigma /
                                                     deltaPix_source,
                                                     mode='nearest',
                                                     truncate=20)

        im = ax.matshow(np.log10(source),
                        origin='lower',
                        extent=[0, d_s, 0, d_s],
                        cmap=self._cmap,
                        vmin=v_min,
                        vmax=v_max)  # source
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'log$_{10}$ flux', fontsize=15)
        plot_line_set(ax,
                      coords_source,
                      self._ra_caustic_list,
                      self._dec_caustic_list,
                      color='b')
        scale_bar(ax, d_s, dist=0.1, text='0.1"', color='w', flipped=False)
        coordinate_arrows(ax,
                          d_s,
                          coords_source,
                          arrow_size=self._arrow_size,
                          color='w')
        text_description(ax,
                         d_s,
                         text="Reconstructed source",
                         color="w",
                         backgroundcolor='k',
                         flipped=False)
        source_position_plot(ax, coords_source, self._kwargs_source)
        return ax

    def error_map_source_plot(self,
                              ax,
                              numPix,
                              deltaPix_source,
                              v_min=None,
                              v_max=None):
        x_grid_source, y_grid_source = util.make_grid_transformed(
            numPix, self._Mpix2coord * deltaPix_source / self._deltaPix)
        x_center = self._kwargs_source[0]['center_x']
        y_center = self._kwargs_source[0]['center_y']
        x_grid_source += x_center
        y_grid_source += y_center
        coords_source = Coordinates(self._Mpix2coord * deltaPix_source /
                                    self._deltaPix,
                                    ra_at_xy_0=x_grid_source[0],
                                    dec_at_xy_0=y_grid_source[0])
        error_map_source = self._analysis.error_map_source(
            self._kwargs_source, x_grid_source, y_grid_source, self._cov_param)
        error_map_source = util.array2image(error_map_source)
        d_s = numPix * deltaPix_source
        im = ax.matshow(error_map_source,
                        origin='lower',
                        extent=[0, d_s, 0, d_s],
                        cmap=self._cmap,
                        vmin=v_min,
                        vmax=v_max)  # source
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'error variance', fontsize=15)
        plot_line_set(ax,
                      coords_source,
                      self._ra_caustic_list,
                      self._dec_caustic_list,
                      color='b')
        scale_bar(ax, d_s, dist=0.1, text='0.1"', color='w', flipped=False)
        coordinate_arrows(ax,
                          d_s,
                          coords_source,
                          arrow_size=self._arrow_size,
                          color='w')
        text_description(ax,
                         d_s,
                         text="Error map in source",
                         color="w",
                         backgroundcolor='k',
                         flipped=False)
        source_position_plot(ax, coords_source, self._kwargs_source)
        return ax

    def magnification_plot(self, ax, v_min=-10, v_max=10):
        """

        :param ax:
        :return:
        """
        mag_result = util.array2image(
            self._lensModel.magnification(self._x_grid, self._y_grid,
                                          self._kwargs_lens))
        im = ax.matshow(mag_result,
                        origin='lower',
                        extent=[0, self._frame_size, 0, self._frame_size],
                        vmin=v_min,
                        vmax=v_max,
                        cmap=self._cmap,
                        alpha=0.5)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"', color='k')
        coordinate_arrows(ax,
                          self._frame_size,
                          self._coords,
                          color='k',
                          arrow_size=self._arrow_size)
        text_description(ax,
                         self._frame_size,
                         text="Magnification model",
                         color="k",
                         backgroundcolor='w')
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'det(A$^{-1}$)', fontsize=15)

        plot_line_set(ax,
                      self._coords,
                      self._ra_caustic_list,
                      self._dec_caustic_list,
                      color='b')
        plot_line_set(ax,
                      self._coords,
                      self._ra_crit_list,
                      self._dec_crit_list,
                      color='r')
        ra_image, dec_image = self._imageModel.image_positions(
            self._kwargs_else, self._kwargs_lens)
        image_position_plot(ax,
                            self._coords,
                            ra_image[0],
                            dec_image[0],
                            color='k')
        source_position_plot(ax, self._coords, self._kwargs_source)
        return ax

    def deflection_plot(self, ax, v_min=None, v_max=None, axis=0):
        """

        :param kwargs_lens:
        :param kwargs_else:
        :return:
        """

        alpha1, alpha2 = self._lensModel.alpha(self._x_grid, self._y_grid,
                                               self._kwargs_lens)
        alpha1 = util.array2image(alpha1)
        alpha2 = util.array2image(alpha2)
        if axis == 0:
            alpha = alpha1
        else:
            alpha = alpha2
        im = ax.matshow(alpha,
                        origin='lower',
                        extent=[0, self._frame_size, 0, self._frame_size],
                        vmin=v_min,
                        vmax=v_max,
                        cmap=self._cmap,
                        alpha=0.5)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"', color='k')
        coordinate_arrows(ax,
                          self._frame_size,
                          self._coords,
                          color='k',
                          arrow_size=self._arrow_size)
        text_description(ax,
                         self._frame_size,
                         text="Deflection model",
                         color="k",
                         backgroundcolor='w')
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'arcsec', fontsize=15)

        plot_line_set(ax,
                      self._coords,
                      self._ra_caustic_list,
                      self._dec_caustic_list,
                      color='b')
        plot_line_set(ax,
                      self._coords,
                      self._ra_crit_list,
                      self._dec_crit_list,
                      color='r')
        ra_image, dec_image = self._imageModel.image_positions(
            self._kwargs_else, self._kwargs_lens)
        image_position_plot(ax, self._coords, ra_image[0], dec_image[0])
        source_position_plot(ax, self._coords, self._kwargs_source)
        return ax

    def decomposition_plot(self,
                           ax,
                           text='Reconstructed',
                           v_min=None,
                           v_max=None,
                           unconvolved=False,
                           point_source_add=False,
                           source_add=False,
                           lens_light_add=False):

        model = self._imageModel.image(self._kwargs_lens,
                                       self._kwargs_source,
                                       self._kwargs_lens_light,
                                       self._kwargs_else,
                                       unconvolved=unconvolved,
                                       source_add=source_add,
                                       lens_light_add=lens_light_add,
                                       point_source_add=point_source_add)
        if v_min is None:
            v_min = self._v_min_default
        if v_max is None:
            v_max = self._v_max_default
        im = ax.matshow(np.log10(model),
                        origin='lower',
                        vmin=v_min,
                        vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size],
                        cmap=self._cmap)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"')
        text_description(ax,
                         self._frame_size,
                         text=text,
                         color="w",
                         backgroundcolor='k')
        coordinate_arrows(ax,
                          self._frame_size,
                          self._coords,
                          arrow_size=self._arrow_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'log$_{10}$ flux', fontsize=15)
        return ax

    def subtract_from_data_plot(self,
                                ax,
                                text='Subtracted',
                                v_min=None,
                                v_max=None,
                                point_source_add=False,
                                source_add=False,
                                lens_light_add=False):
        model = self._imageModel.image(self._kwargs_lens,
                                       self._kwargs_source,
                                       self._kwargs_lens_light,
                                       self._kwargs_else,
                                       unconvolved=False,
                                       source_add=source_add,
                                       lens_light_add=lens_light_add,
                                       point_source_add=point_source_add)
        if v_min is None:
            v_min = self._v_min_default
        if v_max is None:
            v_max = self._v_max_default
        im = ax.matshow(np.log10(self._data - model),
                        origin='lower',
                        vmin=v_min,
                        vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size],
                        cmap=self._cmap)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"')
        text_description(ax,
                         self._frame_size,
                         text=text,
                         color="w",
                         backgroundcolor='k')
        coordinate_arrows(ax,
                          self._frame_size,
                          self._coords,
                          arrow_size=self._arrow_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'log$_{10}$ flux', fontsize=15)
        return ax
class ModelBandPlot(ModelBand):
    """
    class to plot a single band given the modeling results

    """
    def __init__(self,
                 multi_band_list,
                 kwargs_model,
                 model,
                 error_map,
                 cov_param,
                 param,
                 kwargs_params,
                 likelihood_mask_list=None,
                 band_index=0,
                 arrow_size=0.02,
                 cmap_string="gist_heat",
                 fast_caustic=True):
        """

        :param multi_band_list: list of imaging data configuration [[kwargs_data, kwargs_psf, kwargs_numerics], [...]]
        :param kwargs_model: model keyword argument list for the full multi-band modeling
        :param model: 2d numpy array of modeled image for the specified band
        :param error_map: 2d numpy array of size of the image, additional error in the pixels coming from PSF uncertainties
        :param cov_param: covariance matrix of the linear inversion
        :param param: 1d numpy array of the linear coefficients of this imaging band
        :param kwargs_params: keyword argument of keyword argument lists of the different model components selected for
         the imaging band, NOT including linear amplitudes (not required as being overwritten by the param list)
        :param likelihood_mask_list: list of 2d numpy arrays of likelihood masks (for all bands)
        :param band_index: integer of the band to be considered in this class
        :param arrow_size: size of the scale and orientation arrow
        :param cmap_string: string of color map (or cmap matplotlib object)
        :param fast_caustic: boolean; if True, uses fast (but less accurate) caustic calculation method
        """
        ModelBand.__init__(self,
                           multi_band_list,
                           kwargs_model,
                           model,
                           error_map,
                           cov_param,
                           param,
                           kwargs_params,
                           image_likelihood_mask_list=likelihood_mask_list,
                           band_index=band_index)

        self._lensModel = self._bandmodel.LensModel
        self._lensModelExt = LensModelExtensions(self._lensModel)
        log_model = np.log10(model)
        log_model[np.isnan(log_model)] = -5
        self._v_min_default = max(np.min(log_model), -5)
        self._v_max_default = min(np.max(log_model), 10)
        self._coords = self._bandmodel.Data
        self._data = self._coords.data
        self._deltaPix = self._coords.pixel_width
        self._frame_size = np.max(self._coords.width)
        x_grid, y_grid = self._coords.pixel_coordinates
        self._x_grid = util.image2array(x_grid)
        self._y_grid = util.image2array(y_grid)
        self._x_center, self._y_center = self._coords.center

        self._cmap = plot_util.cmap_conf(cmap_string)
        self._arrow_size = arrow_size
        self._fast_caustic = fast_caustic

    def _critical_curves(self):
        if not hasattr(self, '_ra_crit_list') or not hasattr(
                self, '_dec_crit_list'):
            #self._ra_crit_list, self._dec_crit_list, self._ra_caustic_list, self._dec_caustic_list = self._lensModelExt.critical_curve_caustics(
            #    self._kwargs_lens_partial, compute_window=self._frame_size, grid_scale=self._deltaPix / 5.,
            #    center_x=self._x_center, center_y=self._y_center)
            if self._fast_caustic:
                self._ra_crit_list, self._dec_crit_list, self._ra_caustic_list, self._dec_caustic_list = self._lensModelExt.critical_curve_caustics(
                    self._kwargs_lens_partial,
                    compute_window=self._frame_size,
                    grid_scale=self._deltaPix,
                    center_x=self._x_center,
                    center_y=self._y_center)
                self._caustic_points_only = False
            else:
                # only supports individual points due to output of critical_curve_tiling definition
                self._caustic_points_only = True
                self._ra_crit_list, self._dec_crit_list = self._lensModelExt.critical_curve_tiling(
                    self._kwargs_lens_partial,
                    compute_window=self._frame_size,
                    start_scale=self._deltaPix / 5.,
                    max_order=10,
                    center_x=self._x_center,
                    center_y=self._y_center)
                self._ra_caustic_list, self._dec_caustic_list = self._lensModel.ray_shooting(
                    self._ra_crit_list, self._dec_crit_list,
                    self._kwargs_lens_partial)
        return self._ra_crit_list, self._dec_crit_list

    def _caustics(self):
        if not hasattr(self, '_ra_caustic_list') or not hasattr(
                self, '_dec_caustic_list'):
            _, _ = self._critical_curves()
        return self._ra_caustic_list, self._dec_caustic_list

    def data_plot(self,
                  ax,
                  v_min=None,
                  v_max=None,
                  text='Observed',
                  font_size=15,
                  colorbar_label=r'log$_{10}$ flux',
                  **kwargs):
        """

        :param ax:
        :return:
        """
        if v_min is None:
            v_min = self._v_min_default
        if v_max is None:
            v_max = self._v_max_default
        im = ax.matshow(np.log10(self._data),
                        origin='lower',
                        extent=[0, self._frame_size, 0, self._frame_size],
                        cmap=self._cmap,
                        vmin=v_min,
                        vmax=v_max)  # , vmin=0, vmax=2

        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)

        plot_util.scale_bar(ax,
                            self._frame_size,
                            dist=1,
                            text='1"',
                            font_size=font_size)
        plot_util.text_description(ax,
                                   self._frame_size,
                                   text=text,
                                   color="w",
                                   backgroundcolor='k',
                                   font_size=font_size)

        if 'no_arrow' not in kwargs or not kwargs['no_arrow']:
            plot_util.coordinate_arrows(ax,
                                        self._frame_size,
                                        self._coords,
                                        color='w',
                                        arrow_size=self._arrow_size,
                                        font_size=font_size)

        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax, orientation='vertical')
        cb.set_label(colorbar_label, fontsize=font_size)
        return ax

    def model_plot(self,
                   ax,
                   v_min=None,
                   v_max=None,
                   image_names=False,
                   colorbar_label=r'log$_{10}$ flux',
                   font_size=15,
                   text='Reconstructed',
                   **kwargs):
        """

        :param ax: matplotib axis instance
        :param v_min:
        :param v_max:
        :return:
        """
        if v_min is None:
            v_min = self._v_min_default
        if v_max is None:
            v_max = self._v_max_default
        im = ax.matshow(np.log10(self._model),
                        origin='lower',
                        vmin=v_min,
                        vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size],
                        cmap=self._cmap)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        plot_util.scale_bar(ax,
                            self._frame_size,
                            dist=1,
                            text='1"',
                            font_size=font_size)
        plot_util.text_description(ax,
                                   self._frame_size,
                                   text=text,
                                   color="w",
                                   backgroundcolor='k',
                                   font_size=font_size)
        if 'no_arrow' not in kwargs or not kwargs['no_arrow']:
            plot_util.coordinate_arrows(ax,
                                        self._frame_size,
                                        self._coords,
                                        color='w',
                                        arrow_size=self._arrow_size,
                                        font_size=font_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(colorbar_label, fontsize=font_size)

        #plot_line_set(ax, self._coords, self._ra_caustic_list, self._dec_caustic_list, color='b')
        #plot_line_set(ax, self._coords, self._ra_crit_list, self._dec_crit_list, color='r')
        if image_names is True:
            ra_image, dec_image = self._bandmodel.PointSource.image_position(
                self._kwargs_ps_partial, self._kwargs_lens_partial)
            plot_util.image_position_plot(ax, self._coords, ra_image,
                                          dec_image)
        #source_position_plot(ax, self._coords, self._kwargs_source)

    def convergence_plot(self,
                         ax,
                         text='Convergence',
                         v_min=None,
                         v_max=None,
                         font_size=15,
                         colorbar_label=r'$\log_{10}\ \kappa$',
                         **kwargs):
        """

        :param ax: matplotib axis instance
        :return: convergence plot in ax instance
        """
        if not 'cmap' in kwargs:
            kwargs['cmap'] = self._cmap

        kappa_result = util.array2image(
            self._lensModel.kappa(self._x_grid, self._y_grid,
                                  self._kwargs_lens_partial))
        im = ax.matshow(np.log10(kappa_result),
                        origin='lower',
                        extent=[0, self._frame_size, 0, self._frame_size],
                        cmap=kwargs['cmap'],
                        vmin=v_min,
                        vmax=v_max)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        plot_util.scale_bar(ax,
                            self._frame_size,
                            dist=1,
                            text='1"',
                            color='w',
                            font_size=font_size)
        if 'no_arrow' not in kwargs or not kwargs['no_arrow']:
            plot_util.coordinate_arrows(ax,
                                        self._frame_size,
                                        self._coords,
                                        color='w',
                                        arrow_size=self._arrow_size,
                                        font_size=font_size)
            plot_util.text_description(ax,
                                       self._frame_size,
                                       text=text,
                                       color="w",
                                       backgroundcolor='k',
                                       flipped=False,
                                       font_size=font_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(colorbar_label, fontsize=font_size)
        return ax

    def normalized_residual_plot(
            self,
            ax,
            v_min=-6,
            v_max=6,
            font_size=15,
            text="Normalized Residuals",
            colorbar_label=r'(f${}_{\rm model}$ - f${}_{\rm data}$)/$\sigma$',
            no_arrow=False,
            color_bar=True,
            **kwargs):
        """

        :param ax:
        :param v_min:
        :param v_max:
        :param kwargs: kwargs to send to matplotlib.pyplot.matshow()
        :param color_bar: Option to display the color bar
        :return:
        """
        if not 'cmap' in kwargs:
            kwargs['cmap'] = 'bwr'
        im = ax.matshow(self._norm_residuals,
                        vmin=v_min,
                        vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size],
                        origin='lower',
                        **kwargs)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        plot_util.scale_bar(ax,
                            self._frame_size,
                            dist=1,
                            text='1"',
                            color='k',
                            font_size=font_size)
        plot_util.text_description(ax,
                                   self._frame_size,
                                   text=text,
                                   color="k",
                                   backgroundcolor='w',
                                   font_size=font_size)
        if not no_arrow:
            plot_util.coordinate_arrows(ax,
                                        self._frame_size,
                                        self._coords,
                                        color='w',
                                        arrow_size=self._arrow_size,
                                        font_size=font_size)
        if color_bar:
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            cb = plt.colorbar(im, cax=cax)
            cb.set_label(colorbar_label, fontsize=font_size)
        return ax

    def absolute_residual_plot(self,
                               ax,
                               v_min=-1,
                               v_max=1,
                               font_size=15,
                               text="Residuals",
                               colorbar_label=r'(f$_{model}$-f$_{data}$)'):
        """

        :param ax:
        :return:
        """
        im = ax.matshow(self._model - self._data,
                        vmin=v_min,
                        vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size],
                        cmap='bwr',
                        origin='lower')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        plot_util.scale_bar(ax,
                            self._frame_size,
                            dist=1,
                            text='1"',
                            color='k',
                            font_size=font_size)
        plot_util.text_description(ax,
                                   self._frame_size,
                                   text=text,
                                   color="k",
                                   backgroundcolor='w',
                                   font_size=font_size)
        plot_util.coordinate_arrows(ax,
                                    self._frame_size,
                                    self._coords,
                                    font_size=font_size,
                                    color='k',
                                    arrow_size=self._arrow_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(colorbar_label, fontsize=font_size)
        return ax

    def source(self, numPix, deltaPix, center=None, image_orientation=True):
        """

        :param numPix: number of pixels per axes
        :param deltaPix: pixel size
        :param image_orientation: bool, if True, uses frame in orientation of the image, otherwise in RA-DEC coordinates
        :return: 2d surface brightness grid of the reconstructed source and Coordinates() instance of source grid
        """
        if image_orientation is True:
            Mpix2coord = self._coords.transform_pix2angle * deltaPix / self._deltaPix
            x_grid_source, y_grid_source = util.make_grid_transformed(
                numPix, Mpix2Angle=Mpix2coord)
            ra_at_xy_0, dec_at_xy_0 = x_grid_source[0], y_grid_source[0]
        else:
            x_grid_source, y_grid_source, ra_at_xy_0, dec_at_xy_0, x_at_radec_0, y_at_radec_0, Mpix2coord, Mcoord2pix = util.make_grid_with_coordtransform(
                numPix, deltaPix)

        center_x = 0
        center_y = 0
        if center is not None:
            center_x, center_y = center[0], center[1]
        elif len(self._kwargs_source_partial) > 0:
            center_x = self._kwargs_source_partial[0]['center_x']
            center_y = self._kwargs_source_partial[0]['center_y']
        x_grid_source += center_x
        y_grid_source += center_y

        coords_source = Coordinates(transform_pix2angle=Mpix2coord,
                                    ra_at_xy_0=ra_at_xy_0 + center_x,
                                    dec_at_xy_0=dec_at_xy_0 + center_y)

        source = self._bandmodel.SourceModel.surface_brightness(
            x_grid_source, y_grid_source, self._kwargs_source_partial)
        source = util.array2image(source) * deltaPix**2
        return source, coords_source

    def source_plot(self,
                    ax,
                    numPix,
                    deltaPix_source,
                    center=None,
                    v_min=None,
                    v_max=None,
                    with_caustics=False,
                    caustic_color='yellow',
                    font_size=15,
                    plot_scale='log',
                    scale_size=0.1,
                    text="Reconstructed source",
                    colorbar_label=r'log$_{10}$ flux',
                    point_source_position=True,
                    **kwargs):
        """

        :param ax:
        :param numPix:
        :param deltaPix_source:
        :param center: [center_x, center_y], if specified, uses this as the center
        :param v_min:
        :param v_max:
        :param caustic_color:
        :param font_size:
        :param plot_scale: string, log or linear, scale of surface brightness plot
        :param kwargs:
        :return:
        """
        if v_min is None:
            v_min = self._v_min_default
        if v_max is None:
            v_max = self._v_max_default
        d_s = numPix * deltaPix_source
        source, coords_source = self.source(numPix,
                                            deltaPix_source,
                                            center=center)
        if plot_scale == 'log':
            source[source < 10**v_min] = 10**(
                v_min)  # to remove weird shadow in plot
            source_scale = np.log10(source)
        elif plot_scale == 'linear':
            source_scale = source
        else:
            raise ValueError(
                'variable plot_scale needs to be "log" or "linear", not %s.' %
                plot_scale)
        im = ax.matshow(source_scale,
                        origin='lower',
                        extent=[0, d_s, 0, d_s],
                        cmap=self._cmap,
                        vmin=v_min,
                        vmax=v_max)  # source
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(colorbar_label, fontsize=font_size)

        if with_caustics is True:
            ra_caustic_list, dec_caustic_list = self._caustics()
            plot_util.plot_line_set(ax,
                                    coords_source,
                                    ra_caustic_list,
                                    dec_caustic_list,
                                    color=caustic_color,
                                    points_only=self._caustic_points_only)
            plot_util.plot_line_set(ax,
                                    coords_source,
                                    ra_caustic_list,
                                    dec_caustic_list,
                                    color=caustic_color,
                                    points_only=self._caustic_points_only,
                                    **kwargs.get('kwargs_caustic', {}))
            plot_util.scale_bar(ax,
                                d_s,
                                dist=scale_size,
                                text='{:.1f}"'.format(scale_size),
                                color='w',
                                flipped=False,
                                font_size=font_size)
        if 'no_arrow' not in kwargs or not kwargs['no_arrow']:
            plot_util.coordinate_arrows(ax,
                                        self._frame_size,
                                        self._coords,
                                        color='w',
                                        arrow_size=self._arrow_size,
                                        font_size=font_size)
            plot_util.text_description(ax,
                                       d_s,
                                       text=text,
                                       color="w",
                                       backgroundcolor='k',
                                       flipped=False,
                                       font_size=font_size)
        if point_source_position is True:
            ra_source, dec_source = self._bandmodel.PointSource.source_position(
                self._kwargs_ps_partial, self._kwargs_lens_partial)
            plot_util.source_position_plot(ax, coords_source, ra_source,
                                           dec_source)
        return ax

    def error_map_source_plot(self,
                              ax,
                              numPix,
                              deltaPix_source,
                              v_min=None,
                              v_max=None,
                              with_caustics=False,
                              font_size=15,
                              point_source_position=True):
        """
        plots the uncertainty in the surface brightness in the source from the linear inversion by taking the diagonal
        elements of the covariance matrix of the inversion of the basis set to be propagated to the source plane.
        #TODO illustration of the uncertainties in real space with the full covariance matrix is subtle. The best way is probably to draw realizations from the covariance matrix.

        :param ax: matplotlib axis instance
        :param numPix: number of pixels in plot per axis
        :param deltaPix_source: pixel spacing in the source resolution illustrated in plot
        :param v_min: minimum plotting scale of the map
        :param v_max: maximum plotting scale of the map
        :param with_caustics: plot the caustics on top of the source reconstruction (may take some time)
        :param font_size: font size of labels
        :param point_source_position: boolean, if True, plots a point at the position of the point source
        :return: plot of source surface brightness errors in the reconstruction on the axis instance
        """
        x_grid_source, y_grid_source = util.make_grid_transformed(
            numPix, self._coords.transform_pix2angle * deltaPix_source /
            self._deltaPix)
        x_center = self._kwargs_source_partial[0]['center_x']
        y_center = self._kwargs_source_partial[0]['center_y']
        x_grid_source += x_center
        y_grid_source += y_center
        coords_source = Coordinates(self._coords.transform_pix2angle *
                                    deltaPix_source / self._deltaPix,
                                    ra_at_xy_0=x_grid_source[0],
                                    dec_at_xy_0=y_grid_source[0])
        error_map_source = self._bandmodel.error_map_source(
            self._kwargs_source_partial,
            x_grid_source,
            y_grid_source,
            self._cov_param,
            model_index_select=False)
        error_map_source = util.array2image(error_map_source)
        d_s = numPix * deltaPix_source
        im = ax.matshow(error_map_source,
                        origin='lower',
                        extent=[0, d_s, 0, d_s],
                        cmap=self._cmap,
                        vmin=v_min,
                        vmax=v_max)  # source
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'error variance', fontsize=font_size)
        if with_caustics:
            ra_caustic_list, dec_caustic_list = self._caustics()
            plot_util.plot_line_set(ax,
                                    coords_source,
                                    ra_caustic_list,
                                    dec_caustic_list,
                                    color='b',
                                    points_only=self._caustic_points_only)
        plot_util.scale_bar(ax,
                            d_s,
                            dist=0.1,
                            text='0.1"',
                            color='w',
                            flipped=False,
                            font_size=font_size)
        plot_util.coordinate_arrows(ax,
                                    d_s,
                                    coords_source,
                                    arrow_size=self._arrow_size,
                                    color='w',
                                    font_size=font_size)
        plot_util.text_description(ax,
                                   d_s,
                                   text="Error map in source",
                                   color="w",
                                   backgroundcolor='k',
                                   flipped=False,
                                   font_size=font_size)
        if point_source_position is True:
            ra_source, dec_source = self._bandmodel.PointSource.source_position(
                self._kwargs_ps_partial, self._kwargs_lens_partial)
            plot_util.source_position_plot(ax, coords_source, ra_source,
                                           dec_source)
        return ax

    def magnification_plot(self,
                           ax,
                           v_min=-10,
                           v_max=10,
                           image_name_list=None,
                           font_size=15,
                           no_arrow=False,
                           text="Magnification model",
                           colorbar_label=r"$\det\ (\mathsf{A}^{-1})$",
                           **kwargs):
        """

        :param ax: matplotib axis instance
        :param v_min: minimum range of plotting
        :param v_max: maximum range of plotting
        :param kwargs: kwargs to send to matplotlib.pyplot.matshow()
        :return:
        """
        if 'cmap' not in kwargs:
            kwargs['cmap'] = self._cmap
        if 'alpha' not in kwargs:
            kwargs['alpha'] = 0.5
        mag_result = util.array2image(
            self._lensModel.magnification(self._x_grid, self._y_grid,
                                          self._kwargs_lens_partial))
        im = ax.matshow(mag_result,
                        origin='lower',
                        extent=[0, self._frame_size, 0, self._frame_size],
                        vmin=v_min,
                        vmax=v_max,
                        **kwargs)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        plot_util.scale_bar(ax,
                            self._frame_size,
                            dist=1,
                            text='1"',
                            color='k',
                            font_size=font_size)
        if not no_arrow:
            plot_util.coordinate_arrows(ax,
                                        self._frame_size,
                                        self._coords,
                                        color='k',
                                        arrow_size=self._arrow_size,
                                        font_size=font_size)
        plot_util.text_description(ax,
                                   self._frame_size,
                                   text=text,
                                   color="k",
                                   backgroundcolor='w',
                                   font_size=font_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(colorbar_label, fontsize=font_size)
        ra_image, dec_image = self._bandmodel.PointSource.image_position(
            self._kwargs_ps_partial, self._kwargs_lens_partial)
        plot_util.image_position_plot(ax,
                                      self._coords,
                                      ra_image,
                                      dec_image,
                                      color='k',
                                      image_name_list=image_name_list)
        return ax

    def deflection_plot(self,
                        ax,
                        v_min=None,
                        v_max=None,
                        axis=0,
                        with_caustics=False,
                        image_name_list=None,
                        text="Deflection model",
                        font_size=15,
                        colorbar_label=r'arcsec'):
        """

        :return:
        """

        alpha1, alpha2 = self._lensModel.alpha(self._x_grid, self._y_grid,
                                               self._kwargs_lens_partial)
        alpha1 = util.array2image(alpha1)
        alpha2 = util.array2image(alpha2)
        if axis == 0:
            alpha = alpha1
        else:
            alpha = alpha2
        im = ax.matshow(alpha,
                        origin='lower',
                        extent=[0, self._frame_size, 0, self._frame_size],
                        vmin=v_min,
                        vmax=v_max,
                        cmap=self._cmap,
                        alpha=0.5)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        plot_util.scale_bar(ax,
                            self._frame_size,
                            dist=1,
                            text='1"',
                            color='k',
                            font_size=font_size)
        plot_util.coordinate_arrows(ax,
                                    self._frame_size,
                                    self._coords,
                                    color='k',
                                    arrow_size=self._arrow_size,
                                    font_size=font_size)
        plot_util.text_description(ax,
                                   self._frame_size,
                                   text=text,
                                   color="k",
                                   backgroundcolor='w',
                                   font_size=font_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(colorbar_label, fontsize=font_size)
        if with_caustics is True:
            ra_crit_list, dec_crit_list = self._critical_curves()
            ra_caustic_list, dec_caustic_list = self._caustics()
            plot_util.plot_line_set(ax,
                                    self._coords,
                                    ra_caustic_list,
                                    dec_caustic_list,
                                    color='b',
                                    points_only=self._caustic_points_only)
            plot_util.plot_line_set(ax,
                                    self._coords,
                                    ra_crit_list,
                                    dec_crit_list,
                                    color='r',
                                    points_only=self._caustic_points_only)
        ra_image, dec_image = self._bandmodel.PointSource.image_position(
            self._kwargs_ps_partial, self._kwargs_lens_partial)
        plot_util.image_position_plot(ax,
                                      self._coords,
                                      ra_image,
                                      dec_image,
                                      image_name_list=image_name_list)
        return ax

    def decomposition_plot(self,
                           ax,
                           text='Reconstructed',
                           v_min=None,
                           v_max=None,
                           unconvolved=False,
                           point_source_add=False,
                           font_size=15,
                           source_add=False,
                           lens_light_add=False,
                           **kwargs):
        """

        :param ax:
        :param text:
        :param v_min:
        :param v_max:
        :param unconvolved:
        :param point_source_add:
        :param source_add:
        :param lens_light_add:
        :param kwargs: kwargs to send matplotlib.pyplot.matshow()
        :return:
        """
        model = self._bandmodel.image(self._kwargs_lens_partial,
                                      self._kwargs_source_partial,
                                      self._kwargs_lens_light_partial,
                                      self._kwargs_ps_partial,
                                      unconvolved=unconvolved,
                                      source_add=source_add,
                                      lens_light_add=lens_light_add,
                                      point_source_add=point_source_add)
        if v_min is None:
            v_min = self._v_min_default
        if v_max is None:
            v_max = self._v_max_default
        if 'cmap' not in kwargs:
            kwargs['cmap'] = self._cmap
        im = ax.matshow(np.log10(model),
                        origin='lower',
                        vmin=v_min,
                        vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size],
                        **kwargs)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        plot_util.scale_bar(ax,
                            self._frame_size,
                            dist=1,
                            text='1"',
                            font_size=font_size)
        plot_util.text_description(ax,
                                   self._frame_size,
                                   text=text,
                                   color="w",
                                   backgroundcolor='k')
        plot_util.coordinate_arrows(ax,
                                    self._frame_size,
                                    self._coords,
                                    arrow_size=self._arrow_size,
                                    font_size=font_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'log$_{10}$ flux', fontsize=font_size)
        return ax

    def subtract_from_data_plot(self,
                                ax,
                                text='Subtracted',
                                v_min=None,
                                v_max=None,
                                point_source_add=False,
                                source_add=False,
                                lens_light_add=False,
                                font_size=15):
        model = self._bandmodel.image(self._kwargs_lens_partial,
                                      self._kwargs_source_partial,
                                      self._kwargs_lens_light_partial,
                                      self._kwargs_ps_partial,
                                      unconvolved=False,
                                      source_add=source_add,
                                      lens_light_add=lens_light_add,
                                      point_source_add=point_source_add)
        if v_min is None:
            v_min = self._v_min_default
        if v_max is None:
            v_max = self._v_max_default
        im = ax.matshow(np.log10(self._data - model),
                        origin='lower',
                        vmin=v_min,
                        vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size],
                        cmap=self._cmap)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        plot_util.scale_bar(ax,
                            self._frame_size,
                            dist=1,
                            text='1"',
                            font_size=font_size)
        plot_util.text_description(ax,
                                   self._frame_size,
                                   text=text,
                                   color="w",
                                   backgroundcolor='k',
                                   font_size=font_size)
        plot_util.coordinate_arrows(ax,
                                    self._frame_size,
                                    self._coords,
                                    arrow_size=self._arrow_size,
                                    font_size=font_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'log$_{10}$ flux', fontsize=font_size)
        return ax

    def plot_main(self, with_caustics=False):
        """
        print the main plots together in a joint frame

        :return:
        """

        f, axes = plt.subplots(2, 3, figsize=(16, 8))
        self.data_plot(ax=axes[0, 0])
        self.model_plot(ax=axes[0, 1], image_names=True)
        self.normalized_residual_plot(ax=axes[0, 2], v_min=-6, v_max=6)
        self.source_plot(ax=axes[1, 0],
                         deltaPix_source=0.01,
                         numPix=100,
                         with_caustics=with_caustics)
        self.convergence_plot(ax=axes[1, 1], v_max=1)
        self.magnification_plot(ax=axes[1, 2])
        f.tight_layout()
        f.subplots_adjust(left=None,
                          bottom=None,
                          right=None,
                          top=None,
                          wspace=0.,
                          hspace=0.05)
        return f, axes

    def plot_separate(self):
        """
        plot the different model components separately

        :return:
        """
        f, axes = plt.subplots(2, 3, figsize=(16, 8))

        self.decomposition_plot(ax=axes[0, 0],
                                text='Lens light',
                                lens_light_add=True,
                                unconvolved=True)
        self.decomposition_plot(ax=axes[1, 0],
                                text='Lens light convolved',
                                lens_light_add=True)
        self.decomposition_plot(ax=axes[0, 1],
                                text='Source light',
                                source_add=True,
                                unconvolved=True)
        self.decomposition_plot(ax=axes[1, 1],
                                text='Source light convolved',
                                source_add=True)
        self.decomposition_plot(ax=axes[0, 2],
                                text='All components',
                                source_add=True,
                                lens_light_add=True,
                                unconvolved=True)
        self.decomposition_plot(ax=axes[1, 2],
                                text='All components convolved',
                                source_add=True,
                                lens_light_add=True,
                                point_source_add=True)
        f.tight_layout()
        f.subplots_adjust(left=None,
                          bottom=None,
                          right=None,
                          top=None,
                          wspace=0.,
                          hspace=0.05)
        return f, axes

    def plot_subtract_from_data_all(self):
        """
        subtract model components from data

        :return:
        """
        f, axes = plt.subplots(2, 3, figsize=(16, 8))

        self.subtract_from_data_plot(ax=axes[0, 0], text='Data')
        self.subtract_from_data_plot(ax=axes[0, 1],
                                     text='Data - Point Source',
                                     point_source_add=True)
        self.subtract_from_data_plot(ax=axes[0, 2],
                                     text='Data - Lens Light',
                                     lens_light_add=True)
        self.subtract_from_data_plot(ax=axes[1, 0],
                                     text='Data - Source Light',
                                     source_add=True)
        self.subtract_from_data_plot(ax=axes[1, 1],
                                     text='Data - Source Light - Point Source',
                                     source_add=True,
                                     point_source_add=True)
        self.subtract_from_data_plot(ax=axes[1, 2],
                                     text='Data - Lens Light - Point Source',
                                     lens_light_add=True,
                                     point_source_add=True)
        f.tight_layout()
        f.subplots_adjust(left=None,
                          bottom=None,
                          right=None,
                          top=None,
                          wspace=0.,
                          hspace=0.05)
        return f, axes

    def plot_extinction_map(self, ax, v_min=None, v_max=None, **kwargs):
        """

        :param ax:
        :param v_min:
        :param v_max:
        :return:
        """
        model = self._bandmodel.extinction_map(self._kwargs_extinction_partial,
                                               self._kwargs_special_partial)
        if v_min is None:
            v_min = 0
        if v_max is None:
            v_max = 1

        _ = ax.matshow(model,
                       origin='lower',
                       vmin=v_min,
                       vmax=v_max,
                       extent=[0, self._frame_size, 0, self._frame_size],
                       **kwargs)
        return ax
Exemple #8
0
def caustics_plot(ax,
                  pixel_grid,
                  lens_model,
                  kwargs_lens,
                  fast_caustic=True,
                  coord_inverse=False,
                  color_crit='r',
                  color_caustic='g',
                  pixel_offset=False,
                  *args,
                  **kwargs):
    """

    :param ax: matplotlib axis instance
    :param pixel_grid: lenstronomy PixelGrid() instance (or class with inheritance of PixelGrid()
    :param lens_model: LensModel() class instance
    :param kwargs_lens: lens model keyword argument list
    :param fast_caustic: boolean, if True, uses faster but less precise caustic calculation
     (might have troubles for the outer caustic (inner critical curve)
    :param coord_inverse: bool, if True, inverts the x-coordinates to go from right-to-left
     (effectively the RA definition)
    :param color_crit: string, color of critical curve
    :param color_caustic: string, color of caustic curve
    :param pixel_offset: boolean; if True (default plotting), the coordinates are shifted a half a pixel to match with
     the matshow() command to center the coordinates in the pixel center
    :param args: argument for plotting curve
    :param kwargs: keyword arguments for plotting curves
    :return: updated matplotlib axis instance
    """
    lens_model_ext = LensModelExtensions(lens_model)
    pixel_width = pixel_grid.pixel_width
    frame_size = np.max(pixel_grid.width)
    coord_center_ra, coord_center_dec = pixel_grid.center
    ra0, dec0 = pixel_grid.radec_at_xy_0
    origin = [ra0, dec0]
    if fast_caustic:
        ra_crit_list, dec_crit_list, ra_caustic_list, dec_caustic_list = lens_model_ext.critical_curve_caustics(
            kwargs_lens,
            compute_window=frame_size,
            grid_scale=pixel_width,
            center_x=coord_center_ra,
            center_y=coord_center_dec)
        points_only = False
    else:
        # only supports individual points due to output of critical_curve_tiling definition
        points_only = True
        ra_crit_list, dec_crit_list = lens_model_ext.critical_curve_tiling(
            kwargs_lens,
            compute_window=frame_size,
            start_scale=pixel_width,
            max_order=10,
            center_x=coord_center_ra,
            center_y=coord_center_dec)
        ra_caustic_list, dec_caustic_list = lens_model.ray_shooting(
            ra_crit_list, dec_crit_list, kwargs_lens)
        # ra_crit_list, dec_crit_list = list(ra_crit_list), list(dec_crit_list)
        # ra_caustic_list, dec_caustic_list = list(ra_caustic_list), list(dec_caustic_list)
    plot_util.plot_line_set(ax,
                            pixel_grid,
                            ra_caustic_list,
                            dec_caustic_list,
                            color=color_caustic,
                            origin=origin,
                            flipped_x=coord_inverse,
                            points_only=points_only,
                            pixel_offset=pixel_offset,
                            *args,
                            **kwargs)
    plot_util.plot_line_set(ax,
                            pixel_grid,
                            ra_crit_list,
                            dec_crit_list,
                            color=color_crit,
                            origin=origin,
                            flipped_x=coord_inverse,
                            points_only=points_only,
                            pixel_offset=pixel_offset,
                            *args,
                            **kwargs)
    return ax