Ejemplo n.º 1
0
 def test_raise(self):
     with self.assertRaises(ValueError):
         lighModel = LightModel(light_model_list=['WRONG'])
     with self.assertRaises(ValueError):
         lighModel = LightModel(light_model_list=['UNIFORM'])
         lighModel.light_3d(r=1, kwargs_list=[{'amp': 1}])
     with self.assertRaises(ValueError):
         lighModel = LightModel(light_model_list=['UNIFORM'])
         lighModel.profile_type_list = ['WRONG']
         lighModel.functions_split(x=0, y=0, kwargs_list=[{}])
     with self.assertRaises(ValueError):
         lighModel = LightModel(light_model_list=['UNIFORM'])
         lighModel.profile_type_list = ['WRONG']
         lighModel.num_param_linear(kwargs_list=[{}])
     with self.assertRaises(ValueError):
         lighModel = LightModel(light_model_list=['UNIFORM'])
         lighModel.profile_type_list = ['WRONG']
         lighModel.update_linear(param=[1], i=0, kwargs_list=[{}])
     with self.assertRaises(ValueError):
         lighModel = LightModel(light_model_list=['UNIFORM'])
         lighModel.profile_type_list = ['WRONG']
         lighModel.total_flux(kwargs_list=[{}])
Ejemplo n.º 2
0
class LightParam(object):
    """
    class manages the parameters corresponding to the LightModel() module. Also manages linear parameter handling.
    """
    def __init__(self,
                 light_model_list,
                 kwargs_fixed,
                 kwargs_lower=None,
                 kwargs_upper=None,
                 param_type='light',
                 linear_solver=True):
        """

        :param light_model_list: list of light models
        :param kwargs_fixed: list of keyword arguments corresponding to parameters held fixed during sampling
        :param kwargs_lower: list of keyword arguments indicating hard lower limit of the parameter space
        :param kwargs_upper: list of keyword arguments indicating hard upper limit of the parameter space
        :param param_type: string (optional), adding specifications in the output strings (such as lens light or
         source light)
        :param linear_solver: bool, if True fixes the linear amplitude parameters 'amp' (avoid sampling) such that they
         get overwritten by the linear solver solution.
        """
        self._lightModel = LightModel(light_model_list=light_model_list)
        self._param_name_list = self._lightModel.param_name_list
        self._type = param_type
        self.model_list = light_model_list
        self.kwargs_fixed = kwargs_fixed
        if linear_solver:
            self.kwargs_fixed = self._lightModel.add_fixed_linear(
                self.kwargs_fixed)
        self._linear_solve = linear_solver
        if kwargs_lower is None:
            kwargs_lower = []
            for func in self._lightModel.func_list:
                kwargs_lower.append(func.lower_limit_default)
        if kwargs_upper is None:
            kwargs_upper = []
            for func in self._lightModel.func_list:
                kwargs_upper.append(func.upper_limit_default)
        self.lower_limit = kwargs_lower
        self.upper_limit = kwargs_upper

    @property
    def param_name_list(self):
        return self._param_name_list

    def get_params(self, args, i):
        """

        :param args: list of floats corresponding ot the arguments being sampled
        :param i: int, index of the first argument that is managed/read-out by this class
        :return: keyword argument list of the light profile, index after reading out the arguments corresponding to
         this class
        """
        kwargs_list = []
        for k, model in enumerate(self.model_list):
            kwargs = {}
            kwargs_fixed = self.kwargs_fixed[k]
            param_names = self._param_name_list[k]
            for name in param_names:
                if name not in kwargs_fixed:
                    if model in [
                            'SHAPELETS', 'SHAPELETS_POLAR',
                            'SHAPELETS_POLAR_EXP'
                    ] and name == 'amp':
                        if 'n_max' in kwargs_fixed:
                            n_max = kwargs_fixed['n_max']
                        else:
                            raise ValueError('n_max needs to be fixed in %s.' %
                                             model)
                        if model in ['SHAPELETS_POLAR_EXP']:
                            num_param = int((n_max + 1)**2)
                        else:
                            num_param = int((n_max + 1) * (n_max + 2) / 2)
                        kwargs['amp'] = args[i:i + num_param]
                        i += num_param
                    elif model in ['MULTI_GAUSSIAN', 'MULTI_GAUSSIAN_ELLIPSE'
                                   ] and name == 'amp':
                        if 'sigma' in kwargs_fixed:
                            num_param = len(kwargs_fixed['sigma'])
                        else:
                            raise ValueError('sigma needs to be fixed in %s.' %
                                             model)
                        kwargs['amp'] = args[i:i + num_param]
                        i += num_param
                    elif model in ['SLIT_STARLETS', 'SLIT_STARLETS_GEN2'
                                   ] and name == 'amp':
                        if 'n_scales' in kwargs_fixed and 'n_pixels' in kwargs_fixed:
                            n_scales = kwargs_fixed['n_scales']
                            n_pixels = kwargs_fixed['n_pixels']
                        else:
                            raise ValueError(
                                "'n_scales' and 'n_pixels' both need to be fixed in %s."
                                % model)
                        num_param = n_scales * n_pixels
                        kwargs['amp'] = args[i:i + num_param]
                        i += num_param
                    else:
                        kwargs[name] = args[i]
                        i += 1
                else:
                    kwargs[name] = kwargs_fixed[name]

            kwargs_list.append(kwargs)
        return kwargs_list, i

    def set_params(self, kwargs_list):
        """

        :param kwargs_list: list of keyword arguments of the light profile (free parameter as well as optionally the
         fixed ones)
        :return: list of floats corresponding to the free parameters
        """
        args = []
        for k, model in enumerate(self.model_list):
            kwargs = kwargs_list[k]
            kwargs_fixed = self.kwargs_fixed[k]

            param_names = self._param_name_list[k]
            for name in param_names:
                if name not in kwargs_fixed:
                    if model in [
                            'SHAPELETS', 'SHAPELETS_POLAR',
                            'SHAPELETS_POLAR_EXP'
                    ] and name == 'amp':
                        n_max = kwargs_fixed.get('n_max', kwargs['n_max'])
                        if model in ['SHAPELETS_POLAR_EXP']:
                            num_param = int((n_max + 1)**2)
                        else:
                            num_param = int((n_max + 1) * (n_max + 2) / 2)
                        for i in range(num_param):
                            args.append(kwargs[name][i])
                    elif model in ['SLIT_STARLETS', 'SLIT_STARLETS_GEN2'
                                   ] and name == 'amp':
                        if 'n_scales' in kwargs_fixed:
                            n_scales = kwargs_fixed['n_scales']
                        else:
                            raise ValueError(
                                "'n_scales' for SLIT_STARLETS not found in kwargs_fixed"
                            )
                        if 'n_pixels' in kwargs_fixed:
                            n_pixels = kwargs_fixed['n_pixels']
                        else:
                            raise ValueError(
                                "'n_pixels' for SLIT_STARLETS not found in kwargs_fixed"
                            )
                        num_param = n_scales * n_pixels
                        for i in range(num_param):
                            args.append(kwargs[name][i])
                    elif model in ['SLIT_STARLETS', 'SLIT_STARLETS_GEN2'
                                   ] and name in [
                                       'n_scales', 'n_pixels', 'scale',
                                       'center_x', 'center_y'
                                   ]:
                        raise ValueError(
                            "'{}' must be a fixed keyword argument for STARLETS-like models"
                            .format(name))
                    elif model in ['MULTI_GAUSSIAN', 'MULTI_GAUSSIAN_ELLIPSE'
                                   ] and name == 'amp':
                        num_param = len(kwargs['sigma'])
                        for i in range(num_param):
                            args.append(kwargs[name][i])
                    elif model in ['MULTI_GAUSSIAN', 'MULTI_GAUSSIAN_ELLIPSE'
                                   ] and name == 'sigma':
                        raise ValueError(
                            "'sigma' must be a fixed keyword argument for MULTI_GAUSSIAN"
                        )
                    else:
                        args.append(kwargs[name])
        return args

    def num_param(self, latex_style=False):
        """
        :param latex_style: boolena; if True, returns latex strings for plotting
        :return: int, list of strings with param names
        """
        num = 0
        name_list = []
        for k, model in enumerate(self.model_list):
            kwargs_fixed = self.kwargs_fixed[k]
            param_names = self._param_name_list[k]
            for name in param_names:
                if name not in kwargs_fixed:
                    if model in [
                            'SHAPELETS', 'SHAPELETS_POLAR',
                            'SHAPELETS_POLAR_EXP'
                    ] and name == 'amp':
                        if 'n_max' not in kwargs_fixed:
                            raise ValueError(
                                "n_max needs to be fixed in this configuration!"
                            )
                        n_max = kwargs_fixed['n_max']
                        if model in ['SHAPELETS_POLAR_EXP']:
                            num_param = int((n_max + 1)**2)
                        else:
                            num_param = int((n_max + 1) * (n_max + 2) / 2)
                        num += num_param
                        for i in range(num_param):
                            name_list.append(
                                str(name + '_' + self._type + str(k)))
                    elif model in ['SLIT_STARLETS', 'SLIT_STARLETS_GEN2'
                                   ] and name == 'amp':
                        if 'n_scales' not in kwargs_fixed or 'n_pixels' not in kwargs_fixed:
                            raise ValueError(
                                "n_scales and n_pixels need to be fixed when using STARLETS-like models!"
                            )
                        n_scales = kwargs_fixed['n_scales']
                        n_pixels = kwargs_fixed['n_pixels']
                        num_param = n_scales * n_pixels
                        num += num_param
                        for i in range(num_param):
                            name_list.append(
                                str(name + '_' + self._type + str(k)))
                    elif model in ['MULTI_GAUSSIAN', 'MULTI_GAUSSIAN_ELLIPSE'
                                   ] and name == 'amp':
                        num_param = len(kwargs_fixed['sigma'])
                        num += num_param
                        for i in range(num_param):
                            name_list.append(
                                str(name + '_' + self._type + str(k)))
                    else:
                        num += 1
                        name_list.append(str(name + '_' + self._type + str(k)))
        return num, name_list

    def num_param_linear(self):
        """
        :return: number of linear basis set coefficients
        """
        return self._lightModel.num_param_linear(kwargs_list=self.kwargs_fixed)
