Exemplo n.º 1
0
class LensModelPlot(object):
    """
    class that manages the summary plots of a lens model
    """
    def __init__(self, kwargs_data, kwargs_psf, kwargs_numerics, kwargs_model, kwargs_lens, kwargs_source,
                 kwargs_lens_light, kwargs_ps, arrow_size=0.02, cmap_string="gist_heat"):
        """

        :param kwargs_options:
        :param kwargs_data:
        :param arrow_size:
        :param cmap_string:
        """
        self._kwargs_data = kwargs_data
        if isinstance(cmap_string, str) or isinstance(cmap_string, unicode):
            cmap = plt.get_cmap(cmap_string)
        else:
            cmap = cmap_string
        cmap.set_bad(color='k', alpha=1.)
        cmap.set_under('k')
        self._cmap = cmap
        self._arrow_size = arrow_size
        data = Data(kwargs_data)
        self._coords = data._coords
        nx, ny = np.shape(kwargs_data['image_data'])
        Mpix2coord = kwargs_data['transform_pix2angle']
        self._Mpix2coord = Mpix2coord

        self._deltaPix = self._coords.pixel_size
        self._frame_size = self._deltaPix * nx

        x_grid, y_grid = data.coordinates
        self._x_grid = util.image2array(x_grid)
        self._y_grid = util.image2array(y_grid)

        self._imageModel = class_creator.create_image_model(kwargs_data, kwargs_psf, kwargs_numerics, kwargs_model)
        self._analysis = LensAnalysis(kwargs_model)
        self._lensModel = LensModel(lens_model_list=kwargs_model.get('lens_model_list', []),
                                 z_source=kwargs_model.get('z_source', None),
                                 redshift_list=kwargs_model.get('redshift_list', None),
                                 multi_plane=kwargs_model.get('multi_plane', False))
        self._lensModelExt = LensModelExtensions(self._lensModel)
        model, error_map, cov_param, param = self._imageModel.image_linear_solve(kwargs_lens, kwargs_source,
                                                                                 kwargs_lens_light, kwargs_ps, inv_bool=True)
        self._kwargs_lens = kwargs_lens
        self._kwargs_source = kwargs_source
        self._kwargs_lens_light = kwargs_lens_light
        self._kwargs_else = kwargs_ps
        self._model = model
        self._data = kwargs_data['image_data']
        self._cov_param = cov_param
        self._norm_residuals = self._imageModel.reduced_residuals(model, error_map=error_map)
        self._reduced_x2 = self._imageModel.reduced_chi2(model, error_map=error_map)
        log_model = np.log10(model)
        log_model[np.isnan(log_model)] = -5
        self._v_min_default = max(np.min(log_model), -5)
        self._v_max_default = min(np.max(log_model), 10)
        print("reduced chi^2 = ", self._reduced_x2)

    def _critical_curves(self):
        if not hasattr(self, '_ra_crit_list') or not hasattr(self, '_dec_crit_list'):
            self._ra_crit_list, self._dec_crit_list = self._lensModelExt.critical_curve_tiling(self._kwargs_lens,
                                                                                        compute_window=self._frame_size,
                                                                                        start_scale=self._deltaPix / 5.,
                                                                                        max_order=10)
        return self._ra_crit_list, self._dec_crit_list

    def _caustics(self):
        if not hasattr(self, '_ra_caustic_list') or not hasattr(self, '_dec_caustic_list'):
            ra_crit_list, dec_crit_list = self._critical_curves()
            self._ra_caustic_list, self._dec_caustic_list = self._lensModel.ray_shooting(ra_crit_list,
                                                                                     dec_crit_list, self._kwargs_lens)
        return self._ra_caustic_list, self._dec_caustic_list

    def data_plot(self, ax, v_min=None, v_max=None, text='Observed'):
        """

        :param ax:
        :return:
        """
        if v_min is None:
            v_min = self._v_min_default
        if v_max is None:
            v_max = self._v_max_default
        im = ax.matshow(np.log10(self._data), origin='lower',
                        extent=[0, self._frame_size, 0, self._frame_size], cmap=self._cmap, vmin=v_min, vmax=v_max)  # , vmin=0, vmax=2

        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)

        scale_bar(ax, self._frame_size, dist=1, text='1"')
        text_description(ax, self._frame_size, text=text, color="w", backgroundcolor='k')
        coordinate_arrows(ax, self._frame_size, self._coords, arrow_size=self._arrow_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax, orientation='vertical')
        cb.set_label(r'log$_{10}$ flux', fontsize=15)
        return ax

    def model_plot(self, ax, v_min=None, v_max=None, image_names=False):
        """

        :param ax:
        :param model:
        :param v_min:
        :param v_max:
        :return:
        """
        if v_min is None:
            v_min = self._v_min_default
        if v_max is None:
            v_max = self._v_max_default
        im = ax.matshow(np.log10(self._model), origin='lower', vmin=v_min, vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size], cmap=self._cmap)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"')
        text_description(ax, self._frame_size, text="Reconstructed", color="w", backgroundcolor='k')
        coordinate_arrows(ax, self._frame_size, self._coords, arrow_size=self._arrow_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'log$_{10}$ flux', fontsize=15)

        #plot_line_set(ax, self._coords, self._ra_caustic_list, self._dec_caustic_list, color='b')
        #plot_line_set(ax, self._coords, self._ra_crit_list, self._dec_crit_list, color='r')
        if image_names is True:
            ra_image, dec_image = self._imageModel.image_positions(self._kwargs_else, self._kwargs_lens)
            image_position_plot(ax, self._coords, ra_image, dec_image)
        #source_position_plot(ax, self._coords, self._kwargs_source)

    def convergence_plot(self, ax, v_min=None, v_max=None):
        """

        :param x_grid:
        :param y_grid:
        :param kwargs_lens:
        :param kwargs_else:
        :return:
        """
        kappa_result = util.array2image(self._lensModel.kappa(self._x_grid, self._y_grid, self._kwargs_lens))
        im = ax.matshow(np.log10(kappa_result), origin='lower',
                        extent=[0, self._frame_size, 0, self._frame_size], cmap=self._cmap, vmin=v_min, vmax=v_max)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"', color='w')
        coordinate_arrows(ax, self._frame_size, self._coords, color='w', arrow_size=self._arrow_size)
        text_description(ax, self._frame_size, text="Convergence", color="w", backgroundcolor='k', flipped=False)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'log$_{10}$ $\kappa$', fontsize=15)
        return ax

    def normalized_residual_plot(self, ax, v_min=-6, v_max=6, **kwargs):
        """

        :param ax:
        :param v_min:
        :param v_max:
        :param kwargs: kwargs to send to matplotlib.pyplot.matshow()
        :return:
        """
        if not 'cmap' in kwargs:
            kwargs['cmap'] = 'bwr'
        im = ax.matshow(self._norm_residuals, vmin=v_min, vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size], origin='lower', **kwargs)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"', color='k')
        text_description(ax, self._frame_size, text="Normalized Residuals", color="k", backgroundcolor='w')
        coordinate_arrows(ax, self._frame_size, self._coords, color='k', arrow_size=self._arrow_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'(f$_{model}$-f$_{data}$)/$\sigma$', fontsize=15)
        return ax

    def absolute_residual_plot(self, ax, v_min=-1, v_max=1):
        """

        :param ax:
        :param residuals:
        :return:
        """
        im = ax.matshow(self._model - self._data, vmin=v_min, vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size], cmap='bwr', origin='lower')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"', color='k')
        text_description(ax, self._frame_size, text="Residuals", color="k", backgroundcolor='w')
        coordinate_arrows(ax, self._frame_size, self._coords, color='k', arrow_size=self._arrow_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'(f$_{model}$-f$_{data}$)', fontsize=15)
        return ax

    def source_plot(self, ax, numPix, deltaPix_source, source_sigma=0.001, convolution=False, v_min=None, v_max=None, with_caustics=False):
        """

        :param ax:
        :param coords_source:
        :param source:
        :return:
        """
        if v_min is None:
            v_min = self._v_min_default
        if v_max is None:
            v_max = self._v_max_default
        d_s = numPix * deltaPix_source
        x_grid_source, y_grid_source = util.make_grid_transformed(numPix,
                                                                  self._Mpix2coord * deltaPix_source / self._deltaPix)
        if len(self._kwargs_source) > 0:
            x_center = self._kwargs_source[0]['center_x']
            y_center = self._kwargs_source[0]['center_y']
            x_grid_source += x_center
            y_grid_source += y_center
        coords_source = Coordinates(self._Mpix2coord * deltaPix_source / self._deltaPix, ra_at_xy_0=x_grid_source[0],
                                    dec_at_xy_0=y_grid_source[0])

        source = self._imageModel.SourceModel.surface_brightness(x_grid_source, y_grid_source, self._kwargs_source)
        source = util.array2image(source)
        if convolution is True:
            source = ndimage.filters.gaussian_filter(source, sigma=source_sigma / deltaPix_source, mode='nearest',
                                                      truncate=20)

        im = ax.matshow(np.log10(source), origin='lower', extent=[0, d_s, 0, d_s],
                        cmap=self._cmap, vmin=v_min, vmax=v_max)  # source
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'log$_{10}$ flux', fontsize=15)
        if with_caustics:
            ra_caustic_list, dec_caustic_list = self._caustics()
            plot_line_set(ax, coords_source, ra_caustic_list, dec_caustic_list, color='b')
        scale_bar(ax, d_s, dist=0.1, text='0.1"', color='w', flipped=False)
        coordinate_arrows(ax, d_s, coords_source, arrow_size=self._arrow_size, color='w')
        text_description(ax, d_s, text="Reconstructed source", color="w", backgroundcolor='k', flipped=False)
        source_position_plot(ax, coords_source, self._kwargs_source)
        return ax

    def error_map_source_plot(self, ax, numPix, deltaPix_source, v_min=None, v_max=None, with_caustics=False):
        x_grid_source, y_grid_source = util.make_grid_transformed(numPix,
                                                                  self._Mpix2coord * deltaPix_source / self._deltaPix)
        x_center = self._kwargs_source[0]['center_x']
        y_center = self._kwargs_source[0]['center_y']
        x_grid_source += x_center
        y_grid_source += y_center
        coords_source = Coordinates(self._Mpix2coord * deltaPix_source / self._deltaPix, ra_at_xy_0=x_grid_source[0],
                                    dec_at_xy_0=y_grid_source[0])
        error_map_source = self._analysis.error_map_source(self._kwargs_source, x_grid_source, y_grid_source, self._cov_param)
        error_map_source = util.array2image(error_map_source)
        d_s = numPix * deltaPix_source
        im = ax.matshow(error_map_source, origin='lower', extent=[0, d_s, 0, d_s],
                        cmap=self._cmap, vmin=v_min, vmax=v_max)  # source
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'error variance', fontsize=15)
        if with_caustics:
            ra_caustic_list, dec_caustic_list = self._caustics()
            plot_line_set(ax, coords_source, ra_caustic_list, dec_caustic_list, color='b')
        scale_bar(ax, d_s, dist=0.1, text='0.1"', color='w', flipped=False)
        coordinate_arrows(ax, d_s, coords_source, arrow_size=self._arrow_size, color='w')
        text_description(ax, d_s, text="Error map in source", color="w", backgroundcolor='k', flipped=False)
        source_position_plot(ax, coords_source, self._kwargs_source)
        return ax

    def magnification_plot(self, ax, v_min=-10, v_max=10, with_caustics=False, image_name_list=None, **kwargs):
        """

        :param ax:
        :param v_min:
        :param v_max:
        :param with_caustics:
        :param kwargs: kwargs to send to matplotlib.pyplot.matshow()
        :return:
        """
        if not 'cmap' in kwargs:
            kwargs['cmap'] = self._cmap
        if not 'alpha' in kwargs:
            kwargs['alpha'] = 0.5
        mag_result = util.array2image(self._lensModel.magnification(self._x_grid, self._y_grid, self._kwargs_lens))
        im = ax.matshow(mag_result, origin='lower', extent=[0, self._frame_size, 0, self._frame_size],
                        vmin=v_min, vmax=v_max, **kwargs)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"', color='k')
        coordinate_arrows(ax, self._frame_size, self._coords, color='k', arrow_size=self._arrow_size)
        text_description(ax, self._frame_size, text="Magnification model", color="k", backgroundcolor='w')
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'det(A$^{-1}$)', fontsize=15)
        if with_caustics:
            ra_crit_list, dec_crit_list = self._critical_curves()
            ra_caustic_list, dec_caustic_list = self._caustics()
            plot_line_set(ax, self._coords, ra_caustic_list, dec_caustic_list, color='b')
            plot_line_set(ax, self._coords, ra_crit_list, dec_crit_list, color='r')
        ra_image, dec_image = self._imageModel.image_positions(self._kwargs_else, self._kwargs_lens)
        image_position_plot(ax, self._coords, ra_image, dec_image, color='k', image_name_list=image_name_list)
        source_position_plot(ax, self._coords, self._kwargs_source)
        return ax

    def deflection_plot(self, ax, v_min=None, v_max=None, axis=0, with_caustics=False, image_name_list=None):
        """

        :param kwargs_lens:
        :param kwargs_else:
        :return:
        """

        alpha1, alpha2 = self._lensModel.alpha(self._x_grid, self._y_grid, self._kwargs_lens)
        alpha1 = util.array2image(alpha1)
        alpha2 = util.array2image(alpha2)
        if axis == 0:
            alpha = alpha1
        else:
            alpha = alpha2
        im = ax.matshow(alpha, origin='lower', extent=[0, self._frame_size, 0, self._frame_size],
                        vmin=v_min, vmax=v_max, cmap=self._cmap, alpha=0.5)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"', color='k')
        coordinate_arrows(ax, self._frame_size, self._coords, color='k', arrow_size=self._arrow_size)
        text_description(ax, self._frame_size, text="Deflection model", color="k", backgroundcolor='w')
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'arcsec', fontsize=15)
        if with_caustics:
            ra_crit_list, dec_crit_list = self._critical_curves()
            ra_caustic_list, dec_caustic_list = self._caustics()
            plot_line_set(ax, self._coords, ra_caustic_list, dec_caustic_list, color='b')
            plot_line_set(ax, self._coords, ra_crit_list, dec_crit_list, color='r')
        ra_image, dec_image = self._imageModel.image_positions(self._kwargs_else, self._kwargs_lens)
        image_position_plot(ax, self._coords, ra_image, dec_image, image_name_list=image_name_list)
        source_position_plot(ax, self._coords, self._kwargs_source)
        return ax

    def decomposition_plot(self, ax, text='Reconstructed', v_min=None, v_max=None, unconvolved=False, point_source_add=False, source_add=False, lens_light_add=False, **kwargs):
        """

        :param ax:
        :param text:
        :param v_min:
        :param v_max:
        :param unconvolved:
        :param point_source_add:
        :param source_add:
        :param lens_light_add:
        :param kwargs: kwargs to send matplotlib.pyplot.matshow()
        :return:
        """
        model = self._imageModel.image(self._kwargs_lens, self._kwargs_source, self._kwargs_lens_light,
                                          self._kwargs_else, unconvolved=unconvolved, source_add=source_add,
                                          lens_light_add=lens_light_add, point_source_add=point_source_add)
        if v_min is None:
            v_min = self._v_min_default
        if v_max is None:
            v_max = self._v_max_default
        if not 'cmap' in kwargs:
            kwargs['cmap'] = self._cmap
        im = ax.matshow(np.log10(model), origin='lower', vmin=v_min, vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size], **kwargs)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"')
        text_description(ax, self._frame_size, text=text, color="w", backgroundcolor='k')
        coordinate_arrows(ax, self._frame_size, self._coords, arrow_size=self._arrow_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'log$_{10}$ flux', fontsize=15)
        return ax

    def subtract_from_data_plot(self, ax, text='Subtracted', v_min=None, v_max=None, point_source_add=False, source_add=False, lens_light_add=False):
        model = self._imageModel.image(self._kwargs_lens, self._kwargs_source, self._kwargs_lens_light,
                                          self._kwargs_else, unconvolved=False, source_add=source_add,
                                          lens_light_add=lens_light_add, point_source_add=point_source_add)
        if v_min is None:
            v_min = self._v_min_default
        if v_max is None:
            v_max = self._v_max_default
        im = ax.matshow(np.log10(self._data - model), origin='lower', vmin=v_min, vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size], cmap=self._cmap)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"')
        text_description(ax, self._frame_size, text=text, color="w", backgroundcolor='k')
        coordinate_arrows(ax, self._frame_size, self._coords, arrow_size=self._arrow_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'log$_{10}$ flux', fontsize=15)
        return ax

    def plot_main(self):
        """
        print the main plots together in a joint frame

        :return:
        """

        f, axes = plt.subplots(2, 3, figsize=(16, 8))
        self.data_plot(ax=axes[0, 0])
        self.model_plot(ax=axes[0, 1])
        self.normalized_residual_plot(ax=axes[0, 2], v_min=-6, v_max=6)
        self.source_plot(ax=axes[1, 0], convolution=False, deltaPix_source=0.01, numPix=100)
        self.convergence_plot(ax=axes[1, 1], v_max=1)
        self.magnification_plot(ax=axes[1, 2])
        f.tight_layout()
        f.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0., hspace=0.05)
        return f, axes

    def plot_separate(self):
        """
        plot the different model components separately

        :return:
        """
        f, axes = plt.subplots(2, 3, figsize=(16, 8))

        self.decomposition_plot(ax=axes[0, 0], text='Lens light', lens_light_add=True, unconvolved=True)
        self.decomposition_plot(ax=axes[1, 0], text='Lens light convolved', lens_light_add=True)
        self.decomposition_plot(ax=axes[0, 1], text='Source light', source_add=True, unconvolved=True)
        self.decomposition_plot(ax=axes[1, 1], text='Source light convolved', source_add=True)
        self.decomposition_plot(ax=axes[0, 2], text='All components', source_add=True, lens_light_add=True,
                                    unconvolved=True)
        self.decomposition_plot(ax=axes[1, 2], text='All components convolved', source_add=True,
                                    lens_light_add=True, point_source_add=True)
        f.tight_layout()
        f.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0., hspace=0.05)
        return f, axes

    def plot_subtract_from_data_all(self):
        """
        subtract model components from data

        :return:
        """
        f, axes = plt.subplots(2, 3, figsize=(16, 8))

        self.subtract_from_data_plot(ax=axes[0, 0], text='Data')
        self.subtract_from_data_plot(ax=axes[0, 1], text='Data - Point Source', point_source_add=True)
        self.subtract_from_data_plot(ax=axes[0, 2], text='Data - Lens Light', lens_light_add=True)
        self.subtract_from_data_plot(ax=axes[1, 0], text='Data - Source Light', source_add=True)
        self.subtract_from_data_plot(ax=axes[1, 1], text='Data - Source Light - Point Source', source_add=True,
                                         point_source_add=True)
        self.subtract_from_data_plot(ax=axes[1, 2], text='Data - Lens Light - Point Source', lens_light_add=True,
                                         point_source_add=True)
        f.tight_layout()
        f.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0., hspace=0.05)
        return f, axes