Ejemplo n.º 3
0
class TestLightModel(object):
    """
    tests the source model routines
    """
    def setup(self):
        self.light_model_list = [
            'GAUSSIAN', 'MULTI_GAUSSIAN', 'SERSIC', 'SERSIC_ELLIPSE',
            'CORE_SERSIC', 'SHAPELETS', 'HERNQUIST', 'HERNQUIST_ELLIPSE',
            'PJAFFE', 'PJAFFE_ELLIPSE', 'UNIFORM', 'POWER_LAW', 'NIE',
            'INTERPOL', 'SHAPELETS_POLAR_EXP', 'ELLIPSOID'
        ]
        phi_G, q = 0.5, 0.8
        e1, e2 = param_util.phi_q2_ellipticity(phi_G, q)
        self.kwargs = [
            {
                'amp': 1.,
                'sigma': 1.,
                'center_x': 0,
                'center_y': 0
            },  # 'GAUSSIAN'
            {
                'amp': [1., 2],
                'sigma': [1, 3],
                'center_x': 0,
                'center_y': 0
            },  # 'MULTI_GAUSSIAN'
            {
                'amp': 1,
                'R_sersic': 0.5,
                'n_sersic': 1,
                'center_x': 0,
                'center_y': 0
            },  # 'SERSIC'
            {
                'amp': 1,
                'R_sersic': 0.5,
                'n_sersic': 1,
                'e1': e1,
                'e2': e2,
                'center_x': 0,
                'center_y': 0
            },  # 'SERSIC_ELLIPSE'
            {
                'amp': 1,
                'R_sersic': 0.5,
                'Rb': 0.1,
                'gamma': 2.,
                'n_sersic': 1,
                'e1': e1,
                'e2': e2,
                'center_x': 0,
                'center_y': 0
            },
            # 'CORE_SERSIC'
            {
                'amp': [1, 1, 1],
                'beta': 0.5,
                'n_max': 1,
                'center_x': 0,
                'center_y': 0
            },  # 'SHAPELETS'
            {
                'amp': 1,
                'Rs': 0.5,
                'center_x': 0,
                'center_y': 0
            },  # 'HERNQUIST'
            {
                'amp': 1,
                'Rs': 0.5,
                'center_x': 0,
                'center_y': 0,
                'e1': e1,
                'e2': e2
            },  # 'HERNQUIST_ELLIPSE'
            {
                'amp': 1,
                'Ra': 1,
                'Rs': 0.5,
                'center_x': 0,
                'center_y': 0
            },  # 'PJAFFE'
            {
                'amp': 1,
                'Ra': 1,
                'Rs': 0.5,
                'center_x': 0,
                'center_y': 0,
                'e1': e1,
                'e2': e2
            },  # 'PJAFFE_ELLIPSE'
            {
                'amp': 1
            },  # 'UNIFORM'
            {
                'amp': 1.,
                'gamma': 2.,
                'e1': e1,
                'e2': e2,
                'center_x': 0,
                'center_y': 0
            },  # 'POWER_LAW'
            {
                'amp': .001,
                'e1': 0,
                'e2': 1.,
                'center_x': 0,
                'center_y': 0,
                's_scale': 1.
            },  # 'NIE'
            {
                'image': np.zeros((20, 5)),
                'scale': 1,
                'phi_G': 0,
                'center_x': 0,
                'center_y': 0
            },
            {
                'amp': [1],
                'n_max': 0,
                'beta': 1,
                'center_x': 0,
                'center_y': 0
            },
            {
                'amp': 1,
                'radius': 1.,
                'e1': 0,
                'e2': 0.1,
                'center_x': 0,
                'center_y': 0
            }  # 'ELLIPSOID'
        ]

        self.LightModel = LightModel(light_model_list=self.light_model_list)

    def test_init(self):
        model_list = [
            'CORE_SERSIC', 'SHAPELETS', 'SHAPELETS_POLAR',
            'SHAPELETS_POLAR_EXP', 'UNIFORM', 'CHAMELEON', 'DOUBLE_CHAMELEON',
            'TRIPLE_CHAMELEON'
        ]
        lightModel = LightModel(light_model_list=model_list)
        assert len(lightModel.profile_type_list) == len(model_list)

    def test_surface_brightness(self):
        output = self.LightModel.surface_brightness(x=1.,
                                                    y=1.,
                                                    kwargs_list=self.kwargs)
        npt.assert_almost_equal(output, 2.5886852663397137, decimal=6)

    def test_surface_brightness_array(self):
        output = self.LightModel.surface_brightness(x=[1],
                                                    y=[1],
                                                    kwargs_list=self.kwargs)
        npt.assert_almost_equal(output[0], 2.5886852663397137, decimal=6)

    def test_functions_split(self):
        output = self.LightModel.functions_split(x=1.,
                                                 y=1.,
                                                 kwargs_list=self.kwargs)
        npt.assert_almost_equal(output[0][0], 0.058549831524319168, decimal=6)

    def test_param_name_list(self):
        param_name_list = self.LightModel.param_name_list
        assert len(self.light_model_list) == len(param_name_list)

    def test_num_param_linear(self):
        num = self.LightModel.num_param_linear(self.kwargs, list_return=False)
        assert num == 19

        num_list = self.LightModel.num_param_linear(self.kwargs,
                                                    list_return=True)
        assert num_list[0] == 1

    def test_update_linear(self):
        response, n = self.LightModel.functions_split(1, 1, self.kwargs)
        param = np.ones(n) * 2
        kwargs_out, i = self.LightModel.update_linear(param,
                                                      i=0,
                                                      kwargs_list=self.kwargs)
        assert i == n
        assert kwargs_out[0]['amp'] == 2

    def test_total_flux(self):
        light_model_list = [
            'SERSIC', 'SERSIC_ELLIPSE', 'INTERPOL', 'GAUSSIAN',
            'GAUSSIAN_ELLIPSE', 'MULTI_GAUSSIAN', 'MULTI_GAUSSIAN_ELLIPSE'
        ]
        kwargs_list = [
            {
                'amp': 1,
                'R_sersic': 0.5,
                'n_sersic': 1,
                'center_x': 0,
                'center_y': 0
            },  # 'SERSIC'
            {
                'amp': 1,
                'R_sersic': 0.5,
                'n_sersic': 1,
                'e1': 0.1,
                'e2': 0,
                'center_x': 0,
                'center_y': 0
            },  # 'SERSIC_ELLIPSE'
            {
                'image': np.ones((20, 5)),
                'scale': 1,
                'phi_G': 0,
                'center_x': 0,
                'center_y': 0
            },  # 'INTERPOL'
            {
                'amp': 2,
                'sigma': 2,
                'center_x': 0,
                'center_y': 0
            },  # 'GAUSSIAN'
            {
                'amp': 2,
                'sigma': 2,
                'e1': 0.1,
                'e2': 0,
                'center_x': 0,
                'center_y': 0
            },  # 'GAUSSIAN_ELLIPSE'
            {
                'amp': [1, 1],
                'sigma': [2, 1],
                'center_x': 0,
                'center_y': 0
            },  # 'MULTI_GAUSSIAN'
            {
                'amp': [1, 1],
                'sigma': [2, 1],
                'e1': 0.1,
                'e2': 0,
                'center_x': 0,
                'center_y': 0
            }  # 'MULTI_GAUSSIAN_ELLIPSE'
        ]
        lightModel = LightModel(light_model_list=light_model_list)
        total_flux_list = lightModel.total_flux(kwargs_list)
        assert total_flux_list[2] == 100
        assert total_flux_list[3] == 2
        assert total_flux_list[4] == 2
        assert total_flux_list[5] == 2
        assert total_flux_list[6] == 2

        total_flux_list = lightModel.total_flux(kwargs_list, norm=True)
        assert total_flux_list[2] == 100
        assert total_flux_list[3] == 1
        assert total_flux_list[4] == 1
        assert total_flux_list[5] == 2
        assert total_flux_list[6] == 2

    def test_delete_interpol_caches(self):
        x, y = util.make_grid(numPix=20, deltapix=1.)
        gauss = Gaussian()
        flux = gauss.function(x, y, amp=1., center_x=0., center_y=0., sigma=1.)
        image = util.array2image(flux)

        light_model_list = ['INTERPOL', 'INTERPOL']
        kwargs_list = [{
            'image': image,
            'scale': 1,
            'phi_G': 0,
            'center_x': 0,
            'center_y': 0
        }, {
            'image': image,
            'scale': 1,
            'phi_G': 0,
            'center_x': 0,
            'center_y': 0
        }]
        lightModel = LightModel(light_model_list=light_model_list)
        output = lightModel.surface_brightness(x, y, kwargs_list)
        for func in lightModel.func_list:
            assert hasattr(func, '_image_interp')
        lightModel.delete_interpol_caches()
        for func in lightModel.func_list:
            assert not hasattr(func, '_image_interp')

    def test_check_positive_flux_profile(self):
        ligthModel = LightModel(light_model_list=['GAUSSIAN'])
        kwargs_list = [{'amp': 0, 'sigma': 1}]
        bool = ligthModel.check_positive_flux_profile(kwargs_list)
        assert bool

        kwargs_list = [{'amp': -1, 'sigma': 1}]
        bool = ligthModel.check_positive_flux_profile(kwargs_list)
        assert not bool
Ejemplo n.º 4
0
class TestLightModel(object):
    """
    tests the source model routines
    """
    def setup(self):
        self.light_model_list = [
            'GAUSSIAN', 'MULTI_GAUSSIAN', 'SERSIC', 'SERSIC_ELLIPSE',
            'CORE_SERSIC', 'SHAPELETS', 'HERNQUIST', 'HERNQUIST_ELLIPSE',
            'PJAFFE', 'PJAFFE_ELLIPSE', 'UNIFORM', 'POWER_LAW', 'NIE',
            'INTERPOL', 'SHAPELETS_POLAR_EXP'
        ]
        phi_G, q = 0.5, 0.8
        e1, e2 = param_util.phi_q2_ellipticity(phi_G, q)
        self.kwargs = [
            {
                'amp': 1.,
                'sigma_x': 1,
                'sigma_y': 1.,
                'center_x': 0,
                'center_y': 0
            },  # 'GAUSSIAN'
            {
                'amp': [1., 2],
                'sigma': [1, 3],
                'center_x': 0,
                'center_y': 0
            },  # 'MULTI_GAUSSIAN'
            {
                'amp': 1,
                'R_sersic': 0.5,
                'n_sersic': 1,
                'center_x': 0,
                'center_y': 0
            },  # 'SERSIC'
            {
                'amp': 1,
                'R_sersic': 0.5,
                'n_sersic': 1,
                'e1': e1,
                'e2': e2,
                'center_x': 0,
                'center_y': 0
            },  # 'SERSIC_ELLIPSE'
            {
                'amp': 1,
                'R_sersic': 0.5,
                'Re': 0.1,
                'gamma': 2.,
                'n_sersic': 1,
                'e1': e1,
                'e2': e2,
                'center_x': 0,
                'center_y': 0
            },
            # 'CORE_SERSIC'
            {
                'amp': [1, 1, 1],
                'beta': 0.5,
                'n_max': 1,
                'center_x': 0,
                'center_y': 0
            },  # 'SHAPELETS'
            {
                'amp': 1,
                'Rs': 0.5,
                'center_x': 0,
                'center_y': 0
            },  # 'HERNQUIST'
            {
                'amp': 1,
                'Rs': 0.5,
                'center_x': 0,
                'center_y': 0,
                'e1': e1,
                'e2': e2
            },  # 'HERNQUIST_ELLIPSE'
            {
                'amp': 1,
                'Ra': 1,
                'Rs': 0.5,
                'center_x': 0,
                'center_y': 0
            },  # 'PJAFFE'
            {
                'amp': 1,
                'Ra': 1,
                'Rs': 0.5,
                'center_x': 0,
                'center_y': 0,
                'e1': e1,
                'e2': e2
            },  # 'PJAFFE_ELLIPSE'
            {
                'amp': 1
            },  # 'UNIFORM'
            {
                'amp': 1.,
                'gamma': 2.,
                'e1': e1,
                'e2': e2,
                'center_x': 0,
                'center_y': 0
            },  # 'POWER_LAW'
            {
                'amp': .001,
                'e1': 0,
                'e2': 1.,
                'center_x': 0,
                'center_y': 0,
                's_scale': 1.
            },  # 'NIE'
            {
                'image': np.zeros((10, 10)),
                'scale': 1,
                'phi_G': 0,
                'center_x': 0,
                'center_y': 0
            },
            {
                'amp': [1],
                'n_max': 0,
                'beta': 1,
                'center_x': 0,
                'center_y': 0
            }
        ]

        self.LightModel = LightModel(light_model_list=self.light_model_list)

    def test_init(self):
        model_list = [
            'CORE_SERSIC', 'SHAPELETS', 'SHAPELETS_POLAR',
            'SHAPELETS_POLAR_EXP', 'UNIFORM', 'CHAMELEON', 'DOUBLE_CHAMELEON',
            'TRIPLE_CHAMELEON'
        ]
        lightModel = LightModel(light_model_list=model_list)
        assert len(lightModel.profile_type_list) == len(model_list)

    def test_surface_brightness(self):
        output = self.LightModel.surface_brightness(x=1,
                                                    y=1,
                                                    kwargs_list=self.kwargs)
        npt.assert_almost_equal(output, 3.7065728131855824, decimal=6)

    def test_surface_brightness_array(self):
        output = self.LightModel.surface_brightness(x=[1],
                                                    y=[1],
                                                    kwargs_list=self.kwargs)
        npt.assert_almost_equal(output[0], 3.7065728131855824, decimal=6)

    def test_functions_split(self):
        output = self.LightModel.functions_split(x=1.,
                                                 y=1.,
                                                 kwargs_list=self.kwargs)
        npt.assert_almost_equal(output[0][0], 0.058549831524319168, decimal=6)

    def test_re_normalize_flux(self):
        kwargs_out = self.LightModel.re_normalize_flux(kwargs_list=self.kwargs,
                                                       norm_factor=2)
        assert kwargs_out[0]['amp'] == 2 * self.kwargs[0]['amp']

    def test_param_name_list(self):
        param_name_list = self.LightModel.param_name_list()
        assert len(self.light_model_list) == len(param_name_list)

    def test_num_param_linear(self):
        num = self.LightModel.num_param_linear(self.kwargs, list_return=False)
        assert num == 18

        num_list = self.LightModel.num_param_linear(self.kwargs,
                                                    list_return=True)
        assert num_list[0] == 1

    def test_update_linear(self):
        response, n = self.LightModel.functions_split(1, 1, self.kwargs)
        param = np.ones(n) * 2
        kwargs_out, i = self.LightModel.update_linear(param,
                                                      i=0,
                                                      kwargs_list=self.kwargs)
        assert i == n
        assert kwargs_out[0]['amp'] == 2

    def test_total_flux(self):
        light_model_list = [
            'SERSIC', 'SERSIC_ELLIPSE', 'INTERPOL', 'GAUSSIAN',
            'GAUSSIAN_ELLIPSE', 'MULTI_GAUSSIAN', 'MULTI_GAUSSIAN_ELLIPSE'
        ]
        kwargs_list = [
            {
                'amp': 1,
                'R_sersic': 0.5,
                'n_sersic': 1,
                'center_x': 0,
                'center_y': 0
            },  # 'SERSIC'
            {
                'amp': 1,
                'R_sersic': 0.5,
                'n_sersic': 1,
                'e1': 0.1,
                'e2': 0,
                'center_x': 0,
                'center_y': 0
            },  # 'SERSIC_ELLIPSE'
            {
                'image': np.ones((10, 10)),
                'scale': 1,
                'phi_G': 0,
                'center_x': 0,
                'center_y': 0
            },  # 'INTERPOL'
            {
                'amp': 2,
                'sigma_x': 2,
                'sigma_y': 1,
                'center_x': 0,
                'center_y': 0
            },  # 'GAUSSIAN'
            {
                'amp': 2,
                'sigma': 2,
                'e1': 0.1,
                'e2': 0,
                'center_x': 0,
                'center_y': 0
            },  # 'GAUSSIAN_ELLIPSE'
            {
                'amp': [1, 1],
                'sigma': [2, 1],
                'center_x': 0,
                'center_y': 0
            },  # 'MULTI_GAUSSIAN'
            {
                'amp': [1, 1],
                'sigma': [2, 1],
                'e1': 0.1,
                'e2': 0,
                'center_x': 0,
                'center_y': 0
            }  # 'MULTI_GAUSSIAN_ELLIPSE'
        ]
        lightModel = LightModel(light_model_list=light_model_list)
        total_flux_list = lightModel.total_flux(kwargs_list)
        assert total_flux_list[2] == 100
        assert total_flux_list[3] == 2
        assert total_flux_list[4] == 2
        assert total_flux_list[5] == 2
        assert total_flux_list[6] == 2

        total_flux_list = lightModel.total_flux(kwargs_list, norm=True)
        assert total_flux_list[2] == 100
        assert total_flux_list[3] == 1
        assert total_flux_list[4] == 1
        assert total_flux_list[5] == 2
        assert total_flux_list[6] == 2