Exemplo n.º 2
0
class LensModelPlot(object):
    """
    class that manages the summary plots of a lens model
    """
    def __init__(self,
                 kwargs_data,
                 kwargs_psf,
                 kwargs_numerics,
                 kwargs_model,
                 kwargs_lens,
                 kwargs_source,
                 kwargs_lens_light,
                 kwargs_ps,
                 arrow_size=0.1,
                 cmap_string="gist_heat",
                 high_res=5):
        """

        :param kwargs_options:
        :param kwargs_data:
        :param arrow_size:
        :param cmap_string:
        """
        self._kwargs_data = kwargs_data
        if isinstance(cmap_string, str) or isinstance(cmap_string, unicode):
            cmap = plt.get_cmap(cmap_string)
        else:
            cmap = cmap_string
        cmap.set_bad(color='k', alpha=1.)
        cmap.set_under('k')
        self._cmap = cmap
        self._arrow_size = arrow_size
        data = Data(kwargs_data)
        self._coords = data._coords
        nx, ny = np.shape(kwargs_data['image_data'])
        Mpix2coord = kwargs_data['transform_pix2angle']
        self._Mpix2coord = Mpix2coord

        self._deltaPix = self._coords.pixel_size
        self._frame_size = self._deltaPix * nx

        self._x_grid, self._y_grid = data.coordinates

        self._imageModel = class_creator.creat_image_model(
            kwargs_data, kwargs_psf, kwargs_numerics, kwargs_model)
        self._analysis = LensAnalysis(kwargs_model)
        self._lensModel = LensModelExtensions(
            lens_model_list=kwargs_model.get('lens_model_list', ['NONE']),
            z_source=kwargs_model.get('z_source', None),
            redshift_list=kwargs_model.get('redshift_list', None),
            multi_plane=kwargs_model.get('multi_plane', False))
        self._ra_crit_list, self._dec_crit_list, self._ra_caustic_list, self._dec_caustic_list = self._lensModel.critical_curve_caustics(
            kwargs_lens, compute_window=self._frame_size, grid_scale=0.01)

        model, error_map, cov_param, param = self._imageModel.image_linear_solve(
            kwargs_lens,
            kwargs_source,
            kwargs_lens_light,
            kwargs_ps,
            inv_bool=True)
        self._kwargs_lens = kwargs_lens
        self._kwargs_source = kwargs_source
        self._kwargs_lens_light = kwargs_lens_light
        self._kwargs_else = kwargs_ps
        self._model = model
        self._data = kwargs_data['image_data']
        self._cov_param = cov_param
        self._norm_residuals = self._imageModel.reduced_residuals(
            model, error_map=error_map)
        self._reduced_x2 = self._imageModel.reduced_chi2(model,
                                                         error_map=error_map)
        log_model = np.log10(model)
        log_model[np.isnan(log_model)] = -5
        self._v_min_default = max(np.min(log_model), -5)
        self._v_max_default = min(np.max(log_model), 10)
        print("reduced chi^^ = ", self._reduced_x2)

    def data_plot(self, ax, v_min=None, v_max=None):
        """

        :param ax:
        :return:
        """
        if v_min is None:
            v_min = self._v_min_default
        if v_max is None:
            v_max = self._v_max_default
        im = ax.matshow(np.log10(self._data),
                        origin='lower',
                        extent=[0, self._frame_size, 0, self._frame_size],
                        cmap=self._cmap,
                        vmin=v_min,
                        vmax=v_max)  # , vmin=0, vmax=2

        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)

        scale_bar(ax, self._frame_size, dist=1, text='1"')
        text_description(ax,
                         self._frame_size,
                         text="Observed",
                         color="w",
                         backgroundcolor='k')
        coordinate_arrows(ax,
                          self._frame_size,
                          self._coords,
                          arrow_size=self._arrow_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'log$_{10}$ flux', fontsize=15)
        return ax

    def model_plot(self, ax, v_min=None, v_max=None):
        """

        :param ax:
        :param model:
        :param v_min:
        :param v_max:
        :return:
        """
        if v_min is None:
            v_min = self._v_min_default
        if v_max is None:
            v_max = self._v_max_default
        im = ax.matshow(np.log10(self._model),
                        origin='lower',
                        vmin=v_min,
                        vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size],
                        cmap=self._cmap)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"')
        text_description(ax,
                         self._frame_size,
                         text="Reconstructed",
                         color="w",
                         backgroundcolor='k')
        coordinate_arrows(ax,
                          self._frame_size,
                          self._coords,
                          arrow_size=self._arrow_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'log$_{10}$ flux', fontsize=15)

        plot_line_set(ax,
                      self._coords,
                      self._ra_caustic_list,
                      self._dec_caustic_list,
                      color='b')
        plot_line_set(ax,
                      self._coords,
                      self._ra_crit_list,
                      self._dec_crit_list,
                      color='r')
        ra_image, dec_image = self._imageModel.image_positions(
            self._kwargs_else, self._kwargs_lens)
        image_position_plot(ax, self._coords, ra_image[0], dec_image[0])
        source_position_plot(ax, self._coords, self._kwargs_source)

    def convergence_plot(self, ax, v_min=None, v_max=None):
        """

        :param x_grid:
        :param y_grid:
        :param kwargs_lens:
        :param kwargs_else:
        :return:
        """
        kappa_result = util.array2image(
            self._lensModel.kappa(self._x_grid, self._y_grid,
                                  self._kwargs_lens))
        im = ax.matshow(np.log10(kappa_result),
                        origin='lower',
                        extent=[0, self._frame_size, 0, self._frame_size],
                        cmap=self._cmap,
                        vmin=v_min,
                        vmax=v_max)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"', color='w')
        coordinate_arrows(ax,
                          self._frame_size,
                          self._coords,
                          color='w',
                          arrow_size=self._arrow_size)
        text_description(ax,
                         self._frame_size,
                         text="Convergence",
                         color="w",
                         backgroundcolor='k',
                         flipped=False)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'log$_{10}$ $\kappa$', fontsize=15)
        return ax

    def normalized_residual_plot(self, ax, v_min=-6, v_max=6):
        """

        :param ax:
        :param residuals:
        :return:
        """
        im = ax.matshow(self._norm_residuals,
                        vmin=v_min,
                        vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size],
                        cmap='bwr',
                        origin='lower')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"', color='k')
        text_description(ax,
                         self._frame_size,
                         text="Normalized Residuals",
                         color="k",
                         backgroundcolor='w')
        coordinate_arrows(ax,
                          self._frame_size,
                          self._coords,
                          color='k',
                          arrow_size=self._arrow_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'(f$_{model}$-f$_{data}$)/$\sigma$', fontsize=15)
        return ax

    def absolute_residual_plot(self, ax, v_min=-1, v_max=1):
        """

        :param ax:
        :param residuals:
        :return:
        """
        im = ax.matshow(self._model - self._data,
                        vmin=v_min,
                        vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size],
                        cmap='bwr',
                        origin='lower')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"', color='k')
        text_description(ax,
                         self._frame_size,
                         text="Residuals",
                         color="k",
                         backgroundcolor='w')
        coordinate_arrows(ax,
                          self._frame_size,
                          self._coords,
                          color='k',
                          arrow_size=self._arrow_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'(f$_{model}$-f$_{data}$)', fontsize=15)
        return ax

    def source_plot(self,
                    ax,
                    numPix,
                    deltaPix_source,
                    source_sigma=0.001,
                    convolution=False,
                    v_min=None,
                    v_max=None):
        """

        :param ax:
        :param coords_source:
        :param source:
        :return:
        """
        if v_min is None:
            v_min = self._v_min_default
        if v_max is None:
            v_max = self._v_max_default
        d_s = numPix * deltaPix_source
        x_grid_source, y_grid_source = util.make_grid_transformed(
            numPix, self._Mpix2coord * deltaPix_source / self._deltaPix)
        x_center = self._kwargs_source[0]['center_x']
        y_center = self._kwargs_source[0]['center_y']
        x_grid_source += x_center
        y_grid_source += y_center
        coords_source = Coordinates(self._Mpix2coord * deltaPix_source /
                                    self._deltaPix,
                                    ra_at_xy_0=x_grid_source[0],
                                    dec_at_xy_0=y_grid_source[0])

        source = self._imageModel.SourceModel.surface_brightness(
            x_grid_source, y_grid_source, self._kwargs_source)
        source = util.array2image(source)
        if convolution:
            source = ndimage.filters.gaussian_filter(source,
                                                     sigma=source_sigma /
                                                     deltaPix_source,
                                                     mode='nearest',
                                                     truncate=20)

        im = ax.matshow(np.log10(source),
                        origin='lower',
                        extent=[0, d_s, 0, d_s],
                        cmap=self._cmap,
                        vmin=v_min,
                        vmax=v_max)  # source
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'log$_{10}$ flux', fontsize=15)
        plot_line_set(ax,
                      coords_source,
                      self._ra_caustic_list,
                      self._dec_caustic_list,
                      color='b')
        scale_bar(ax, d_s, dist=0.1, text='0.1"', color='w', flipped=False)
        coordinate_arrows(ax,
                          d_s,
                          coords_source,
                          arrow_size=self._arrow_size,
                          color='w')
        text_description(ax,
                         d_s,
                         text="Reconstructed source",
                         color="w",
                         backgroundcolor='k',
                         flipped=False)
        source_position_plot(ax, coords_source, self._kwargs_source)
        return ax

    def error_map_source_plot(self,
                              ax,
                              numPix,
                              deltaPix_source,
                              v_min=None,
                              v_max=None):
        x_grid_source, y_grid_source = util.make_grid_transformed(
            numPix, self._Mpix2coord * deltaPix_source / self._deltaPix)
        x_center = self._kwargs_source[0]['center_x']
        y_center = self._kwargs_source[0]['center_y']
        x_grid_source += x_center
        y_grid_source += y_center
        coords_source = Coordinates(self._Mpix2coord * deltaPix_source /
                                    self._deltaPix,
                                    ra_at_xy_0=x_grid_source[0],
                                    dec_at_xy_0=y_grid_source[0])
        error_map_source = self._analysis.error_map_source(
            self._kwargs_source, x_grid_source, y_grid_source, self._cov_param)
        error_map_source = util.array2image(error_map_source)
        d_s = numPix * deltaPix_source
        im = ax.matshow(error_map_source,
                        origin='lower',
                        extent=[0, d_s, 0, d_s],
                        cmap=self._cmap,
                        vmin=v_min,
                        vmax=v_max)  # source
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'error variance', fontsize=15)
        plot_line_set(ax,
                      coords_source,
                      self._ra_caustic_list,
                      self._dec_caustic_list,
                      color='b')
        scale_bar(ax, d_s, dist=0.1, text='0.1"', color='w', flipped=False)
        coordinate_arrows(ax,
                          d_s,
                          coords_source,
                          arrow_size=self._arrow_size,
                          color='w')
        text_description(ax,
                         d_s,
                         text="Error map in source",
                         color="w",
                         backgroundcolor='k',
                         flipped=False)
        source_position_plot(ax, coords_source, self._kwargs_source)
        return ax

    def magnification_plot(self, ax, v_min=-10, v_max=10):
        """

        :param ax:
        :return:
        """
        mag_result = util.array2image(
            self._lensModel.magnification(self._x_grid, self._y_grid,
                                          self._kwargs_lens))
        im = ax.matshow(mag_result,
                        origin='lower',
                        extent=[0, self._frame_size, 0, self._frame_size],
                        vmin=v_min,
                        vmax=v_max,
                        cmap=self._cmap,
                        alpha=0.5)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"', color='k')
        coordinate_arrows(ax,
                          self._frame_size,
                          self._coords,
                          color='k',
                          arrow_size=self._arrow_size)
        text_description(ax,
                         self._frame_size,
                         text="Magnification model",
                         color="k",
                         backgroundcolor='w')
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'det(A$^{-1}$)', fontsize=15)

        plot_line_set(ax,
                      self._coords,
                      self._ra_caustic_list,
                      self._dec_caustic_list,
                      color='b')
        plot_line_set(ax,
                      self._coords,
                      self._ra_crit_list,
                      self._dec_crit_list,
                      color='r')
        ra_image, dec_image = self._imageModel.image_positions(
            self._kwargs_else, self._kwargs_lens)
        image_position_plot(ax,
                            self._coords,
                            ra_image[0],
                            dec_image[0],
                            color='k')
        source_position_plot(ax, self._coords, self._kwargs_source)
        return ax

    def deflection_plot(self, ax, v_min=None, v_max=None, axis=0):
        """

        :param kwargs_lens:
        :param kwargs_else:
        :return:
        """

        alpha1, alpha2 = self._lensModel.alpha(self._x_grid, self._y_grid,
                                               self._kwargs_lens)
        alpha1 = util.array2image(alpha1)
        alpha2 = util.array2image(alpha2)
        if axis == 0:
            alpha = alpha1
        else:
            alpha = alpha2
        im = ax.matshow(alpha,
                        origin='lower',
                        extent=[0, self._frame_size, 0, self._frame_size],
                        vmin=v_min,
                        vmax=v_max,
                        cmap=self._cmap,
                        alpha=0.5)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"', color='k')
        coordinate_arrows(ax,
                          self._frame_size,
                          self._coords,
                          color='k',
                          arrow_size=self._arrow_size)
        text_description(ax,
                         self._frame_size,
                         text="Deflection model",
                         color="k",
                         backgroundcolor='w')
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'arcsec', fontsize=15)

        plot_line_set(ax,
                      self._coords,
                      self._ra_caustic_list,
                      self._dec_caustic_list,
                      color='b')
        plot_line_set(ax,
                      self._coords,
                      self._ra_crit_list,
                      self._dec_crit_list,
                      color='r')
        ra_image, dec_image = self._imageModel.image_positions(
            self._kwargs_else, self._kwargs_lens)
        image_position_plot(ax, self._coords, ra_image[0], dec_image[0])
        source_position_plot(ax, self._coords, self._kwargs_source)
        return ax

    def decomposition_plot(self,
                           ax,
                           text='Reconstructed',
                           v_min=None,
                           v_max=None,
                           unconvolved=False,
                           point_source_add=False,
                           source_add=False,
                           lens_light_add=False):

        model = self._imageModel.image(self._kwargs_lens,
                                       self._kwargs_source,
                                       self._kwargs_lens_light,
                                       self._kwargs_else,
                                       unconvolved=unconvolved,
                                       source_add=source_add,
                                       lens_light_add=lens_light_add,
                                       point_source_add=point_source_add)
        if v_min is None:
            v_min = self._v_min_default
        if v_max is None:
            v_max = self._v_max_default
        im = ax.matshow(np.log10(model),
                        origin='lower',
                        vmin=v_min,
                        vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size],
                        cmap=self._cmap)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"')
        text_description(ax,
                         self._frame_size,
                         text=text,
                         color="w",
                         backgroundcolor='k')
        coordinate_arrows(ax,
                          self._frame_size,
                          self._coords,
                          arrow_size=self._arrow_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'log$_{10}$ flux', fontsize=15)
        return ax

    def subtract_from_data_plot(self,
                                ax,
                                text='Subtracted',
                                v_min=None,
                                v_max=None,
                                point_source_add=False,
                                source_add=False,
                                lens_light_add=False):
        model = self._imageModel.image(self._kwargs_lens,
                                       self._kwargs_source,
                                       self._kwargs_lens_light,
                                       self._kwargs_else,
                                       unconvolved=False,
                                       source_add=source_add,
                                       lens_light_add=lens_light_add,
                                       point_source_add=point_source_add)
        if v_min is None:
            v_min = self._v_min_default
        if v_max is None:
            v_max = self._v_max_default
        im = ax.matshow(np.log10(self._data - model),
                        origin='lower',
                        vmin=v_min,
                        vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size],
                        cmap=self._cmap)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        scale_bar(ax, self._frame_size, dist=1, text='1"')
        text_description(ax,
                         self._frame_size,
                         text=text,
                         color="w",
                         backgroundcolor='k')
        coordinate_arrows(ax,
                          self._frame_size,
                          self._coords,
                          arrow_size=self._arrow_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'log$_{10}$ flux', fontsize=15)
        return ax