Ejemplo n.º 5
0
class LightParam(object):
    """

    """
    def __init__(self,
                 light_model_list,
                 kwargs_fixed,
                 kwargs_lower=None,
                 kwargs_upper=None,
                 type='light',
                 linear_solver=True):
        self._lightModel = LightModel(light_model_list=light_model_list)
        self._param_name_list = self._lightModel.param_name_list()
        self._type = type
        self.model_list = light_model_list
        self.kwargs_fixed = kwargs_fixed
        if linear_solver:
            self.kwargs_fixed = self.add_fixed_linear(self.kwargs_fixed)
        self._linear_solve = linear_solver
        if kwargs_lower is None:
            kwargs_lower = []
            for func in self._lightModel.func_list:
                kwargs_lower.append(func.lower_limit_default)
        if kwargs_upper is None:
            kwargs_upper = []
            for func in self._lightModel.func_list:
                kwargs_upper.append(func.upper_limit_default)
        self.lower_limit = kwargs_lower
        self.upper_limit = kwargs_upper

    @property
    def param_name_list(self):
        return self._param_name_list

    def getParams(self, args, i):
        """

        :param args:
        :param i:
        :return:
        """
        kwargs_list = []
        for k, model in enumerate(self.model_list):
            kwargs = {}
            kwargs_fixed = self.kwargs_fixed[k]
            param_names = self._param_name_list[k]
            for name in param_names:
                if not name in kwargs_fixed:
                    if model in [
                            'SHAPELETS', 'SHAPELETS_POLAR',
                            'SHAPELETS_POLAR_EXP'
                    ] and name == 'amp':
                        if 'n_max' in kwargs_fixed:
                            n_max = kwargs_fixed['n_max']
                        else:
                            raise ValueError('n_max needs to be fixed in %s.' %
                                             model)
                        if model in ['SHAPELETS_POLAR_EXP']:
                            num_param = int((n_max + 1)**2)
                        else:
                            num_param = int((n_max + 1) * (n_max + 2) / 2)
                        kwargs['amp'] = args[i:i + num_param]
                        i += num_param
                    elif model in ['MULTI_GAUSSIAN', 'MULTI_GAUSSIAN_ELLIPSE'
                                   ] and name == 'amp':
                        if 'sigma' in kwargs_fixed:
                            num_param = len(kwargs_fixed['sigma'])
                        else:
                            raise ValueError('sigma needs to be fixed in %s.' %
                                             model)
                        kwargs['amp'] = args[i:i + num_param]
                        i += num_param
                    else:
                        kwargs[name] = args[i]
                        i += 1
                else:
                    kwargs[name] = kwargs_fixed[name]

            kwargs_list.append(kwargs)
        return kwargs_list, i

    def setParams(self, kwargs_list):
        """

        :param kwargs_list:
        :param bounds: bool, if True, ellitpicity of min/max
        :return:
        """
        args = []
        for k, model in enumerate(self.model_list):
            kwargs = kwargs_list[k]
            kwargs_fixed = self.kwargs_fixed[k]

            param_names = self._param_name_list[k]
            for name in param_names:
                if not name in kwargs_fixed:
                    if model in [
                            'SHAPELETS', 'SHAPELETS_POLAR',
                            'SHAPELETS_POLAR_EXP'
                    ] and name == 'amp':
                        n_max = kwargs_fixed.get('n_max', kwargs['n_max'])
                        if model in ['SHAPELETS_POLAR_EXP']:
                            num_param = int((n_max + 1)**2)
                        else:
                            num_param = int((n_max + 1) * (n_max + 2) / 2)
                        for i in range(num_param):
                            args.append(kwargs[name][i])
                    elif model in ['MULTI_GAUSSIAN', 'MULTI_GAUSSIAN_ELLIPSE'
                                   ] and name == 'amp':
                        num_param = len(kwargs['sigma'])
                        for i in range(num_param):
                            args.append(kwargs[name][i])
                    elif model in ['MULTI_GAUSSIAN', 'MULTI_GAUSSIAN_ELLIPSE'
                                   ] and name == 'sigma':
                        raise ValueError(
                            "'sigma' must be a fixed keyword argument for MULTI_GAUSSIAN"
                        )
                    else:
                        args.append(kwargs[name])
        return args

    def num_param(self):
        """

        :return:
        """
        num = 0
        list = []
        for k, model in enumerate(self.model_list):
            kwargs_fixed = self.kwargs_fixed[k]
            param_names = self._param_name_list[k]
            for name in param_names:
                if not name in kwargs_fixed:
                    if model in [
                            'SHAPELETS', 'SHAPELETS_POLAR',
                            'SHAPELETS_POLAR_EXP'
                    ] and name == 'amp':
                        if 'n_max' not in kwargs_fixed:
                            raise ValueError(
                                "n_max needs to be fixed in this configuration!"
                            )
                        n_max = kwargs_fixed['n_max']
                        if model in ['SHAPELETS_POLAR_EXP']:
                            num_param = int((n_max + 1)**2)
                        else:
                            num_param = int((n_max + 1) * (n_max + 2) / 2)
                        num += num_param
                        for i in range(num_param):
                            list.append(str(name + '_' + self._type + str(k)))
                    elif model in ['MULTI_GAUSSIAN', 'MULTI_GAUSSIAN_ELLIPSE'
                                   ] and name == 'amp':
                        num_param = len(kwargs_fixed['sigma'])
                        num += num_param
                        for i in range(num_param):
                            list.append(str(name + '_' + self._type + str(k)))
                    else:
                        num += 1
                        list.append(str(name + '_' + self._type + str(k)))
        return num, list

    def add_fixed_linear(self, kwargs_fixed_list):
        """

        :param kwargs_fixed_list: list of fixed keyword arguments
        :return: updated kwargs_fixed_list with additional linear parameters being fixed.
        """
        for k, model in enumerate(self.model_list):
            kwargs_fixed = kwargs_fixed_list[k]
            param_names = self._param_name_list[k]
            if 'amp' in param_names:
                if not 'amp' in kwargs_fixed:
                    kwargs_fixed['amp'] = 1
        return kwargs_fixed_list

    def num_param_linear(self):
        """

        :return: number of linear basis set coefficients
        """
        return self._lightModel.num_param_linear(kwargs_list=self.kwargs_fixed)

    def check_positive_flux_profile(self, kwargs_list):
        pos_bool = True
        for k, model in enumerate(self.model_list):
            if 'amp' in kwargs_list[k]:
                if model in [
                        'SERSIC', 'SERSIC_ELLIPSE', 'CORE_SERSIC', 'HERNQUIST',
                        'PJAFFE', 'PJAFFE_ELLIPSE', 'HERNQUIST_ELLIPSE',
                        'GAUSSIAN', 'GAUSSIAN_ELLIPSE', 'POWER_LAW', 'NIE',
                        'CHAMELEON', 'DOUBLE_CHAMELEON'
                ]:
                    if kwargs_list[k]['amp'] < 0:
                        pos_bool = False
                        break
        return pos_bool
Ejemplo n.º 6
0
class LightParam(object):
    """

    """

    def __init__(self, light_model_list, kwargs_fixed, kwargs_lower=None, kwargs_upper=None, type='light',
                 linear_solver=True):
        self._lightModel = LightModel(light_model_list=light_model_list)
        self._param_name_list = self._lightModel.param_name_list
        self._type = type
        self.model_list = light_model_list
        self.kwargs_fixed = kwargs_fixed
        if linear_solver:
            self.kwargs_fixed = self._lightModel.add_fixed_linear(self.kwargs_fixed)
        self._linear_solve = linear_solver
        if kwargs_lower is None:
            kwargs_lower = []
            for func in self._lightModel.func_list:
                kwargs_lower.append(func.lower_limit_default)
        if kwargs_upper is None:
            kwargs_upper = []
            for func in self._lightModel.func_list:
                kwargs_upper.append(func.upper_limit_default)
        self.lower_limit = kwargs_lower
        self.upper_limit = kwargs_upper
    
    @property
    def param_name_list(self):
        return self._param_name_list

    def getParams(self, args, i):
        """

        :param args:
        :param i:
        :return:
        """
        kwargs_list = []
        for k, model in enumerate(self.model_list):
            kwargs = {}
            kwargs_fixed = self.kwargs_fixed[k]
            param_names = self._param_name_list[k]
            for name in param_names:
                if not name in kwargs_fixed:
                    if model in ['SHAPELETS', 'SHAPELETS_POLAR', 'SHAPELETS_POLAR_EXP'] and name == 'amp':
                        if 'n_max' in kwargs_fixed:
                            n_max = kwargs_fixed['n_max']
                        else:
                            raise ValueError('n_max needs to be fixed in %s.' % model)
                        if model in ['SHAPELETS_POLAR_EXP']:
                            num_param = int((n_max + 1) ** 2)
                        else:
                            num_param = int((n_max + 1) * (n_max + 2) / 2)
                        kwargs['amp'] = args[i:i + num_param]
                        i += num_param
                    elif model in ['MULTI_GAUSSIAN', 'MULTI_GAUSSIAN_ELLIPSE'] and name == 'amp':
                        if 'sigma' in kwargs_fixed:
                            num_param = len(kwargs_fixed['sigma'])
                        else:
                            raise ValueError('sigma needs to be fixed in %s.' % model)
                        kwargs['amp'] = args[i:i + num_param]
                        i += num_param
                    elif model in ['SLIT_STARLETS', 'SLIT_STARLETS_GEN2'] and name == 'amp':
                        if 'n_scales' in kwargs_fixed and 'n_pixels' in kwargs_fixed:
                            n_scales = kwargs_fixed['n_scales']
                            n_pixels = kwargs_fixed['n_pixels']
                        else:
                            raise ValueError("'n_scales' and 'n_pixels' both need to be fixed in %s." % model)
                        num_param = n_scales * n_pixels
                        kwargs['amp'] = args[i:i + num_param]
                        i += num_param
                    else:
                        kwargs[name] = args[i]
                        i += 1
                else:
                    kwargs[name] = kwargs_fixed[name]

            kwargs_list.append(kwargs)
        return kwargs_list, i

    def setParams(self, kwargs_list):
        """

        :param kwargs_list:
        :param bounds: bool, if True, ellitpicity of min/max
        :return:
        """
        args = []
        for k, model in enumerate(self.model_list):
            kwargs = kwargs_list[k]
            kwargs_fixed = self.kwargs_fixed[k]

            param_names = self._param_name_list[k]
            for name in param_names:
                if not name in kwargs_fixed:
                    if model in ['SHAPELETS', 'SHAPELETS_POLAR', 'SHAPELETS_POLAR_EXP'] and name == 'amp':
                        n_max = kwargs_fixed.get('n_max', kwargs['n_max'])
                        if model in ['SHAPELETS_POLAR_EXP']:
                            num_param = int((n_max + 1) ** 2)
                        else:
                            num_param = int((n_max + 1) * (n_max + 2) / 2)
                        for i in range(num_param):
                            args.append(kwargs[name][i])
                    elif model in ['SLIT_STARLETS', 'SLIT_STARLETS_GEN2'] and name == 'amp':
                        if 'n_scales' in kwargs_fixed:
                            n_scales = kwargs_fixed['n_scales']
                        else:
                            raise ValueError("'n_scakes' for SLIT_STARLETS not found in kwargs_fixed")
                        if 'n_pixels' in kwargs_fixed:
                            n_pixels = kwargs_fixed['n_pixels']
                        else:
                            raise ValueError("'n_pixels' for SLIT_STARLETS not found in kwargs_fixed")
                        num_param = n_scales * n_pixels
                        for i in range(num_param):
                            args.append(kwargs[name][i])
                    elif model in ['SLIT_STARLETS', 'SLIT_STARLETS_GEN2'] and name in ['n_scales', 'n_pixels', 'scale', 'center_x', 'center_y']:
                        raise ValueError("'{}' must be a fixed keyword argument for STARLETS-like models".format(name))
                    elif model in ['MULTI_GAUSSIAN', 'MULTI_GAUSSIAN_ELLIPSE'] and name == 'amp':
                        num_param = len(kwargs['sigma'])
                        for i in range(num_param):
                            args.append(kwargs[name][i])
                    elif model in ['MULTI_GAUSSIAN', 'MULTI_GAUSSIAN_ELLIPSE'] and name == 'sigma':
                        raise ValueError("'sigma' must be a fixed keyword argument for MULTI_GAUSSIAN")
                    else:
                        args.append(kwargs[name])
        return args

    def num_param(self):
        """

        :return:
        """
        num = 0
        list = []
        for k, model in enumerate(self.model_list):
            kwargs_fixed = self.kwargs_fixed[k]
            param_names = self._param_name_list[k]
            for name in param_names:
                if not name in kwargs_fixed:
                    if model in ['SHAPELETS', 'SHAPELETS_POLAR', 'SHAPELETS_POLAR_EXP'] and name == 'amp':
                        if 'n_max' not in kwargs_fixed:
                            raise ValueError("n_max needs to be fixed in this configuration!")
                        n_max = kwargs_fixed['n_max']
                        if model in ['SHAPELETS_POLAR_EXP']:
                            num_param = int((n_max + 1) ** 2)
                        else:
                            num_param = int((n_max + 1) * (n_max + 2) / 2)
                        num += num_param
                        for i in range(num_param):
                            list.append(str(name + '_' + self._type + str(k)))
                    elif model in ['SLIT_STARLETS', 'SLIT_STARLETS_GEN2'] and name == 'amp':
                        if 'n_scales' not in kwargs_fixed or 'n_pixels' not in kwargs_fixed:
                            raise ValueError("n_scales and n_pixels need to be fixed when using STARLETS-like models!")
                        n_scales = kwargs_fixed['n_scales']
                        n_pixels = kwargs_fixed['n_pixels']
                        num_param = n_scales * n_pixels
                        num += num_param
                        for i in range(num_param):
                            list.append(str(name + '_' + self._type + str(k)))
                    elif model in ['MULTI_GAUSSIAN', 'MULTI_GAUSSIAN_ELLIPSE'] and name == 'amp':
                        num_param = len(kwargs_fixed['sigma'])
                        num += num_param
                        for i in range(num_param):
                            list.append(str(name + '_' + self._type + str(k)))
                    else:
                        num += 1
                        list.append(str(name + '_' + self._type + str(k)))
        return num, list

    def num_param_linear(self):
        """
        :return: number of linear basis set coefficients
        """
        return self._lightModel.num_param_linear(kwargs_list=self.kwargs_fixed)