예제 #1
0
    def __init__(self, filter_name: str) -> None:
        """
        Parameters
        ----------
        filter_name : str
            Filter name as listed in the database. Filters from the
            SVO Filter Profile Service are automatically downloaded
            and added to the database.

        Returns
        -------
        NoneType
            None
        """

        self.filter_name = filter_name
        self.filter_interp = None
        self.wavel_range = None

        self.vega_mag = 0.03  # (mag)

        config_file = os.path.join(os.getcwd(), "species_config.ini")

        config = configparser.ConfigParser()
        config.read(config_file)

        self.database = config["species"]["database"]

        read_filt = read_filter.ReadFilter(self.filter_name)
        self.det_type = read_filt.detector_type()
예제 #2
0
    def __init__(self, model, wavelength, teff=None):
        """
        Parameters
        ----------
        model : str
            Model name.
        wavelength : tuple(float, float) or str
            Wavelength range (micron) or filter name. Full spectrum is used if set to None.
        teff : tuple(float, float)
            Effective temperature (K) range. Restricting the temperature range will speed up the
            computation.

        Returns
        -------
        NoneType
            None
        """

        self.model = model
        self.teff = teff

        self.spectrum_interp = None
        self.wl_points = None
        self.wl_index = None

        if isinstance(wavelength, str):
            self.filter_name = wavelength
            transmission = read_filter.ReadFilter(wavelength)
            self.wavelength = transmission.wavelength_range()

        else:
            self.filter_name = None
            self.wavelength = wavelength
예제 #3
0
    def __init__(self, spectrum, filter_name):
        """
        Parameters
        ----------
        spectrum : str
            Database tag of the calibration spectrum.
        filter_name : str
            Filter ID. Full spectrum is used if set to None.

        Returns
        -------
        NoneType
            None
        """

        self.spectrum = spectrum
        self.filter_name = filter_name

        if filter_name:
            transmission = read_filter.ReadFilter(filter_name)
            self.wl_range = transmission.wavelength_range()

        else:
            self.wl_range = None

        config_file = os.path.join(os.getcwd(), 'species_config.ini')

        config = configparser.ConfigParser()
        config.read_file(open(config_file))

        self.database = config['species']['database']
예제 #4
0
    def __init__(self,
                 spec_library,
                 filter_name=None):
        """
        Parameters
        ----------
        spec_library : str
            Name of the spectral library ('irtf', 'spex') or other type of spectrum ('vega').
        filter_name : str, None
            Filter ID for the wavelength range. Full spectra are read if set to None.

        Returns
        -------
        NoneType
            None
        """

        self.spec_library = spec_library
        self.filter_name = filter_name

        if filter_name is None:
            self.wavel_range = None

        else:
            transmission = read_filter.ReadFilter(filter_name)
            self.wavel_range = transmission.wavelength_range()

        config_file = os.path.join(os.getcwd(), 'species_config.ini')

        config = configparser.ConfigParser()
        config.read_file(open(config_file))

        self.database = config['species']['database']
예제 #5
0
    def __init__(self, tag: str, filter_name: Optional[str] = None) -> None:
        """
        Parameters
        ----------
        tag : str
            Database tag of the calibration spectrum.
        filter_name : str, None
            Filter that is used for the wavelength range. The full
            spectrum is read if the argument is set to ``None``.

        Returns
        -------
        NoneType
            None
        """

        self.tag = tag
        self.filter_name = filter_name

        if filter_name is None:
            self.wavel_range = None

        else:
            transmission = read_filter.ReadFilter(filter_name)
            self.wavel_range = transmission.wavelength_range()

        config_file = os.path.join(os.getcwd(), "species_config.ini")

        config = configparser.ConfigParser()
        config.read(config_file)

        self.database = config["species"]["database"]
예제 #6
0
    def get_magnitude(self, sptypes: List[str] = None) -> box.PhotometryBox:
        """
        Function for calculating the apparent magnitude for the ``filter_name``.

        Parameters
        ----------
        sptypes : list(str)
            Spectral types to select from a library. The spectral types should be indicated with
            two characters (e.g. 'M5', 'L2', 'T3'). All spectra are selected if set to ``None``.

        Returns
        -------
        species.core.box.PhotometryBox
            Box with the synthetic photometry.
        """

        specbox = self.get_spectrum(sptypes=sptypes, exclude_nan=True)

        n_spectra = len(specbox.wavelength)

        filter_profile = read_filter.ReadFilter(filter_name=self.filter_name)
        mean_wavel = filter_profile.mean_wavelength()

        wavelengths = np.full(n_spectra, mean_wavel)
        filters = np.full(n_spectra, self.filter_name)

        synphot = photometry.SyntheticPhotometry(filter_name=self.filter_name)

        app_mag = []
        abs_mag = []

        for i in range(n_spectra):

            if np.isnan(specbox.distance[i][0]):
                app_tmp = (np.nan, np.nan)
                abs_tmp = (np.nan, np.nan)

            else:
                app_tmp, abs_tmp = synphot.spectrum_to_magnitude(
                    specbox.wavelength[i],
                    specbox.flux[i],
                    error=specbox.error[i],
                    distance=(float(specbox.distance[i][0]),
                              float(specbox.distance[i][1])))

            app_mag.append(app_tmp)
            abs_mag.append(abs_tmp)

        return box.create_box(boxtype='photometry',
                              name=specbox.name,
                              sptype=specbox.sptype,
                              wavelength=wavelengths,
                              flux=None,
                              app_mag=np.asarray(app_mag),
                              abs_mag=np.asarray(abs_mag),
                              filter_name=filters)
예제 #7
0
    def synthetic_photometry(
            self, filter_name: Union[str, List[str]]) -> PhotometryBox:
        """
        Method for calculating synthetic photometry from the model
        spectrum that is stored in the ``ModelBox``.

        Parameters
        ----------
        filter_name : str, list(str)
            Single filter name or a list of filter names for which
            synthetic photometry will be calculated.

        Returns
        -------
        species.core.box.PhotometryBox
            Box with the synthetic photometry.
        """

        if isinstance(filter_name, str):
            filter_name = [filter_name]

        list_wavel = []
        list_flux = []
        list_app_mag = []
        list_abs_mag = []

        for item in filter_name:
            syn_phot = photometry.SyntheticPhotometry(filter_name=item)

            syn_flux = syn_phot.spectrum_to_flux(wavelength=self.wavelength,
                                                 flux=self.flux)

            syn_mag = syn_phot.spectrum_to_magnitude(
                wavelength=self.wavelength, flux=self.flux)

            list_flux.append(syn_flux)
            list_app_mag.append(syn_mag[0])
            list_abs_mag.append(syn_mag[1])

            filter_profile = read_filter.ReadFilter(filter_name=item)
            list_wavel.append(filter_profile.mean_wavelength())

        phot_box = create_box(
            boxtype="photometry",
            name=None,
            sptype=None,
            wavelength=list_wavel,
            flux=list_flux,
            app_mag=list_app_mag,
            abs_mag=list_abs_mag,
            filter_name=filter_name,
        )

        return phot_box
예제 #8
0
    def get_flux(self, sptypes: List[str] = None) -> box.PhotometryBox:
        """
        Function for calculating the average flux density for the
        ``filter_name``.

        Parameters
        ----------
        sptypes : list(str), None
            Spectral types to select from a library. The spectral types
            should be indicated with two characters (e.g. 'M5', 'L2',
            'T3'). All spectra are selected if set to ``None``.

        Returns
        -------
        species.core.box.PhotometryBox
            Box with the synthetic photometry.
        """

        specbox = self.get_spectrum(sptypes=sptypes, exclude_nan=True)

        n_spectra = len(specbox.wavelength)

        filter_profile = read_filter.ReadFilter(filter_name=self.filter_name)
        mean_wavel = filter_profile.mean_wavelength()

        wavelengths = np.full(n_spectra, mean_wavel)
        filters = np.full(n_spectra, self.filter_name)

        synphot = photometry.SyntheticPhotometry(filter_name=self.filter_name)

        phot_flux = []

        for i in range(n_spectra):
            flux = synphot.spectrum_to_flux(
                wavelength=specbox.wavelength[i],
                flux=specbox.flux[i],
                error=specbox.error[i],
            )

            phot_flux.append(flux)

        phot_flux = np.asarray(phot_flux)

        return box.create_box(
            boxtype="photometry",
            name=specbox.name,
            sptype=specbox.sptype,
            wavelength=wavelengths,
            flux=phot_flux,
            app_mag=None,
            abs_mag=None,
            filter_name=filters,
        )
예제 #9
0
    def __init__(
        self,
        wavel_range: Optional[
            Tuple[Union[float, np.float32], Union[float, np.float32]]
        ] = None,
        filter_name: Optional[str] = None,
    ) -> None:
        """
        Parameters
        ----------
        wavel_range : tuple(float, float), None
            Wavelength range (um). A wavelength range of 0.1-1000 um
            is used if set to ``None``. Not used if ``filter_name``
            is not set to ``None``.
        filter_name : str, None
            Filter name that is used for the wavelength range. The
            ``wavel_range`` is used if set to ``None``.

        Returns
        -------
        NoneType
            None
        """

        self.spectrum_interp = None
        self.wl_points = None
        self.wl_index = None

        self.filter_name = filter_name
        self.wavel_range = wavel_range

        if self.filter_name is not None:
            transmission = read_filter.ReadFilter(self.filter_name)
            self.wavel_range = transmission.wavelength_range()

        elif self.wavel_range is None:
            self.wavel_range = (0.1, 1000.0)

        config_file = os.path.join(os.getcwd(), "species_config.ini")

        config = configparser.ConfigParser()
        config.read(config_file)

        self.database = config["species"]["database"]
예제 #10
0
    def zero_point(self) -> np.float64:
        """
        Internal function for calculating the zero point
        of the provided ``filter_name``.

        Returns
        -------
        float
            Zero-point flux (W m-2 um-1).
        """

        if self.wavel_range is None:
            transmission = read_filter.ReadFilter(self.filter_name)
            self.wavel_range = transmission.wavelength_range()

        h5_file = h5py.File(self.database, "r")

        try:
            h5_file["spectra/calibration/vega"]

        except KeyError:
            h5_file.close()
            species_db = database.Database()
            species_db.add_spectra("vega")
            h5_file = h5py.File(self.database, "r")

        readcalib = read_calibration.ReadCalibration("vega", None)
        calibbox = readcalib.get_spectrum()

        wavelength = calibbox.wavelength
        flux = calibbox.flux

        wavelength_crop = wavelength[(wavelength > self.wavel_range[0])
                                     & (wavelength < self.wavel_range[1])]

        flux_crop = flux[(wavelength > self.wavel_range[0])
                         & (wavelength < self.wavel_range[1])]

        h5_file.close()

        return self.spectrum_to_flux(wavelength_crop, flux_crop)[0]
예제 #11
0
    def __init__(self,
                 model: str,
                 wavel_range: Optional[Tuple[float, float]] = None,
                 filter_name: Optional[str] = None):
        """
        Parameters
        ----------
        model : str
            Name of the atmospheric model.
        wavel_range : tuple(float, float), None
            Wavelength range (um). Full spectrum is selected if set to ``None``. Not used if
            ``filter_name`` is not ``None``.
        filter_name : str, None
            Filter name that is used for the wavelength range. The ``wavel_range`` is used if set
            to ``None``.

        Returns
        -------
        NoneType
            None
        """

        self.model = model

        self.spectrum_interp = None
        self.wl_points = None
        self.wl_index = None

        self.filter_name = filter_name
        self.wavel_range = wavel_range

        if self.filter_name is not None:
            transmission = read_filter.ReadFilter(self.filter_name)
            self.wavel_range = transmission.wavelength_range()

        config_file = os.path.join(os.getcwd(), 'species_config.ini')

        config = configparser.ConfigParser()
        config.read_file(open(config_file))

        self.database = config['species']['database']
예제 #12
0
    def zero_point(self):
        """
        Returns
        -------
        tuple(float, float)
        """

        if self.wl_range is None:
            transmission = read_filter.ReadFilter(self.filter_name)
            self.wl_range = transmission.wavelength_range()

        h5_file = h5py.File(self.database, 'r')

        try:
            h5_file['spectra/calibration/vega']

        except KeyError:
            h5_file.close()
            species_db = database.Database()
            species_db.add_spectrum('vega')
            h5_file = h5py.File(self.database, 'r')

        readcalib = read_calibration.ReadCalibration('vega', None)
        calibbox = readcalib.get_spectrum()

        wavelength = calibbox.wavelength
        flux = calibbox.flux

        wavelength_crop = wavelength[(wavelength > self.wl_range[0])
                                     & (wavelength < self.wl_range[1])]

        flux_crop = flux[(wavelength > self.wl_range[0])
                         & (wavelength < self.wl_range[1])]

        h5_file.close()

        return self.spectrum_to_photometry(wavelength_crop, flux_crop)
예제 #13
0
def get_residuals(
    datatype: str,
    spectrum: str,
    parameters: Dict[str, float],
    objectbox: box.ObjectBox,
    inc_phot: Union[bool, List[str]] = True,
    inc_spec: Union[bool, List[str]] = True,
    radtrans: Optional[read_radtrans.ReadRadtrans] = None,
) -> box.ResidualsBox:
    """
    Function for calculating the residuals from fitting model or
    calibration spectra to a set of spectra and/or photometry.

    Parameters
    ----------
    datatype : str
        Data type ('model' or 'calibration').
    spectrum : str
        Name of the atmospheric model or calibration spectrum.
    parameters : dict
        Parameters and values for the spectrum
    objectbox : species.core.box.ObjectBox
        Box with the photometry and/or spectra of an object. A scaling
        and/or error inflation of the spectra should be applied with
        :func:`~species.util.read_util.update_spectra` beforehand.
    inc_phot : bool, list(str)
        Include photometric data in the fit. If a boolean, either all
        (``True``) or none (``False``) of the data are selected. If a
        list, a subset of filter names (as stored in the database) can
        be provided.
    inc_spec : bool, list(str)
        Include spectroscopic data in the fit. If a boolean, either all
        (``True``) or none (``False``) of the data are selected. If a
        list, a subset of spectrum names (as stored in the database
        with :func:`~species.data.database.Database.add_object`) can be
        provided.
    radtrans : read_radtrans.ReadRadtrans, None
        Instance of :class:`~species.read.read_radtrans.ReadRadtrans`.
        Only required with ``spectrum='petitradtrans'`. Make sure that
        the ``wavel_range`` of the ``ReadRadtrans`` instance is
        sufficiently broad to cover all the photometric and
        spectroscopic data of ``inc_phot`` and ``inc_spec``. Not used
        if set to ``None``.

    Returns
    -------
    species.core.box.ResidualsBox
        Box with the residuals.
    """

    if isinstance(inc_phot, bool) and inc_phot:
        inc_phot = objectbox.filters

    if inc_phot:
        model_phot = multi_photometry(
            datatype=datatype,
            spectrum=spectrum,
            filters=inc_phot,
            parameters=parameters,
            radtrans=radtrans,
        )

        res_phot = {}

        for item in inc_phot:
            transmission = read_filter.ReadFilter(item)
            res_phot[item] = np.zeros(objectbox.flux[item].shape)

            if objectbox.flux[item].ndim == 1:
                res_phot[item][0] = transmission.mean_wavelength()
                res_phot[item][1] = (
                    objectbox.flux[item][0] -
                    model_phot.flux[item]) / objectbox.flux[item][1]

            elif objectbox.flux[item].ndim == 2:
                for j in range(objectbox.flux[item].shape[1]):
                    res_phot[item][0, j] = transmission.mean_wavelength()
                    res_phot[item][1, j] = (
                        objectbox.flux[item][0, j] -
                        model_phot.flux[item]) / objectbox.flux[item][1, j]

    else:
        res_phot = None

    if inc_spec:
        res_spec = {}

        if spectrum == "petitradtrans":
            # Calculate the petitRADTRANS spectrum only once
            model = radtrans.get_model(parameters)

        for key in objectbox.spectrum:

            if isinstance(inc_spec, bool) or key in inc_spec:
                wavel_range = (
                    0.9 * objectbox.spectrum[key][0][0, 0],
                    1.1 * objectbox.spectrum[key][0][-1, 0],
                )

                wl_new = objectbox.spectrum[key][0][:, 0]
                spec_res = objectbox.spectrum[key][3]

                if spectrum == "planck":
                    readmodel = read_planck.ReadPlanck(wavel_range=wavel_range)

                    model = readmodel.get_spectrum(model_param=parameters,
                                                   spec_res=1000.0)

                    # Separate resampling to the new wavelength points

                    flux_new = spectres.spectres(
                        wl_new,
                        model.wavelength,
                        model.flux,
                        spec_errs=None,
                        fill=0.0,
                        verbose=True,
                    )

                elif spectrum == "petitradtrans":
                    # Separate resampling to the new wavelength points
                    flux_new = spectres.spectres(
                        wl_new,
                        model.wavelength,
                        model.flux,
                        spec_errs=None,
                        fill=0.0,
                        verbose=True,
                    )

                else:
                    # Resampling to the new wavelength points
                    # is done by the get_model method

                    readmodel = read_model.ReadModel(spectrum,
                                                     wavel_range=wavel_range)

                    if "teff_0" in parameters and "teff_1" in parameters:
                        # Binary system

                        param_0 = read_util.binary_to_single(parameters, 0)

                        model_spec_0 = readmodel.get_model(
                            param_0,
                            spec_res=spec_res,
                            wavel_resample=wl_new,
                            smooth=True,
                        )

                        param_1 = read_util.binary_to_single(parameters, 1)

                        model_spec_1 = readmodel.get_model(
                            param_1,
                            spec_res=spec_res,
                            wavel_resample=wl_new,
                            smooth=True,
                        )

                        flux_comb = (
                            parameters["spec_weight"] * model_spec_0.flux +
                            (1.0 - parameters["spec_weight"]) *
                            model_spec_1.flux)

                        model_spec = box.create_box(
                            boxtype="model",
                            model=spectrum,
                            wavelength=wl_new,
                            flux=flux_comb,
                            parameters=parameters,
                            quantity="flux",
                        )

                    else:
                        # Single object

                        model_spec = readmodel.get_model(
                            parameters,
                            spec_res=spec_res,
                            wavel_resample=wl_new,
                            smooth=True,
                        )

                    flux_new = model_spec.flux

                data_spec = objectbox.spectrum[key][0]
                res_tmp = (data_spec[:, 1] - flux_new) / data_spec[:, 2]

                res_spec[key] = np.column_stack([wl_new, res_tmp])

    else:
        res_spec = None

    print("Calculating residuals... [DONE]")

    print("Residuals (sigma):")

    if res_phot is not None:
        for item in inc_phot:
            if res_phot[item].ndim == 1:
                print(f"   - {item}: {res_phot[item][1]:.2f}")

            elif res_phot[item].ndim == 2:
                for j in range(res_phot[item].shape[1]):
                    print(f"   - {item}: {res_phot[item][1, j]:.2f}")

    if res_spec is not None:
        for key in objectbox.spectrum:
            if isinstance(inc_spec, bool) or key in inc_spec:
                print(f"   - {key}: min: {np.nanmin(res_spec[key]):.2f}, "
                      f"max: {np.nanmax(res_spec[key]):.2f}")

    chi2_stat = 0
    n_dof = 0

    if res_phot is not None:
        for key, value in res_phot.items():
            chi2_stat += value[1]**2
            n_dof += 1

    if res_spec is not None:
        for key, value in res_spec.items():
            chi2_stat += np.sum(value[:, 1]**2)
            n_dof += value.shape[0]

    for item in parameters:
        if item not in ["mass", "luminosity", "distance"]:
            n_dof -= 1

    chi2_red = chi2_stat / n_dof

    print(f"Reduced chi2 = {chi2_red:.2f}")
    print(f"Number of degrees of freedom = {n_dof}")

    return box.create_box(
        boxtype="residuals",
        name=objectbox.name,
        photometry=res_phot,
        spectrum=res_spec,
        chi2_red=chi2_red,
    )
예제 #14
0
    def spectrum_to_flux(
        self,
        wavelength: np.ndarray,
        flux: np.ndarray,
        error: Optional[np.ndarray] = None,
        threshold: Optional[float] = 0.05,
    ) -> Tuple[Union[np.float32, np.float64], Union[Optional[np.float32],
                                                    Optional[np.float64]]]:
        """
        Function for calculating the average flux from a spectrum and
        a filter profile. The uncertainty is propagated by sampling
        200 random values from the error distributions.

        Parameters
        ----------
        wavelength : np.ndarray
            Wavelength points (um).
        flux : np.ndarray
            Flux (W m-2 um-1).
        error : np.ndarray, None
            Uncertainty (W m-2 um-1). Not used if set to ``None``.
        threshold : float, None
            Transmission threshold (value between 0 and 1). If the
            minimum transmission value is larger than the threshold,
            a NaN is returned. This will happen if the input spectrum
            does not cover the full wavelength range of the filter
            profile. The parameter is not used if set to ``None``.

        Returns
        -------
        float
            Average flux (W m-2 um-1).
        float, None
            Uncertainty (W m-2 um-1).
        """

        if error is not None:
            # The error calculation requires the original
            # spectrum because spectrum_to_flux is used
            wavel_error = wavelength.copy()
            flux_error = flux.copy()

        if self.filter_interp is None:
            transmission = read_filter.ReadFilter(self.filter_name)
            self.filter_interp = transmission.interpolate_filter()

            if self.wavel_range is None:
                self.wavel_range = transmission.wavelength_range()

        if wavelength.size == 0:
            raise ValueError(f"Calculation of the mean flux for "
                             f"{self.filter_name} is not possible "
                             f"because the wavelength array is empty.")

        indices = np.where((self.wavel_range[0] <= wavelength)
                           & (wavelength <= self.wavel_range[1]))[0]

        if indices.size < 2:
            syn_flux = np.nan

            warnings.warn("Calculating a synthetic flux requires more than "
                          "one wavelength point. Photometry is set to NaN.")

        else:
            if threshold is None and (wavelength[0] > self.wavel_range[0]
                                      or wavelength[-1] < self.wavel_range[1]):

                warnings.warn(
                    f"The filter profile of {self.filter_name} "
                    f"({self.wavel_range[0]:.4f}-{self.wavel_range[1]:.4f}) extends "
                    f"beyond the wavelength range of the spectrum ({wavelength[0]:.4f} "
                    f"-{wavelength[-1]:.4f}). The flux is set to NaN. Setting the "
                    f"'threshold' parameter will loosen the wavelength constraints."
                )

                syn_flux = np.nan

            else:
                wavelength = wavelength[indices]
                flux = flux[indices]

                transmission = self.filter_interp(wavelength)

                if (threshold is not None and (transmission[0] > threshold
                                               or transmission[-1] > threshold)
                        and (wavelength[0] < self.wavel_range[0]
                             or wavelength[-1] > self.wavel_range[-1])):

                    warnings.warn(
                        f"The filter profile of {self.filter_name} "
                        f"({self.wavel_range[0]:.4f}-{self.wavel_range[1]:.4f}) "
                        f"extends beyond the wavelength range of the spectrum "
                        f"({wavelength[0]:.4f}-{wavelength[-1]:.4f}). The flux "
                        f"is set to NaN. Increasing the 'threshold' parameter "
                        f"({threshold}) will loosen the wavelength constraint."
                    )

                    syn_flux = np.nan

                else:
                    indices = np.isnan(transmission)
                    indices = np.logical_not(indices)

                    if self.det_type == "energy":
                        # Energy counting detector
                        integrand1 = transmission[indices] * flux[indices]
                        integrand2 = transmission[indices]

                    elif self.det_type == "photon":
                        # Photon counting detector
                        integrand1 = (wavelength[indices] *
                                      transmission[indices] * flux[indices])
                        integrand2 = wavelength[indices] * transmission[indices]

                    integral1 = np.trapz(integrand1, wavelength[indices])
                    integral2 = np.trapz(integrand2, wavelength[indices])

                    syn_flux = integral1 / integral2

        if error is not None and not np.any(np.isnan(error)):
            phot_random = np.zeros(200)

            for i in range(200):
                # Use the original spectrum size (i.e. wavel_error and flux_error)
                spec_random = (flux_error + np.random.normal(
                    loc=0.0, scale=1.0, size=wavel_error.shape[0]) * error)

                phot_random[i] = self.spectrum_to_flux(wavel_error,
                                                       spec_random,
                                                       error=None,
                                                       threshold=threshold)[0]

            error_flux = np.std(phot_random)

        elif error is not None and np.any(np.isnan(error)):
            warnings.warn(
                "Spectum contains NaN so can not calculate the error.")
            error_flux = None

        else:
            error_flux = None

        return syn_flux, error_flux
예제 #15
0
def interp_powerlaw(inc_phot: List[str],
                    inc_spec: List[str],
                    spec_data: Optional[Dict[str, Tuple[np.ndarray, Optional[np.ndarray],
                                                        Optional[np.ndarray], float]]]) -> \
                        Tuple[Dict[str, Union[interp2d, List[interp2d]]], np.ndarray, np.ndarray]:
    """
    Function for interpolating the power-law dust cross sections for each filter and spectrum.

    Parameters
    ----------
    inc_phot : list(str)
        List with filter names. Not used if the list is empty.
    inc_spec : list(str)
        List with the spectrum names (as stored in the database with
        :func:`~species.data.database.Database.add_object`). Not used if the list is empty.
    spec_data : dict, None
        Dictionary with the spectrum data. Only required in combination with ``inc_spec``,
        otherwise the argument needs to be set to ``None``,.

    Returns
    -------
    dict
        Dictionary with the extinction cross section for each filter and spectrum
    np.ndarray
        Grid points of the maximum radius.
    np.ndarray
        Grid points of the power-law exponent.
    """

    database_path = check_dust_database()

    with h5py.File(database_path, 'r') as h5_file:
        cross_section = np.asarray(
            h5_file['dust/powerlaw/mgsio3/crystalline/cross_section'])
        wavelength = np.asarray(
            h5_file['dust/powerlaw/mgsio3/crystalline/wavelength'])
        radius_max = np.asarray(
            h5_file['dust/powerlaw/mgsio3/crystalline/radius_max'])
        exponent = np.asarray(
            h5_file['dust/powerlaw/mgsio3/crystalline/exponent'])

    print('Grid boundaries of the dust opacities:')
    print(f'   - Wavelength (um) = {wavelength[0]:.2f} - {wavelength[-1]:.2f}')
    print(
        f'   - Maximum radius (um) = {radius_max[0]:.2e} - {radius_max[-1]:.2e}'
    )
    print(f'   - Power-law exponent = {exponent[0]:.2f} - {exponent[-1]:.2f}')

    inc_phot.append('Generic/Bessell.V')

    cross_sections = {}

    for phot_item in inc_phot:
        read_filt = read_filter.ReadFilter(phot_item)
        filt_trans = read_filt.get_filter()

        cross_phot = np.zeros((radius_max.shape[0], exponent.shape[0]))

        for i in range(radius_max.shape[0]):
            for j in range(exponent.shape[0]):
                cross_interp = interp1d(wavelength,
                                        cross_section[:, i, j],
                                        kind='linear',
                                        bounds_error=True)

                cross_tmp = cross_interp(filt_trans[:, 0])

                integral1 = np.trapz(filt_trans[:, 1] * cross_tmp,
                                     filt_trans[:, 0])
                integral2 = np.trapz(filt_trans[:, 1], filt_trans[:, 0])

                # Filter-weighted average of the extinction cross section
                cross_phot[i, j] = integral1 / integral2

        cross_sections[phot_item] = interp2d(exponent,
                                             radius_max,
                                             cross_phot,
                                             kind='linear',
                                             bounds_error=True)

    print('Interpolating dust opacities...', end='')

    for spec_item in inc_spec:
        wavel_spec = spec_data[spec_item][0][:, 0]

        cross_spec = np.zeros(
            (wavel_spec.shape[0], radius_max.shape[0], exponent.shape[0]))

        for i in range(radius_max.shape[0]):
            for j in range(exponent.shape[0]):
                cross_interp = interp1d(wavelength,
                                        cross_section[:, i, j],
                                        kind='linear',
                                        bounds_error=True)

                cross_spec[:, i, j] = cross_interp(wavel_spec)

        cross_sections[spec_item] = []

        for i in range(wavel_spec.shape[0]):

            cross_tmp = interp2d(exponent,
                                 radius_max,
                                 cross_spec[i, :, :],
                                 kind='linear',
                                 bounds_error=True)

            cross_sections[spec_item].append(cross_tmp)

    print(' [DONE]')

    return cross_sections, radius_max, exponent
예제 #16
0
def calc_reddening(filters_color: Tuple[str, str],
                   extinction: Tuple[str, float],
                   composition: str = 'MgSiO3',
                   structure: str = 'crystalline',
                   radius_g: float = 1.) -> Tuple[float, float]:
    """
    Function for calculating the reddening of a color given the extinction for a given filter. A
    log-normal size distribution with a geometric standard deviation of 2 is used as
    parametrization for the grain sizes (Ackerman & Marley 2001).

    Parameters
    ----------
    filters_color : tuple(str, str)
        Filter names for which the extinction is calculated.
    extinction : str
        Filter name and extinction (mag).
    composition : str
        Dust composition ('MgSiO3' or 'Fe').
    structure : str
        Grain structure ('crystalline' or 'amorphous').
    radius_g : float
        Geometric radius of the grain size distribution (um).

    Returns
    -------
    float
        Extinction (mag) for ``filters_color[0]``.
    float
        Extinction (mag) for ``filters_color[1]``.
    """

    database_path = check_dust_database()

    h5_file = h5py.File(database_path, 'r')

    filters = [extinction[0], filters_color[0], filters_color[1]]

    dn_dr, r_width, radii = log_normal_distribution(radius_g, 2., 100)

    c_ext = {}

    for item in filters:
        read_filt = read_filter.ReadFilter(item)
        filter_wavel = read_filt.mean_wavelength()

        if composition == 'MgSiO3' and structure == 'crystalline':
            for i in range(3):
                data = h5_file[f'dust/mgsio3/crystalline/axis_{i+1}']

                wavel_index = (np.abs(data[:, 0] - filter_wavel)).argmin()

                # Average cross section of the three axes

                if i == 0:
                    c_ext[item] = dust_cross_section(
                        dn_dr, r_width, radii, data[wavel_index, 0],
                        data[wavel_index, 1], data[wavel_index, 2]) / 3.

                else:
                    c_ext[item] += dust_cross_section(
                        dn_dr, r_width, radii, data[wavel_index, 0],
                        data[wavel_index, 1], data[wavel_index, 2]) / 3.

        else:
            if composition == 'MgSiO3' and structure == 'amorphous':
                data = h5_file['dust/mgsio3/amorphous/']

            elif composition == 'Fe' and structure == 'crystalline':
                data = h5_file['dust/fe/crystalline/']

            elif composition == 'Fe' and structure == 'amorphous':
                data = h5_file['dust/fe/amorphous/']

            wavel_index = (np.abs(data[:, 0] - filter_wavel)).argmin()

            c_ext[item] += dust_cross_section(
                dn_dr, r_width, radii, data[wavel_index, 0],
                data[wavel_index, 1], data[wavel_index, 2]) / 3.

    h5_file.close()

    n_grains = extinction[1] / c_ext[extinction[0]] / 2.5 / np.log10(
        np.exp(1.))

    return 2.5 * np.log10(np.exp(1.)) * c_ext[filters_color[0]] * n_grains, \
        2.5 * np.log10(np.exp(1.)) * c_ext[filters_color[1]] * n_grains
예제 #17
0
        def lnlike_multinest(cube, n_dim: int, n_param: int) -> np.float64:
            """
            Function for the logarithm of the likelihood, computed from the parameter cube.

            Parameters
            ----------
            cube : pymultinest.run.LP_c_double
                Unit cube.
            n_dim : int
                Number of dimensions.
            n_param : int
                Number of parameters.

            Returns
            -------
            float
                Log likelihood.
            """

            param_dict = {}
            spec_scaling = {}
            err_offset = {}
            corr_len = {}
            corr_amp = {}
            dust_param = {}

            for item in self.bounds:
                if item[:8] == 'scaling_' and item[8:] in self.spectrum:
                    spec_scaling[item[8:]] = cube[cube_index[item]]

                elif item[:6] == 'error_' and item[6:] in self.spectrum:
                    err_offset[item[6:]] = cube[cube_index[item]]  # log10(um)

                elif item[:9] == 'corr_len_' and item[9:] in self.spectrum:
                    corr_len[item[9:]] = 10.**cube[cube_index[item]]  # (um)

                elif item[:9] == 'corr_amp_' and item[9:] in self.spectrum:
                    corr_amp[item[9:]] = cube[cube_index[item]]

                elif item[:8] == 'lognorm_':
                    dust_param[item] = cube[cube_index[item]]

                elif item[:9] == 'powerlaw_':
                    dust_param[item] = cube[cube_index[item]]

                elif item[:4] == 'ism_':
                    dust_param[item] = cube[cube_index[item]]

                else:
                    param_dict[item] = cube[cube_index[item]]

            if self.model == 'planck':
                param_dict['distance'] = self.distance[0]

            else:
                flux_scaling = (param_dict['radius']*constants.R_JUP)**2 / \
                               (self.distance[0]*constants.PARSEC)**2

                # The scaling is applied manually because of the interpolation
                del param_dict['radius']

            for item in self.spectrum:
                if item not in spec_scaling:
                    spec_scaling[item] = 1.

                if item not in err_offset:
                    err_offset[item] = None

            ln_like = 0.

            if self.model == 'planck' and self.n_planck > 1:
                for i in range(self.n_planck - 1):
                    if param_dict[f'teff_{i+1}'] > param_dict[f'teff_{i}']:
                        return -np.inf

                    if param_dict[f'radius_{i}'] > param_dict[f'radius_{i+1}']:
                        return -np.inf

            if prior is not None:
                for key, value in prior.items():
                    if key == 'mass':
                        mass = read_util.get_mass(cube[cube_index['logg']],
                                                  cube[cube_index['radius']])

                        ln_like += -0.5 * (mass - value[0])**2 / value[1]**2

                    else:
                        ln_like += -0.5 * (cube[cube_index[key]] -
                                           value[0])**2 / value[1]**2

            if 'lognorm_ext' in dust_param:
                cross_tmp = self.cross_sections['Generic/Bessell.V'](
                    dust_param['lognorm_sigma'],
                    10.**dust_param['lognorm_radius'])[0]

                n_grains = dust_param[
                    'lognorm_ext'] / cross_tmp / 2.5 / np.log10(np.exp(1.))

            elif 'powerlaw_ext' in dust_param:
                cross_tmp = self.cross_sections['Generic/Bessell.V'](
                    dust_param['powerlaw_exp'],
                    10.**dust_param['powerlaw_max'])

                n_grains = dust_param[
                    'powerlaw_ext'] / cross_tmp / 2.5 / np.log10(np.exp(1.))

            for i, obj_item in enumerate(self.objphot):
                if self.model == 'planck':
                    readplanck = read_planck.ReadPlanck(
                        filter_name=self.modelphot[i].filter_name)
                    phot_flux = readplanck.get_flux(
                        param_dict, synphot=self.modelphot[i])[0]

                else:
                    phot_flux = self.modelphot[i].spectrum_interp(
                        list(param_dict.values()))[0][0]
                    phot_flux *= flux_scaling

                if 'lognorm_ext' in dust_param:
                    cross_tmp = self.cross_sections[
                        self.modelphot[i].filter_name](
                            dust_param['lognorm_sigma'],
                            10.**dust_param['lognorm_radius'])[0]

                    phot_flux *= np.exp(-cross_tmp * n_grains)

                elif 'powerlaw_ext' in dust_param:
                    cross_tmp = self.cross_sections[
                        self.modelphot[i].filter_name](
                            dust_param['powerlaw_exp'],
                            10.**dust_param['powerlaw_max'])[0]

                    phot_flux *= np.exp(-cross_tmp * n_grains)

                elif 'ism_ext' in dust_param:
                    read_filt = read_filter.ReadFilter(
                        self.modelphot[i].filter_name)
                    filt_wavel = np.array([read_filt.mean_wavelength()])

                    ext_filt = dust_util.ism_extinction(
                        dust_param['ism_ext'], dust_param['ism_red'],
                        filt_wavel)

                    phot_flux *= 10.**(-0.4 * ext_filt[0])

                if obj_item.ndim == 1:
                    ln_like += -0.5 * (obj_item[0] -
                                       phot_flux)**2 / obj_item[1]**2

                else:
                    for j in range(obj_item.shape[1]):
                        ln_like += -0.5 * (obj_item[0, j] -
                                           phot_flux)**2 / obj_item[1, j]**2

            for i, item in enumerate(self.spectrum.keys()):
                data_flux = spec_scaling[item] * self.spectrum[item][0][:, 1]

                if err_offset[item] is None:
                    data_var = self.spectrum[item][0][:, 2]**2
                else:
                    data_var = (self.spectrum[item][0][:, 2] +
                                10.**err_offset[item])**2

                if self.spectrum[item][2] is not None:
                    if err_offset[item] is None:
                        data_cov_inv = self.spectrum[item][2]

                    else:
                        # Ratio of the inflated and original uncertainties
                        sigma_ratio = np.sqrt(
                            data_var) / self.spectrum[item][0][:, 2]
                        sigma_j, sigma_i = np.meshgrid(sigma_ratio,
                                                       sigma_ratio)

                        # Calculate the inversion of the infalted covariances
                        data_cov_inv = np.linalg.inv(self.spectrum[item][1] *
                                                     sigma_i * sigma_j)

                if self.model == 'planck':
                    readplanck = read_planck.ReadPlanck(
                        (0.9 * self.spectrum[item][0][0, 0],
                         1.1 * self.spectrum[item][0][-1, 0]))

                    model_box = readplanck.get_spectrum(param_dict,
                                                        1000.,
                                                        smooth=True)

                    model_flux = spectres.spectres(
                        self.spectrum[item][0][:, 0], model_box.wavelength,
                        model_box.flux)

                else:
                    model_flux = self.modelspec[i].spectrum_interp(
                        list(param_dict.values()))[0, :]
                    model_flux *= flux_scaling

                if 'lognorm_ext' in dust_param:
                    for j, cross_item in enumerate(self.cross_sections[item]):
                        cross_tmp = cross_item(
                            dust_param['lognorm_sigma'],
                            10.**dust_param['lognorm_radius'])[0]

                        model_flux[j] *= np.exp(-cross_tmp * n_grains)

                elif 'powerlaw_ext' in dust_param:
                    for j, cross_item in enumerate(self.cross_sections[item]):
                        cross_tmp = cross_item(
                            dust_param['powerlaw_exp'],
                            10.**dust_param['powerlaw_max'])[0]

                        model_flux[j] *= np.exp(-cross_tmp * n_grains)

                elif 'ism_ext' in dust_param:
                    ext_filt = dust_util.ism_extinction(
                        dust_param['ism_ext'], dust_param['ism_red'],
                        self.spectrum[item][0][:, 0])

                    model_flux *= 10.**(-0.4 * ext_filt)

                if self.spectrum[item][2] is not None:
                    # Use the inverted covariance matrix
                    dot_tmp = np.dot(
                        data_flux - model_flux,
                        np.dot(data_cov_inv, data_flux - model_flux))

                    ln_like += -0.5 * dot_tmp - 0.5 * np.nansum(
                        np.log(2. * np.pi * data_var))

                else:
                    if item in self.fit_corr:
                        # Covariance model (Wang et al. 2020)
                        wavel = self.spectrum[item][0][:, 0]  # (um)
                        wavel_j, wavel_i = np.meshgrid(wavel, wavel)

                        error = np.sqrt(data_var)  # (W m-2 um-1)
                        error_j, error_i = np.meshgrid(error, error)

                        cov_matrix = corr_amp[item]**2 * error_i * error_j * \
                            np.exp(-(wavel_i-wavel_j)**2 / (2.*corr_len[item]**2)) + \
                            (1.-corr_amp[item]**2) * np.eye(wavel.shape[0])*error_i**2

                        dot_tmp = np.dot(
                            data_flux - model_flux,
                            np.dot(np.linalg.inv(cov_matrix),
                                   data_flux - model_flux))

                        ln_like += -0.5 * dot_tmp - 0.5 * np.nansum(
                            np.log(2. * np.pi * data_var))

                    else:
                        # Calculate the chi-square without a covariance matrix
                        ln_like += np.nansum(
                            -0.5 * (data_flux - model_flux)**2 / data_var -
                            0.5 * np.log(2. * np.pi * data_var))

            return ln_like
예제 #18
0
def plot_spectrum(
        boxes: list,
        filters: Optional[List[str]] = None,
        residuals: Optional[box.ResidualsBox] = None,
        plot_kwargs: Optional[List[Optional[dict]]] = None,
        xlim: Optional[Tuple[float, float]] = None,
        ylim: Optional[Tuple[float, float]] = None,
        ylim_res: Optional[Tuple[float, float]] = None,
        scale: Optional[Tuple[str, str]] = None,
        title: Optional[str] = None,
        offset: Optional[Tuple[float, float]] = None,
        legend: Optional[Union[str, dict, Tuple[float, float],
                               List[Optional[Union[dict, str,
                                                   Tuple[float,
                                                         float]]]]]] = None,
        figsize: Optional[Tuple[float, float]] = (10., 5.),
        object_type: str = 'planet',
        quantity: str = 'flux density',
        output: str = 'spectrum.pdf'):
    """
    Parameters
    ----------
    boxes : list(species.core.box, )
        Boxes with data.
    filters : list(str, ), None
        Filter IDs for which the transmission profile is plotted. Not plotted if set to None.
    residuals : species.core.box.ResidualsBox, None
        Box with residuals of a fit. Not plotted if set to None.
    plot_kwargs : list(dict, ), None
        List with dictionaries of keyword arguments for each box. For example, if the ``boxes``
        are a ``ModelBox`` and ``ObjectBox``:

        .. code-block:: python

            plot_kwargs=[{'ls': '-', 'lw': 1., 'color': 'black'},
                         {'spectrum_1': {'marker': 'o', 'ms': 3., 'color': 'tab:brown', 'ls': 'none'},
                          'spectrum_2': {'marker': 'o', 'ms': 3., 'color': 'tab:blue', 'ls': 'none'},
                          'Paranal/SPHERE.IRDIS_D_H23_3': {'marker': 's', 'ms': 4., 'color': 'tab:cyan', 'ls': 'none'},
                          'Paranal/SPHERE.IRDIS_D_K12_1': [{'marker': 's', 'ms': 4., 'color': 'tab:orange', 'ls': 'none'},
                                                           {'marker': 's', 'ms': 4., 'color': 'tab:red', 'ls': 'none'}],
                          'Paranal/NACO.Lp': {'marker': 's', 'ms': 4., 'color': 'tab:green', 'ls': 'none'},
                          'Paranal/NACO.Mp': {'marker': 's', 'ms': 4., 'color': 'tab:green', 'ls': 'none'}}]

        For an ``ObjectBox``, the dictionary contains items for the different spectrum and filter
        names stored with :func:`~species.data.database.Database.add_object`. In case both
        and ``ObjectBox`` and a ``SynphotBox`` are provided, then the latter can be set to ``None``
        in order to use the same (but open) symbols as the data from the ``ObjectBox``. Note that
        if a filter name is duplicated in an ``ObjectBox`` (Paranal/SPHERE.IRDIS_D_K12_1 in the
        example) then a list with two dictionaries should be provided. Colors are automatically
        chosen if ``plot_kwargs`` is set to ``None``.
    xlim : tuple(float, float)
        Limits of the wavelength axis.
    ylim : tuple(float, float)
        Limits of the flux axis.
    ylim_res : tuple(float, float), None
        Limits of the residuals axis. Automatically chosen (based on the minimum and maximum
        residual value) if set to None.
    scale : tuple(str, str), None
        Scale of the x and y axes ('linear' or 'log'). The scale is set to ``('linear', 'linear')``
        if set to ``None``.
    title : str
        Title.
    offset : tuple(float, float)
        Offset for the label of the x- and y-axis.
    legend : str, tuple, dict, list(dict, dict), None
        Location of the legend (str or tuple(float, float)) or a dictionary with the ``**kwargs``
        of ``matplotlib.pyplot.legend``, for example ``{'loc': 'upper left', 'fontsize: 12.}``.
        Alternatively, a list with two values can be provided to separate the model and data
        handles in two legends. Each of these two elements can be set to ``None``. For example,
        ``[None, {'loc': 'upper left', 'fontsize: 12.}]``, if only the data points should be
        included in a legend.                  
    figsize : tuple(float, float)
        Figure size.
    object_type : str
        Object type ('planet' or 'star'). With 'planet', the radius and mass are expressed in
        Jupiter units. With 'star', the radius and mass are expressed in solar units.
    quantity: str
        The quantity of the y-axis ('flux density', 'flux', or 'magnitude').
    output : str
        Output filename.

    Returns
    -------
    NoneType
        None
    """

    mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
    mpl.rcParams['font.family'] = 'serif'

    plt.rc('axes', edgecolor='black', linewidth=2.2)
    plt.rcParams['axes.axisbelow'] = False

    if plot_kwargs is None:
        plot_kwargs = []

    elif plot_kwargs is not None and len(boxes) != len(plot_kwargs):
        raise ValueError(
            f'The number of \'boxes\' ({len(boxes)}) should be equal to the '
            f'number of items in \'plot_kwargs\' ({len(plot_kwargs)}).')

    if residuals is not None and filters is not None:
        plt.figure(1, figsize=figsize)
        gridsp = mpl.gridspec.GridSpec(3, 1, height_ratios=[1, 3, 1])
        gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)

        ax1 = plt.subplot(gridsp[1, 0])
        ax2 = plt.subplot(gridsp[0, 0])
        ax3 = plt.subplot(gridsp[2, 0])

    elif residuals is not None:
        plt.figure(1, figsize=figsize)
        gridsp = mpl.gridspec.GridSpec(2, 1, height_ratios=[4, 1])
        gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)

        ax1 = plt.subplot(gridsp[0, 0])
        ax2 = None
        ax3 = plt.subplot(gridsp[1, 0])

    elif filters is not None:
        plt.figure(1, figsize=figsize)
        gridsp = mpl.gridspec.GridSpec(2, 1, height_ratios=[1, 4])
        gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)

        ax1 = plt.subplot(gridsp[1, 0])
        ax2 = plt.subplot(gridsp[0, 0])
        ax3 = None

    else:
        plt.figure(1, figsize=figsize)
        gridsp = mpl.gridspec.GridSpec(1, 1)
        gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)

        ax1 = plt.subplot(gridsp[0, 0])
        ax2 = None
        ax3 = None

    if residuals is not None:
        labelbottom = False
    else:
        labelbottom = True

    if scale is None:
        scale = ('linear', 'linear')

    ax1.set_xscale(scale[0])
    ax1.set_yscale(scale[1])

    if filters is not None:
        ax2.set_xscale(scale[0])

    if residuals is not None:
        ax3.set_xscale(scale[0])

    ax1.tick_params(axis='both',
                    which='major',
                    colors='black',
                    labelcolor='black',
                    direction='in',
                    width=1,
                    length=5,
                    labelsize=12,
                    top=True,
                    bottom=True,
                    left=True,
                    right=True,
                    labelbottom=labelbottom)

    ax1.tick_params(axis='both',
                    which='minor',
                    colors='black',
                    labelcolor='black',
                    direction='in',
                    width=1,
                    length=3,
                    labelsize=12,
                    top=True,
                    bottom=True,
                    left=True,
                    right=True,
                    labelbottom=labelbottom)

    if filters is not None:
        ax2.tick_params(axis='both',
                        which='major',
                        colors='black',
                        labelcolor='black',
                        direction='in',
                        width=1,
                        length=5,
                        labelsize=12,
                        top=True,
                        bottom=True,
                        left=True,
                        right=True,
                        labelbottom=False)

        ax2.tick_params(axis='both',
                        which='minor',
                        colors='black',
                        labelcolor='black',
                        direction='in',
                        width=1,
                        length=3,
                        labelsize=12,
                        top=True,
                        bottom=True,
                        left=True,
                        right=True,
                        labelbottom=False)

    if residuals is not None:
        ax3.tick_params(axis='both',
                        which='major',
                        colors='black',
                        labelcolor='black',
                        direction='in',
                        width=1,
                        length=5,
                        labelsize=12,
                        top=True,
                        bottom=True,
                        left=True,
                        right=True)

        ax3.tick_params(axis='both',
                        which='minor',
                        colors='black',
                        labelcolor='black',
                        direction='in',
                        width=1,
                        length=3,
                        labelsize=12,
                        top=True,
                        bottom=True,
                        left=True,
                        right=True)

    if scale[0] == 'linear':
        ax1.xaxis.set_minor_locator(AutoMinorLocator(5))

    if scale[1] == 'linear':
        ax1.yaxis.set_minor_locator(AutoMinorLocator(5))

    # ax1.set_yticks([1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0])
    # ax3.set_yticks([-2., 0., 2.])

    if filters is not None and scale[0] == 'linear':
        ax2.xaxis.set_minor_locator(AutoMinorLocator(5))

    if residuals is not None and scale[0] == 'linear':
        ax3.xaxis.set_minor_locator(AutoMinorLocator(5))

    if residuals is not None and filters is not None:
        ax1.set_xlabel('')
        ax2.set_xlabel('')
        ax3.set_xlabel('Wavelength (µm)', fontsize=13)

    elif residuals is not None:
        ax1.set_xlabel('')
        ax3.set_xlabel('Wavelength (µm)', fontsize=11)

    elif filters is not None:
        ax1.set_xlabel('Wavelength (µm)', fontsize=13)
        ax2.set_xlabel('')

    else:
        ax1.set_xlabel('Wavelength (µm)', fontsize=13)

    if filters is not None:
        ax2.set_ylabel('Transmission', fontsize=13)

    if residuals is not None:
        if quantity == 'flux density':
            ax3.set_ylabel(r'$\Delta$$\mathregular{F}_\lambda$ ($\sigma$)',
                           fontsize=11)

        elif quantity == 'flux':
            ax3.set_ylabel(r'$\Delta$$\mathregular{F}_\lambda$ ($\sigma$)',
                           fontsize=11)

    if xlim is None:
        ax1.set_xlim(0.6, 6.)
    else:
        ax1.set_xlim(xlim[0], xlim[1])

    if quantity == 'magnitude':
        scaling = 1.
        ax1.set_ylabel('Flux contrast (mag)', fontsize=13)

        if ylim:
            ax1.set_ylim(ylim[0], ylim[1])

    else:
        if ylim:
            ax1.set_ylim(ylim[0], ylim[1])

            ylim = ax1.get_ylim()

            exponent = math.floor(math.log10(ylim[1]))
            scaling = 10.**exponent

            if quantity == 'flux density':
                ylabel = r'$\mathregular{F}_\lambda$ (10$^{' + str(
                    exponent) + r'}$ W m$^{-2}$ µm$^{-1}$)'

            elif quantity == 'flux':
                ylabel = r'$\lambda$$\mathregular{F}_\lambda$ (10$^{' + str(
                    exponent) + r'}$ W m$^{-2}$)'

            ax1.set_ylabel(ylabel, fontsize=11)
            ax1.set_ylim(ylim[0] / scaling, ylim[1] / scaling)

            if ylim[0] < 0.:
                ax1.axhline(0.0,
                            ls='--',
                            lw=0.7,
                            color='gray',
                            dashes=(2, 4),
                            zorder=0.5)

        else:
            if quantity == 'flux density':
                ax1.set_ylabel(
                    r'$\mathregular{F}_\lambda$ (W m$^{-2}$ µm$^{-1}$)',
                    fontsize=11)

            elif quantity == 'flux':
                ax1.set_ylabel(
                    r'$\lambda$$\mathregular{F}_\lambda$ (W m$^{-2}$)',
                    fontsize=11)

            scaling = 1.

    xlim = ax1.get_xlim()

    if filters is not None:
        ax2.set_xlim(xlim[0], xlim[1])
        ax2.set_ylim(0., 1.)

    if residuals is not None:
        ax3.set_xlim(xlim[0], xlim[1])

    if offset is not None and residuals is not None and filters is not None:
        ax3.get_xaxis().set_label_coords(0.5, offset[0])

        ax1.get_yaxis().set_label_coords(offset[1], 0.5)
        ax2.get_yaxis().set_label_coords(offset[1], 0.5)
        ax3.get_yaxis().set_label_coords(offset[1], 0.5)

    elif offset is not None and filters is not None:
        ax1.get_xaxis().set_label_coords(0.5, offset[0])

        ax1.get_yaxis().set_label_coords(offset[1], 0.5)
        ax2.get_yaxis().set_label_coords(offset[1], 0.5)

    elif offset is not None and residuals is not None:
        ax3.get_xaxis().set_label_coords(0.5, offset[0])

        ax1.get_yaxis().set_label_coords(offset[1], 0.5)
        ax3.get_yaxis().set_label_coords(offset[1], 0.5)

    elif offset is not None:
        ax1.get_xaxis().set_label_coords(0.5, offset[0])
        ax1.get_yaxis().set_label_coords(offset[1], 0.5)

    else:
        ax1.get_xaxis().set_label_coords(0.5, -0.12)
        ax1.get_yaxis().set_label_coords(-0.1, 0.5)

    for j, boxitem in enumerate(boxes):
        flux_scaling = 1.

        if j < len(boxes):
            plot_kwargs.append(None)

        if isinstance(boxitem, (box.SpectrumBox, box.ModelBox)):
            wavelength = boxitem.wavelength
            flux = boxitem.flux

            if isinstance(wavelength[0], (np.float32, np.float64)):
                data = np.array(flux, dtype=np.float64)
                masked = np.ma.array(data, mask=np.isnan(data))

                if isinstance(boxitem, box.ModelBox):
                    param = boxitem.parameters

                    par_key, par_unit, par_label = plot_util.quantity_unit(
                        param=list(param.keys()), object_type=object_type)

                    label = ''
                    newline = False

                    for i, item in enumerate(par_key):
                        if item[:4] == 'teff':
                            value = f'{param[item]:.0f}'

                        elif item in [
                                'logg', 'feh', 'co', 'fsed', 'lognorm_ext',
                                'powerlaw_ext', 'ism_ext'
                        ]:
                            value = f'{param[item]:.2f}'

                        elif item[:6] == 'radius':
                            if object_type == 'planet':
                                value = f'{param[item]:.1f}'

                                # if item == 'radius_1':
                                #     value = f'{param[item]:.0f}'
                                # else:
                                #     value = f'{param[item]:.1f}'

                            elif object_type == 'star':
                                value = f'{param[item]*constants.R_JUP/constants.R_SUN:.1f}'

                        elif item == 'mass':
                            if object_type == 'planet':
                                value = f'{param[item]:.1f}'

                            elif object_type == 'star':
                                value = f'{param[item]*constants.M_JUP/constants.M_SUN:.1f}'

                        elif item == 'luminosity':
                            value = f'{np.log10(param[item]):.2f}'

                        else:
                            continue

                        # if len(label) > 80 and newline == False:
                        #     label += '\n'
                        #     newline = True

                        if par_unit[i] is None:
                            label += f'{par_label[i]} = {value}'
                        else:
                            label += f'{par_label[i]} = {value} {par_unit[i]}'

                        if i < len(par_key) - 1:
                            label += ', '

                else:
                    label = None

                if plot_kwargs[j]:
                    kwargs_copy = plot_kwargs[j].copy()

                    if 'label' in kwargs_copy:
                        if kwargs_copy['label'] is None:
                            label = None
                        else:
                            label = kwargs_copy['label']

                        del kwargs_copy['label']

                    if quantity == 'flux':
                        flux_scaling = wavelength

                    ax1.plot(wavelength,
                             flux_scaling * masked / scaling,
                             zorder=2,
                             label=label,
                             **kwargs_copy)

                else:
                    if quantity == 'flux':
                        flux_scaling = wavelength

                    ax1.plot(wavelength,
                             flux_scaling * masked / scaling,
                             lw=0.5,
                             label=label,
                             zorder=2)

            elif isinstance(wavelength[0], (np.ndarray)):
                for i, item in enumerate(wavelength):
                    data = np.array(flux[i], dtype=np.float64)
                    masked = np.ma.array(data, mask=np.isnan(data))

                    if isinstance(boxitem.name[i], bytes):
                        label = boxitem.name[i].decode('utf-8')
                    else:
                        label = boxitem.name[i]

                    if quantity == 'flux':
                        flux_scaling = item

                    ax1.plot(item,
                             flux_scaling * masked / scaling,
                             lw=0.5,
                             label=label)

        elif isinstance(boxitem, list):
            for i, item in enumerate(boxitem):
                wavelength = item.wavelength
                flux = item.flux

                data = np.array(flux, dtype=np.float64)
                masked = np.ma.array(data, mask=np.isnan(data))

                if quantity == 'flux':
                    flux_scaling = wavelength

                if plot_kwargs[j]:
                    ax1.plot(wavelength,
                             flux_scaling * masked / scaling,
                             zorder=1,
                             **plot_kwargs[j])
                else:
                    ax1.plot(wavelength,
                             flux_scaling * masked / scaling,
                             color='gray',
                             lw=0.2,
                             alpha=0.5,
                             zorder=1)

        elif isinstance(boxitem, box.PhotometryBox):
            label_check = []

            for i, item in enumerate(boxitem.wavelength):
                transmission = read_filter.ReadFilter(boxitem.filter_name[i])
                fwhm = transmission.filter_fwhm()

                if quantity == 'flux':
                    flux_scaling = item

                if plot_kwargs[j]:
                    if 'label' in plot_kwargs[j] and plot_kwargs[j][
                            'label'] not in label_check:
                        label_check.append(plot_kwargs[j]['label'])

                    elif 'label' in plot_kwargs[j] and plot_kwargs[j][
                            'label'] in label_check:
                        del plot_kwargs[j]['label']

                    if boxitem.flux[i][1] is None:
                        ax1.errorbar(item,
                                     flux_scaling * boxitem.flux[i][0] /
                                     scaling,
                                     xerr=fwhm / 2.,
                                     yerr=None,
                                     zorder=3,
                                     **plot_kwargs[j])

                    else:
                        ax1.errorbar(
                            item,
                            flux_scaling * boxitem.flux[i][0] / scaling,
                            xerr=fwhm / 2.,
                            yerr=flux_scaling * boxitem.flux[i][1] / scaling,
                            zorder=3,
                            **plot_kwargs[j])

                else:
                    if boxitem.flux[i][1] is None:
                        ax1.errorbar(item,
                                     flux_scaling * boxitem.flux[i][0] /
                                     scaling,
                                     xerr=fwhm / 2.,
                                     yerr=None,
                                     marker='s',
                                     ms=6,
                                     color='black',
                                     zorder=3)

                    else:
                        ax1.errorbar(
                            item,
                            flux_scaling * boxitem.flux[i][0] / scaling,
                            xerr=fwhm / 2.,
                            yerr=flux_scaling * boxitem.flux[i][1] / scaling,
                            marker='s',
                            ms=6,
                            color='black',
                            zorder=3)

        elif isinstance(boxitem, box.ObjectBox):
            if boxitem.spectrum is not None:
                spec_list = []
                wavel_list = []

                for item in boxitem.spectrum:
                    spec_list.append(item)
                    wavel_list.append(boxitem.spectrum[item][0][0, 0])

                sort_index = np.argsort(wavel_list)
                spec_sort = []

                for i in range(sort_index.size):
                    spec_sort.append(spec_list[sort_index[i]])

                for key in spec_sort:
                    masked = np.ma.array(boxitem.spectrum[key][0],
                                         mask=np.isnan(
                                             boxitem.spectrum[key][0]))

                    if quantity == 'flux':
                        flux_scaling = masked[:, 0]

                    if not plot_kwargs[j] or key not in plot_kwargs[j]:
                        plot_obj = ax1.errorbar(
                            masked[:, 0],
                            flux_scaling * masked[:, 1] / scaling,
                            yerr=flux_scaling * masked[:, 2] / scaling,
                            ms=2,
                            marker='s',
                            zorder=2.5,
                            ls='none')

                        if plot_kwargs[j] is None:
                            plot_kwargs[j] = {}

                        plot_kwargs[j][key] = {
                            'marker': 's',
                            'ms': 2.,
                            'ls': 'none',
                            'color': plot_obj[0].get_color()
                        }

                    else:
                        ax1.errorbar(masked[:, 0],
                                     flux_scaling * masked[:, 1] / scaling,
                                     yerr=flux_scaling * masked[:, 2] /
                                     scaling,
                                     zorder=2.5,
                                     **plot_kwargs[j][key])

            if boxitem.flux is not None:
                filter_list = []
                wavel_list = []

                for item in boxitem.flux:
                    read_filt = read_filter.ReadFilter(item)
                    filter_list.append(item)
                    wavel_list.append(read_filt.mean_wavelength())

                sort_index = np.argsort(wavel_list)
                filter_sort = []

                for i in range(sort_index.size):
                    filter_sort.append(filter_list[sort_index[i]])

                for item in filter_sort:
                    transmission = read_filter.ReadFilter(item)
                    wavelength = transmission.mean_wavelength()
                    fwhm = transmission.filter_fwhm()

                    if not plot_kwargs[j] or item not in plot_kwargs[j]:
                        if not plot_kwargs[j]:
                            plot_kwargs[j] = {}

                        if quantity == 'flux':
                            flux_scaling = wavelength

                        if isinstance(boxitem.flux[item][0], np.ndarray):
                            for i in range(boxitem.flux[item].shape[1]):

                                plot_obj = ax1.errorbar(
                                    wavelength,
                                    flux_scaling * boxitem.flux[item][0, i] /
                                    scaling,
                                    xerr=fwhm / 2.,
                                    yerr=flux_scaling *
                                    boxitem.flux[item][1, i] / scaling,
                                    marker='s',
                                    ms=5,
                                    zorder=3)

                        else:
                            plot_obj = ax1.errorbar(
                                wavelength,
                                flux_scaling * boxitem.flux[item][0] / scaling,
                                xerr=fwhm / 2.,
                                yerr=flux_scaling * boxitem.flux[item][1] /
                                scaling,
                                marker='s',
                                ms=5,
                                zorder=3)

                        plot_kwargs[j][item] = {
                            'marker': 's',
                            'ms': 5.,
                            'color': plot_obj[0].get_color()
                        }

                    else:
                        if quantity == 'flux':
                            flux_scaling = wavelength

                        if isinstance(boxitem.flux[item][0], np.ndarray):
                            if not isinstance(plot_kwargs[j][item], list):
                                raise ValueError(
                                    f'A list with {boxitem.flux[item].shape[1]} '
                                    f'dictionaries are required because the filter '
                                    f'{item} has {boxitem.flux[item].shape[1]} '
                                    f'values.')

                            for i in range(boxitem.flux[item].shape[1]):
                                ax1.errorbar(
                                    wavelength,
                                    flux_scaling * boxitem.flux[item][0, i] /
                                    scaling,
                                    xerr=fwhm / 2.,
                                    yerr=flux_scaling *
                                    boxitem.flux[item][1, i] / scaling,
                                    zorder=3,
                                    **plot_kwargs[j][item][i])

                        else:
                            if boxitem.flux[item][1] == 0.:
                                ax1.errorbar(wavelength,
                                             flux_scaling *
                                             boxitem.flux[item][0] / scaling,
                                             xerr=fwhm / 2.,
                                             yerr=0.5 * flux_scaling *
                                             boxitem.flux[item][0] / scaling,
                                             uplims=True,
                                             capsize=2.,
                                             capthick=0.,
                                             zorder=3,
                                             **plot_kwargs[j][item])

                            else:
                                ax1.errorbar(wavelength,
                                             flux_scaling *
                                             boxitem.flux[item][0] / scaling,
                                             xerr=fwhm / 2.,
                                             yerr=flux_scaling *
                                             boxitem.flux[item][1] / scaling,
                                             zorder=3,
                                             **plot_kwargs[j][item])

        elif isinstance(boxitem, box.SynphotBox):
            for i, find_item in enumerate(boxes):
                if isinstance(find_item, box.ObjectBox):
                    obj_index = i
                    break

            for item in boxitem.flux:
                transmission = read_filter.ReadFilter(item)
                wavelength = transmission.mean_wavelength()
                fwhm = transmission.filter_fwhm()

                if quantity == 'flux':
                    flux_scaling = wavelength

                if not plot_kwargs[obj_index] or item not in plot_kwargs[
                        obj_index]:
                    ax1.errorbar(wavelength,
                                 flux_scaling * boxitem.flux[item] / scaling,
                                 xerr=fwhm / 2.,
                                 yerr=None,
                                 alpha=0.7,
                                 marker='s',
                                 ms=5,
                                 zorder=4,
                                 mfc='white')

                else:
                    if isinstance(plot_kwargs[obj_index][item], list):
                        # In case of multiple photometry values for the same filter, use the
                        # plot_kwargs of the first data point

                        kwargs_copy = plot_kwargs[obj_index][item][0].copy()

                        if 'label' in kwargs_copy:
                            del kwargs_copy['label']

                        ax1.errorbar(wavelength,
                                     flux_scaling * boxitem.flux[item] /
                                     scaling,
                                     xerr=fwhm / 2.,
                                     yerr=None,
                                     zorder=4,
                                     mfc='white',
                                     **kwargs_copy)

                    else:
                        kwargs_copy = plot_kwargs[obj_index][item].copy()

                        if 'label' in kwargs_copy:
                            del kwargs_copy['label']

                        ax1.errorbar(wavelength,
                                     flux_scaling * boxitem.flux[item] /
                                     scaling,
                                     xerr=fwhm / 2.,
                                     yerr=None,
                                     zorder=4,
                                     mfc='white',
                                     **kwargs_copy)

    if filters is not None:
        for i, item in enumerate(filters):
            transmission = read_filter.ReadFilter(item)
            data = transmission.get_filter()

            ax2.plot(data[:, 0],
                     data[:, 1],
                     '-',
                     lw=0.7,
                     color='black',
                     zorder=1)

    if residuals is not None:
        for i, find_item in enumerate(boxes):
            if isinstance(find_item, box.ObjectBox):
                obj_index = i
                break

        res_max = 0.

        if residuals.photometry is not None:
            for item in residuals.photometry:
                if not plot_kwargs[obj_index] or item not in plot_kwargs[
                        obj_index]:
                    ax3.plot(residuals.photometry[item][0],
                             residuals.photometry[item][1],
                             marker='s',
                             ms=5,
                             linestyle='none',
                             zorder=2)

                else:
                    if residuals.photometry[item].ndim == 1:
                        ax3.errorbar(residuals.photometry[item][0],
                                     residuals.photometry[item][1],
                                     zorder=2,
                                     **plot_kwargs[obj_index][item])

                    elif residuals.photometry[item].ndim == 2:
                        for i in range(residuals.photometry[item].shape[1]):
                            if isinstance(plot_kwargs[obj_index][item], list):
                                ax3.errorbar(residuals.photometry[item][0, i],
                                             residuals.photometry[item][1, i],
                                             zorder=2,
                                             **plot_kwargs[obj_index][item][i])

                            else:
                                ax3.errorbar(residuals.photometry[item][0, i],
                                             residuals.photometry[item][1, i],
                                             zorder=2,
                                             **plot_kwargs[obj_index][item])

                res_max = np.nanmax(np.abs(residuals.photometry[item][1]))

        if residuals.spectrum is not None:
            for key, value in residuals.spectrum.items():
                if not plot_kwargs[obj_index] or key not in plot_kwargs[
                        obj_index]:
                    ax3.errorbar(value[:, 0],
                                 value[:, 1],
                                 marker='o',
                                 ms=2,
                                 ls='none',
                                 zorder=1)

                else:
                    ax3.errorbar(value[:, 0],
                                 value[:, 1],
                                 zorder=1,
                                 **plot_kwargs[obj_index][key])

                max_tmp = np.nanmax(np.abs(value[:, 1]))

                if max_tmp > res_max:
                    res_max = max_tmp

        res_lim = math.ceil(1.1 * res_max)

        if res_lim > 10.:
            res_lim = 5.

        ax3.axhline(0.,
                    ls='--',
                    lw=0.7,
                    color='gray',
                    dashes=(2, 4),
                    zorder=0.5)
        # ax3.axhline(-2.5, ls=':', lw=0.7, color='gray', dashes=(1, 4), zorder=0.5)
        # ax3.axhline(2.5, ls=':', lw=0.7, color='gray', dashes=(1, 4), zorder=0.5)

        if ylim_res is None:
            ax3.set_ylim(-res_lim, res_lim)

        else:
            ax3.set_ylim(ylim_res[0], ylim_res[1])

    if filters is not None:
        ax2.set_ylim(0., 1.1)

    print(f'Plotting spectrum: {output}...', end='', flush=True)

    if title is not None:
        if filters:
            ax2.set_title(title, y=1.02, fontsize=13)
        else:
            ax1.set_title(title, y=1.02, fontsize=13)

    handles, labels = ax1.get_legend_handles_labels()

    if handles and legend is not None:
        if isinstance(legend, list):
            model_handles = []
            data_handles = []

            model_labels = []
            data_labels = []

            for i, item in enumerate(handles):
                if isinstance(item, mpl.lines.Line2D):
                    model_handles.append(item)
                    model_labels.append(labels[i])

                elif isinstance(item, mpl.container.ErrorbarContainer):
                    data_handles.append(item)
                    data_labels.append(labels[i])

                else:
                    warnings.warn(
                        f'The object type {item} is not implemented for the legend.'
                    )

            if legend[0] is not None:
                if isinstance(legend[0], (str, tuple)):
                    leg_1 = ax1.legend(model_handles,
                                       model_labels,
                                       loc=legend[0],
                                       fontsize=10.,
                                       frameon=False)
                else:
                    leg_1 = ax1.legend(model_handles, model_labels,
                                       **legend[0])

            else:
                leg_1 = None

            if legend[1] is not None:
                if isinstance(legend[1], (str, tuple)):
                    leg_2 = ax1.legend(data_handles,
                                       data_labels,
                                       loc=legend[1],
                                       fontsize=8,
                                       frameon=False)
                else:
                    leg_2 = ax1.legend(data_handles, data_labels, **legend[1])

            if leg_1 is not None:
                ax1.add_artist(leg_1)

        elif isinstance(legend, (str, tuple)):
            ax1.legend(loc=legend, fontsize=8, frameon=False)

        else:
            ax1.legend(**legend)

    # filters = ['Paranal/SPHERE.ZIMPOL_N_Ha',
    #            'MUSE/Hbeta',
    #            'ALMA/855']
    #
    # filters = ['Paranal/SPHERE.IRDIS_B_Y',
    #            'MKO/NSFCam.J',
    #            'Paranal/SPHERE.IRDIS_D_H23_2',
    #            'Paranal/SPHERE.IRDIS_D_H23_3',
    #            'Paranal/SPHERE.IRDIS_D_K12_1',
    #            'Paranal/SPHERE.IRDIS_D_K12_2',
    #            'Paranal/NACO.Lp',
    #            'Paranal/NACO.NB405',
    #            'Paranal/NACO.Mp']
    #
    # for i, item in enumerate(filters):
    #     readfilter = read_filter.ReadFilter(item)
    #     filter_wavelength = readfilter.mean_wavelength()
    #     filter_width = readfilter.filter_fwhm()
    #
    #     # if i == 5:
    #     #     ax1.errorbar(filter_wavelength, 1.3e4, xerr=filter_width/2., color='dimgray', elinewidth=2.5, zorder=10)
    #     # else:
    #     #     ax1.errorbar(filter_wavelength, 6e3, xerr=filter_width/2., color='dimgray', elinewidth=2.5, zorder=10)
    #
    #     if i == 0:
    #         ax1.text(filter_wavelength, 1e-2, r'H$\alpha$', ha='center', va='center', fontsize=10, color='black')
    #     elif i == 1:
    #         ax1.text(filter_wavelength, 1e-2, r'H$\beta$', ha='center', va='center', fontsize=10, color='black')
    #     elif i == 2:
    #         ax1.text(filter_wavelength, 1e-2, 'ALMA\nband 7 rms', ha='center', va='center', fontsize=8, color='black')
    #
    #     if i == 0:
    #         ax1.text(filter_wavelength, 1.4, 'Y', ha='center', va='center', fontsize=10, color='black')
    #     elif i == 1:
    #         ax1.text(filter_wavelength, 1.4, 'J', ha='center', va='center', fontsize=10, color='black')
    #     elif i == 2:
    #         ax1.text(filter_wavelength-0.04, 1.4, 'H2', ha='center', va='center', fontsize=10, color='black')
    #     elif i == 3:
    #         ax1.text(filter_wavelength+0.04, 1.4, 'H3', ha='center', va='center', fontsize=10, color='black')
    #     elif i == 4:
    #         ax1.text(filter_wavelength, 1.4, 'K1', ha='center', va='center', fontsize=10, color='black')
    #     elif i == 5:
    #         ax1.text(filter_wavelength, 1.4, 'K2', ha='center', va='center', fontsize=10, color='black')
    #     elif i == 6:
    #         ax1.text(filter_wavelength, 1.4, 'L$\'$', ha='center', va='center', fontsize=10, color='black')
    #     elif i == 7:
    #         ax1.text(filter_wavelength, 1.4, 'NB4.05', ha='center', va='center', fontsize=10, color='black')
    #     elif i == 8:
    #         ax1.text(filter_wavelength, 1.4, 'M$\'}$', ha='center', va='center', fontsize=10, color='black')
    #
    # ax1.text(1.26, 0.58, 'VLT/SPHERE', ha='center', va='center', fontsize=8., color='slateblue', rotation=43.)
    # ax1.text(2.5, 1.28, 'VLT/SINFONI', ha='left', va='center', fontsize=8., color='darkgray')

    plt.savefig(os.getcwd() + '/' + output, bbox_inches='tight')
    plt.clf()
    plt.close()

    print(' [DONE]')
예제 #19
0
    def __init__(
        self,
        line_species: Optional[List[str]] = None,
        cloud_species: Optional[List[str]] = None,
        scattering: bool = False,
        wavel_range: Optional[Tuple[float, float]] = None,
        filter_name: Optional[str] = None,
        pressure_grid: str = "smaller",
        res_mode: str = "c-k",
        cloud_wavel: Optional[Tuple[float, float]] = None,
        max_press: float = None,
        pt_manual: Optional[np.ndarray] = None,
    ) -> None:
        """
        Parameters
        ----------
        line_species : list, None
            List with the line species. No line species are used if set
            to ``None``.
        cloud_species : list, None
            List with the cloud species. No clouds are used if set to
            ``None``.
        scattering : bool
            Include scattering in the radiative transfer.
        wavel_range : tuple(float, float), None
            Wavelength range (um). The wavelength range is set to
            0.8-10 um if set to ``None`` or not used if ``filter_name``
            is not ``None``.
        filter_name : str, None
            Filter name that is used for the wavelength range. The
            ``wavel_range`` is used if ''filter_name`` is set to
            ``None``.
        pressure_grid : str
            The type of pressure grid that is used for the radiative
            transfer. Either 'standard', to use 180 layers both for
            the atmospheric structure (e.g. when interpolating the
            abundances) and 180 layers with the radiative transfer,
            or 'smaller' to use 60 (instead of 180) with the radiative
            transfer, or 'clouds' to start with 1440 layers but
            resample to ~100 layers (depending on the number of cloud
            species) with a refinement around the cloud decks. For
            cloudless atmospheres it is recommended to use 'smaller',
            which runs faster than 'standard' and provides sufficient
            accuracy. For cloudy atmosphere, one can test with
            'smaller' but it is recommended to use 'clouds' for
            improved accuracy fluxes.
        res_mode : str
            Resolution mode ('c-k' or 'lbl'). The low-resolution mode
            ('c-k') calculates the spectrum with the correlated-k
            assumption at :math:`\\lambda/\\Delta \\lambda = 1000`. The
            high-resolution mode ('lbl') calculates the spectrum with a
            line-by-line treatment at
            :math:`\\lambda/\\Delta \\lambda = 10^6`.
        cloud_wavel : tuple(float, float), None
            Tuple with the wavelength range (um) that is used for
            calculating the median optical depth of the clouds at the
            gas-only photosphere and then scaling the cloud optical
            depth to the value of ``log_tau_cloud``. The range of
            ``cloud_wavel`` should be encompassed by the range of
            ``wavel_range``.  The full wavelength range (i.e.
            ``wavel_range``) is used if the argument is set to
            ``None``.
        max_pressure : float, None
            Maximum pressure (bar) for the free temperature nodes. The
            default is set to 1000 bar.
        pt_manual : np.ndarray, None
            A 2D array that contains the P-T profile that is used
            when ``pressure_grid="manual"``. The shape of array should
            be (n_pressure, 2), with pressure (bar) as first column
            and temperature (K) as second column. It is recommended
            that the pressures are logarithmically spaced.

        Returns
        -------
        NoneType
            None
        """

        # Set several of the required ReadRadtrans attributes

        self.filter_name = filter_name
        self.wavel_range = wavel_range
        self.scattering = scattering
        self.pressure_grid = pressure_grid
        self.cloud_wavel = cloud_wavel
        self.pt_manual = pt_manual

        # Set maximum pressure

        if max_press is None:
            self.max_press = 1e3
        else:
            self.max_press = max_press

        # Set the wavelength range

        if self.filter_name is not None:
            transmission = read_filter.ReadFilter(self.filter_name)
            self.wavel_range = transmission.wavelength_range()
            self.wavel_range = (0.9 * self.wavel_range[0],
                                1.2 * self.wavel_range[1])

        elif self.wavel_range is None:
            self.wavel_range = (0.8, 10.0)

        # Set the list with line species

        if line_species is None:
            self.line_species = []
        else:
            self.line_species = line_species

        # Set the list with cloud species and the number of P-T points

        if cloud_species is None:
            self.cloud_species = []
        else:
            self.cloud_species = cloud_species

        # Set the number of pressures

        if self.pressure_grid in ["standard", "smaller"]:
            # Initiate 180 pressure layers but use only
            # 60 layers during the radiative transfer
            # when pressure_grid is set to 'smaller'
            n_pressure = 180

        elif self.pressure_grid == "clouds":
            # Initiate 1140 pressure layers but use fewer
            # layers (~100) during the radiative tranfer
            # after running make_half_pressure_better
            n_pressure = 1440

        else:
            raise ValueError(f"The argument of pressure_grid "
                             f"('{self.pressure_grid}') is "
                             f"not recognized. Please use "
                             f"'standard', 'smaller', or 'clouds'.")

        # Create 180 pressure layers in log space

        if self.pressure_grid == "manual":
            if self.pt_manual is None:
                raise UserWarning("A 2D array with the P-T profile "
                                  "should be provided as argument "
                                  "of pt_manual when using "
                                  "pressure_grid='manual'.")

            self.pressure = self.pt_manual[:, 0]

        else:
            self.pressure = np.logspace(-6, np.log10(self.max_press),
                                        n_pressure)

        # Import petitRADTRANS here because it is slow

        print("Importing petitRADTRANS...", end="", flush=True)
        from petitRADTRANS.radtrans import Radtrans

        print(" [DONE]")

        # Create the Radtrans object

        self.rt_object = Radtrans(
            line_species=self.line_species,
            rayleigh_species=["H2", "He"],
            cloud_species=self.cloud_species,
            continuum_opacities=["H2-H2", "H2-He"],
            wlen_bords_micron=self.wavel_range,
            mode=res_mode,
            test_ck_shuffle_comp=self.scattering,
            do_scat_emis=self.scattering,
        )

        # Setup the opacity arrays

        if self.pressure_grid == "standard":
            self.rt_object.setup_opa_structure(self.pressure)

        elif self.pressure_grid == "manual":
            self.rt_object.setup_opa_structure(self.pressure)

        elif self.pressure_grid == "smaller":
            self.rt_object.setup_opa_structure(self.pressure[::3])

        elif self.pressure_grid == "clouds":
            self.rt_object.setup_opa_structure(self.pressure[::24])
예제 #20
0
    def compare_model(
        self,
        tag: str,
        model: str,
        av_points: Optional[Union[List[float], np.array]] = None,
        fix_logg: Optional[float] = None,
        scale_spec: Optional[List[str]] = None,
        weights: bool = True,
        inc_phot: Optional[List[str]] = None,
    ) -> None:
        """
        Method for finding the best fitting spectrum from a grid of atmospheric model spectra by
        evaluating the goodness-of-fit statistic from Cushing et al. (2008). Currently, this method
        only supports model grids with only :math:`T_\\mathrm{eff}` and :math:`\\log(g)` as free
        parameters (e.g. BT-Settl). Please create an issue on Github if support for models with
        more than two parameters is required.

        Parameters
        ----------
        tag : str
            Database tag where for each spectrum from the spectral library the best-fit parameters
            will be stored. So when testing a range of values for ``av_ext`` and ``rad_vel``, only
            the parameters that minimize the goodness-of-fit statistic will be stored.
        model : str
            Name of the atmospheric model grid with synthetic spectra.
        av_points : list(float), np.array, None
            List of :math:`A_V` extinction values for which the goodness-of-fit statistic will be
            tested. The extinction is calculated with the relation from Cardelli et al. (1989).
        fix_logg : float, None
            Fix the value of :math:`\\log(g)`, for example if estimated from gravity-sensitive
            spectral features. Typically, :math:`\\log(g)` can not be accurately determined when
            comparing the spectra over a broad wavelength range.
        scale_spec : list(str), None
            List with names of observed spectra to which a flux scaling is applied to best match
            the spectral templates.
        weights : bool
            Apply a weighting based on the widths of the wavelengths bins.
        inc_phot : list(str), None
            Filter names of the photometry to include in the comparison. Photometry points are
            weighted by the FWHM of the filter profile. No photometric fluxes will be used if the
            argument is set to ``None``.

        Returns
        -------
        NoneType
            None
        """

        w_i = {}

        for spec_item in self.spec_name:
            obj_wavel = self.object.get_spectrum()[spec_item][0][:, 0]

            diff = (np.diff(obj_wavel)[1:] + np.diff(obj_wavel)[:-1]) / 2.0
            diff = np.insert(diff, 0, diff[0])
            diff = np.append(diff, diff[-1])

            if weights:
                w_i[spec_item] = diff
            else:
                w_i[spec_item] = np.ones(obj_wavel.shape[0])

        if inc_phot is None:
            inc_phot = []

        if scale_spec is None:
            scale_spec = []

        phot_wavel = {}

        for phot_item in inc_phot:
            read_filt = read_filter.ReadFilter(phot_item)
            w_i[phot_item] = read_filt.filter_fwhm()
            phot_wavel[phot_item] = read_filt.mean_wavelength()

        if av_points is None:
            av_points = np.array([0.0])

        elif isinstance(av_points, list):
            av_points = np.array(av_points)

        readmodel = read_model.ReadModel(model)

        model_param = readmodel.get_parameters()
        grid_points = readmodel.get_points()

        coord_points = []
        for key, value in grid_points.items():
            if key == "logg" and fix_logg is not None:
                if fix_logg in value:
                    coord_points.append(np.array([fix_logg]))

                else:
                    raise ValueError(
                        f"The argument of 'fix_logg' ({fix_logg}) is not found "
                        f"in the parameter grid of the model spectra. The following "
                        f"values of log(g) are available: {value}")

            else:
                coord_points.append(value)

        if av_points is not None:
            model_param.append("ism_ext")
            coord_points.append(av_points)

        grid_shape = []

        for item in coord_points:
            grid_shape.append(len(item))

        fit_stat = np.zeros(grid_shape)
        flux_scaling = np.zeros(grid_shape)

        if len(scale_spec) == 0:
            extra_scaling = None

        else:
            grid_shape.append(len(scale_spec))
            extra_scaling = np.zeros(grid_shape)

        count = 1

        if len(coord_points) == 3:
            n_iter = len(coord_points[0]) * len(coord_points[1]) * len(
                coord_points[2])

            for i, item_i in enumerate(coord_points[0]):
                for j, item_j in enumerate(coord_points[1]):
                    for k, item_k in enumerate(coord_points[2]):
                        print(
                            f"\rProcessing model spectrum {count}/{n_iter}...",
                            end="")

                        model_spec = {}
                        model_phot = {}

                        for spec_item in self.spec_name:
                            obj_spec = self.object.get_spectrum()[spec_item][0]
                            obj_res = self.object.get_spectrum()[spec_item][3]

                            param_dict = {
                                model_param[0]: item_i,
                                model_param[1]: item_j,
                                model_param[2]: item_k,
                            }

                            wavel_range = (0.9 * obj_spec[0, 0],
                                           1.1 * obj_spec[-1, 0])
                            readmodel = read_model.ReadModel(
                                model, wavel_range=wavel_range)

                            model_box = readmodel.get_data(
                                param_dict,
                                spec_res=obj_res,
                                wavel_resample=obj_spec[:, 0],
                            )

                            model_spec[spec_item] = model_box.flux

                        for phot_item in inc_phot:
                            readmodel = read_model.ReadModel(
                                model, filter_name=phot_item)

                            model_phot[phot_item] = readmodel.get_flux(
                                param_dict)[0]

                        def g_fit(x, scaling):
                            g_stat = 0.0

                            for spec_item in self.spec_name:
                                obs_spec = self.object.get_spectrum(
                                )[spec_item][0]

                                if spec_item in scale_spec:
                                    spec_idx = scale_spec.index(spec_item)

                                    c_numer = (w_i[spec_item] *
                                               obs_spec[:, 1] *
                                               model_spec[spec_item] /
                                               obs_spec[:, 2]**2)

                                    c_denom = (w_i[spec_item] *
                                               model_spec[spec_item]**2 /
                                               obs_spec[:, 2]**2)

                                    extra_scaling[i, j, k, spec_idx] = np.sum(
                                        c_numer) / np.sum(c_denom)

                                    g_stat += np.sum(
                                        w_i[spec_item] *
                                        (obs_spec[:, 1] -
                                         extra_scaling[i, j, k, spec_idx] *
                                         model_spec[spec_item])**2 /
                                        obs_spec[:, 2]**2)

                                else:
                                    g_stat += np.sum(
                                        w_i[spec_item] *
                                        (obs_spec[:, 1] -
                                         scaling * model_spec[spec_item])**2 /
                                        obs_spec[:, 2]**2)

                            for phot_item in inc_phot:
                                obs_phot = self.object.get_photometry(
                                    phot_item)

                                g_stat += (
                                    w_i[phot_item] *
                                    (obs_phot[2] -
                                     scaling * model_phot[phot_item])**2 /
                                    obs_phot[3]**2)

                            return g_stat

                        popt, _ = curve_fit(g_fit, xdata=[0.0], ydata=[0.0])
                        scaling = popt[0]

                        flux_scaling[i, j, k] = scaling
                        fit_stat[i, j, k] = g_fit(0.0, scaling)

                        count += 1

        print(" [DONE]")

        species_db = database.Database()

        species_db.add_comparison(
            tag=tag,
            goodness_of_fit=fit_stat,
            flux_scaling=flux_scaling,
            model_param=model_param,
            coord_points=coord_points,
            object_name=self.object_name,
            spec_name=self.spec_name,
            model=model,
            scale_spec=scale_spec,
            extra_scaling=extra_scaling,
        )
예제 #21
0
    def apply_powerlaw_ext(wavelength: np.ndarray,
                           flux: np.ndarray,
                           r_max_interp: float,
                           exp_interp: float,
                           v_band_ext: float) -> np.ndarray:
        """
        Internal function for applying extinction by dust to a spectrum.

        wavelength : np.ndarray
            Wavelengths (um) of the spectrum.
        flux : np.ndarray
            Fluxes (W m-2 um-1) of the spectrum.
        r_max_interp : float
            Maximum radius (um) of the power-law size distribution.
        exp_interp : float
            Exponent of the power-law size distribution.
        v_band_ext : float
            The extinction (mag) in the V band.

        Returns
        -------
        np.ndarray
            Fluxes (W m-2 um-1) with the extinction applied.
        """

        database_path = dust_util.check_dust_database()

        with h5py.File(database_path, 'r') as h5_file:
            dust_cross = np.asarray(h5_file['dust/powerlaw/mgsio3/crystalline/cross_section'])
            dust_wavel = np.asarray(h5_file['dust/powerlaw/mgsio3/crystalline/wavelength'])
            dust_r_max = np.asarray(h5_file['dust/powerlaw/mgsio3/crystalline/radius_max'])
            dust_exp = np.asarray(h5_file['dust/powerlaw/mgsio3/crystalline/exponent'])

        dust_interp = RegularGridInterpolator((dust_wavel, dust_r_max, dust_exp),
                                              dust_cross,
                                              method='linear',
                                              bounds_error=True)

        read_filt = read_filter.ReadFilter('Generic/Bessell.V')
        filt_trans = read_filt.get_filter()

        cross_phot = np.zeros((dust_r_max.shape[0], dust_exp.shape[0]))

        for i in range(dust_r_max.shape[0]):
            for j in range(dust_exp.shape[0]):
                cross_interp = interp1d(dust_wavel,
                                        dust_cross[:, i, j],
                                        kind='linear',
                                        bounds_error=True)

                cross_tmp = cross_interp(filt_trans[:, 0])

                integral1 = np.trapz(filt_trans[:, 1]*cross_tmp, filt_trans[:, 0])
                integral2 = np.trapz(filt_trans[:, 1], filt_trans[:, 0])

                # Filter-weighted average of the extinction cross section
                cross_phot[i, j] = integral1/integral2

        cross_interp = interp2d(dust_exp,
                                dust_r_max,
                                cross_phot,
                                kind='linear',
                                bounds_error=True)

        cross_v_band = cross_interp(exp_interp, 10.**r_max_interp)[0]

        r_max_full = np.full(wavelength.shape[0], 10.**r_max_interp)
        exp_full = np.full(wavelength.shape[0], exp_interp)

        cross_new = dust_interp(np.column_stack((wavelength, r_max_full, exp_full)))

        n_grains = v_band_ext / cross_v_band / 2.5 / np.log10(np.exp(1.))

        return flux * np.exp(-cross_new*n_grains)
예제 #22
0
    def apply_lognorm_ext(wavelength: np.ndarray,
                          flux: np.ndarray,
                          radius_interp: float,
                          sigma_interp: float,
                          v_band_ext: float) -> np.ndarray:
        """
        Internal function for applying extinction by dust to a spectrum.

        wavelength : np.ndarray
            Wavelengths (um) of the spectrum.
        flux : np.ndarray
            Fluxes (W m-2 um-1) of the spectrum.
        radius_interp : float
            Logarithm of the mean geometric radius (um) of the log-normal size distribution.
        sigma_interp : float
            Geometric standard deviation (dimensionless) of the log-normal size distribution.
        v_band_ext : float
            The extinction (mag) in the V band.

        Returns
        -------
        np.ndarray
            Fluxes (W m-2 um-1) with the extinction applied.
        """

        database_path = dust_util.check_dust_database()

        with h5py.File(database_path, 'r') as h5_file:
            dust_cross = np.asarray(h5_file['dust/lognorm/mgsio3/crystalline/cross_section'])
            dust_wavel = np.asarray(h5_file['dust/lognorm/mgsio3/crystalline/wavelength'])
            dust_radius = np.asarray(h5_file['dust/lognorm/mgsio3/crystalline/radius_g'])
            dust_sigma = np.asarray(h5_file['dust/lognorm/mgsio3/crystalline/sigma_g'])

        dust_interp = RegularGridInterpolator((dust_wavel, dust_radius, dust_sigma),
                                              dust_cross,
                                              method='linear',
                                              bounds_error=True)

        read_filt = read_filter.ReadFilter('Generic/Bessell.V')
        filt_trans = read_filt.get_filter()

        cross_phot = np.zeros((dust_radius.shape[0], dust_sigma.shape[0]))

        for i in range(dust_radius.shape[0]):
            for j in range(dust_sigma.shape[0]):
                cross_interp = interp1d(dust_wavel,
                                        dust_cross[:, i, j],
                                        kind='linear',
                                        bounds_error=True)

                cross_tmp = cross_interp(filt_trans[:, 0])

                integral1 = np.trapz(filt_trans[:, 1]*cross_tmp, filt_trans[:, 0])
                integral2 = np.trapz(filt_trans[:, 1], filt_trans[:, 0])

                # Filter-weighted average of the extinction cross section
                cross_phot[i, j] = integral1/integral2

        cross_interp = interp2d(dust_sigma,
                                dust_radius,
                                cross_phot,
                                kind='linear',
                                bounds_error=True)

        cross_v_band = cross_interp(sigma_interp, 10.**radius_interp)[0]

        radius_full = np.full(wavelength.shape[0], 10.**radius_interp)
        sigma_full = np.full(wavelength.shape[0], sigma_interp)

        cross_new = dust_interp(np.column_stack((wavelength, radius_full, sigma_full)))

        n_grains = v_band_ext / cross_v_band / 2.5 / np.log10(np.exp(1.))

        return flux * np.exp(-cross_new*n_grains)
예제 #23
0
    def interpolate_grid(self,
                         wavel_resample: Optional[np.ndarray] = None,
                         smooth: bool = False,
                         spec_res: Optional[float] = None) -> None:
        """
        Internal function for linearly interpolating the grid of model spectra for a given
        filter or wavelength sampling.

        wavel_resample : np.ndarray, None
            Wavelength points for the resampling of the spectrum. The ``filter_name`` is used
            if set to ``None``.
        smooth : bool
            Smooth the spectrum with a Gaussian line spread function. Only recommended in case the
            input wavelength sampling has a uniform spectral resolution.
        spec_res : float
            Spectral resolution that is used for the Gaussian filter when ``smooth=True``.

        Returns
        -------
        NoneType
            None
        """

        self.interpolate_model()

        if smooth and wavel_resample is None:
            raise ValueError('Smoothing is only required if the spectra are resampled to a new '
                             'wavelength grid, therefore requiring the \'wavel_resample\' '
                             'argument.')

        points = []
        for item in self.get_points().values():
            points.append(list(item))

        param = self.get_parameters()
        n_param = len(param)

        dim_size = []
        for item in points:
            dim_size.append(len(item))

        if self.filter_name is not None:
            dim_size.append(1)
        else:
            dim_size.append(wavel_resample.size)

        flux_new = np.zeros(dim_size)

        if n_param == 2:
            model_param = {}

            for i, item_0 in enumerate(points[0]):
                for j, item_1 in enumerate(points[1]):
                    model_param[param[0]] = item_0
                    model_param[param[1]] = item_1

                    if self.filter_name is not None:
                        flux_new[i, j] = self.get_flux(model_param)[0]

                    else:
                        flux_new[i, j, :] = self.get_model(model_param,
                                                           spec_res=spec_res,
                                                           wavel_resample=wavel_resample,
                                                           smooth=smooth).flux

        elif n_param == 3:
            model_param = {}

            for i, item_0 in enumerate(points[0]):
                for j, item_1 in enumerate(points[1]):
                    for k, item_2 in enumerate(points[2]):
                        model_param[param[0]] = item_0
                        model_param[param[1]] = item_1
                        model_param[param[2]] = item_2

                        if self.filter_name is not None:
                            flux_new[i, j, k] = self.get_flux(model_param)[0]

                        else:
                            flux_new[i, j, k, :] = self.get_model(model_param,
                                                                  spec_res=spec_res,
                                                                  wavel_resample=wavel_resample,
                                                                  smooth=smooth).flux

        elif n_param == 4:
            model_param = {}

            for i, item_0 in enumerate(points[0]):
                for j, item_1 in enumerate(points[1]):
                    for k, item_2 in enumerate(points[2]):
                        for m, item_3 in enumerate(points[3]):
                            model_param[param[0]] = item_0
                            model_param[param[1]] = item_1
                            model_param[param[2]] = item_2
                            model_param[param[3]] = item_3

                            if self.filter_name is not None:
                                flux_new[i, j, k, m] = self.get_flux(model_param)[0]

                            else:
                                flux_new[i, j, k, m, :] = self.get_model(
                                    model_param, spec_res=spec_res, wavel_resample=wavel_resample,
                                    smooth=smooth).flux

        elif n_param == 5:
            model_param = {}

            for i, item_0 in enumerate(points[0]):
                for j, item_1 in enumerate(points[1]):
                    for k, item_2 in enumerate(points[2]):
                        for m, item_3 in enumerate(points[3]):
                            for n, item_4 in enumerate(points[4]):
                                model_param[param[0]] = item_0
                                model_param[param[1]] = item_1
                                model_param[param[2]] = item_2
                                model_param[param[3]] = item_3
                                model_param[param[4]] = item_4

                                if self.filter_name is not None:
                                    flux_new[i, j, k, m, n] = self.get_flux(model_param)[0]

                                else:
                                    flux_new[i, j, k, m, n, :] = self.get_model(
                                        model_param, spec_res=spec_res,
                                        wavel_resample=wavel_resample, smooth=smooth).flux

        if self.filter_name is not None:
            transmission = read_filter.ReadFilter(self.filter_name)
            self.wl_points = [transmission.mean_wavelength()]

        else:
            self.wl_points = wavel_resample

        self.spectrum_interp = RegularGridInterpolator(points,
                                                       flux_new,
                                                       method='linear',
                                                       bounds_error=False,
                                                       fill_value=np.nan)
예제 #24
0
def plot_spectrum(boxes,
                  filters,
                  output,
                  colors=None,
                  residuals=None,
                  xlim=None,
                  ylim=None,
                  scale=('linear', 'linear'),
                  title=None,
                  offset=None,
                  legend='upper left',
                  figsize=(7., 5.),
                  object_type='planet'):
    """
    Parameters
    ----------
    boxes : tuple(species.core.box, )
        Boxes with data.
    filters : tuple(str, )
        Filter IDs for which the transmission profile is plotted.
    output : str
        Output filename.
    colors : tuple(str, )
        Colors to be used for the different boxes. Note that a box with residuals requires a tuple
        with two colors (i.e., for the photometry and spectrum). Automatic colors are used if set
        to None.
    residuals : species.core.box.ResidualsBox
        Box with residuals of a fit.
    xlim : tuple(float, float)
        Limits of the x-axis.
    ylim : tuple(float, float)
        Limits of the y-axis.
    scale : tuple(str, str)
        Scale of the axes ('linear' or 'log').
    title : str
        Title.
    offset : tuple(float, float)
        Offset for the label of the x- and y-axis.
    legend : str, None
        Location of the legend.
    figsize : tuple(float, float)
        Figure size.
    object_type : str
        Object type ('planet' or 'star'). With 'planet', the radius and mass are expressed in
        Jupiter units. With 'star', the radius and mass are expressed in solar units.

    Returns
    -------
    NoneType
        None
    """

    marker = itertools.cycle(('o', 's', '*', 'p', '<', '>', 'P', 'v', '^'))

    if residuals and filters:
        plt.figure(1, figsize=figsize)
        gridsp = mpl.gridspec.GridSpec(3, 1, height_ratios=[1, 3, 1])
        gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)

        ax1 = plt.subplot(gridsp[1, 0])
        ax2 = plt.subplot(gridsp[0, 0])
        ax3 = plt.subplot(gridsp[2, 0])

    elif residuals:
        plt.figure(1, figsize=figsize)
        gridsp = mpl.gridspec.GridSpec(2, 1, height_ratios=[4, 1])
        gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)

        ax1 = plt.subplot(gridsp[0, 0])
        ax3 = plt.subplot(gridsp[1, 0])

    elif filters:
        plt.figure(1, figsize=figsize)
        gridsp = mpl.gridspec.GridSpec(2, 1, height_ratios=[1, 4])
        gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)

        ax1 = plt.subplot(gridsp[1, 0])
        ax2 = plt.subplot(gridsp[0, 0])

    else:
        plt.figure(1, figsize=figsize)
        gridsp = mpl.gridspec.GridSpec(1, 1)
        gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)

        ax1 = plt.subplot(gridsp[0, 0])

    ax1.grid(True,
             linestyle=':',
             linewidth=0.7,
             color='gray',
             dashes=(1, 4),
             alpha=0.3,
             zorder=0)

    if residuals:
        labelbottom = False
    else:
        labelbottom = True

    ax1.tick_params(axis='both',
                    which='major',
                    colors='black',
                    labelcolor='black',
                    direction='in',
                    width=0.8,
                    length=5,
                    labelsize=12,
                    top=True,
                    bottom=True,
                    left=True,
                    right=True,
                    labelbottom=labelbottom)

    ax1.tick_params(axis='both',
                    which='minor',
                    colors='black',
                    labelcolor='black',
                    direction='in',
                    width=0.8,
                    length=3,
                    labelsize=12,
                    top=True,
                    bottom=True,
                    left=True,
                    right=True,
                    labelbottom=labelbottom)

    if filters:
        ax2.grid(True,
                 linestyle=':',
                 linewidth=0.7,
                 color='gray',
                 dashes=(1, 4),
                 alpha=0.3,
                 zorder=0)

        ax2.tick_params(axis='both',
                        which='major',
                        colors='black',
                        labelcolor='black',
                        direction='in',
                        width=0.8,
                        length=5,
                        labelsize=12,
                        top=True,
                        bottom=True,
                        left=True,
                        right=True,
                        labelbottom=False)

        ax2.tick_params(axis='both',
                        which='minor',
                        colors='black',
                        labelcolor='black',
                        direction='in',
                        width=0.8,
                        length=3,
                        labelsize=12,
                        top=True,
                        bottom=True,
                        left=True,
                        right=True,
                        labelbottom=False)

    if residuals:
        ax3.grid(True,
                 linestyle=':',
                 linewidth=0.7,
                 color='gray',
                 dashes=(1, 4),
                 alpha=0.3,
                 zorder=0)

        ax3.tick_params(axis='both',
                        which='major',
                        colors='black',
                        labelcolor='black',
                        direction='in',
                        width=0.8,
                        length=5,
                        labelsize=12,
                        top=True,
                        bottom=True,
                        left=True,
                        right=True)

        ax3.tick_params(axis='both',
                        which='minor',
                        colors='black',
                        labelcolor='black',
                        direction='in',
                        width=0.8,
                        length=3,
                        labelsize=12,
                        top=True,
                        bottom=True,
                        left=True,
                        right=True)

    if residuals and filters:
        ax1.set_xlabel('', fontsize=13)
        ax2.set_xlabel('', fontsize=13)
        ax3.set_xlabel('Wavelength [micron]', fontsize=13)

    elif residuals:
        ax1.set_xlabel('', fontsize=13)
        ax3.set_xlabel('Wavelength [micron]', fontsize=13)

    elif filters:
        ax1.set_xlabel('Wavelength [micron]', fontsize=13)
        ax2.set_xlabel('', fontsize=13)

    else:
        ax1.set_xlabel('Wavelength [micron]', fontsize=13)

    if filters:
        ax2.set_ylabel('Transmission', fontsize=13)

    if residuals:
        ax3.set_ylabel(r'Residual [$\sigma$]', fontsize=13)

    if xlim:
        ax1.set_xlim(xlim[0], xlim[1])
    else:
        ax1.set_xlim(0.6, 6.)

    if ylim:
        ax1.set_ylim(ylim[0], ylim[1])

        ylim = ax1.get_ylim()

        exponent = math.floor(math.log10(ylim[1]))
        scaling = 10.**exponent

        ax1.set_ylabel(r'Flux [10$^{' + str(exponent) +
                       r'}$ W m$^{-2}$ $\mu$m$^{-1}$]',
                       fontsize=13)
        ax1.set_ylim(ylim[0] / scaling, ylim[1] / scaling)

        if ylim[0] < 0.:
            ax1.axhline(0.0,
                        linestyle='--',
                        color='gray',
                        dashes=(2, 4),
                        zorder=0.5)

    else:
        ax1.set_ylabel(r'Flux [W m$^{-2}$ $\mu$m$^{-1}$]', fontsize=13)
        scaling = 1.

    if filters:
        ax2.set_ylim(0., 1.)

    xlim = ax1.get_xlim()

    if filters:
        ax2.set_xlim(xlim[0], xlim[1])

    if residuals:
        ax3.set_xlim(xlim[0], xlim[1])

    if offset and residuals and filters:
        ax3.get_xaxis().set_label_coords(0.5, offset[0])

        ax1.get_yaxis().set_label_coords(offset[1], 0.5)
        ax2.get_yaxis().set_label_coords(offset[1], 0.5)
        ax3.get_yaxis().set_label_coords(offset[1], 0.5)

    elif offset and filters:
        ax1.get_xaxis().set_label_coords(0.5, offset[0])

        ax1.get_yaxis().set_label_coords(offset[1], 0.5)
        ax2.get_yaxis().set_label_coords(offset[1], 0.5)

    elif offset and residuals:
        ax3.get_xaxis().set_label_coords(0.5, offset[0])

        ax1.get_yaxis().set_label_coords(offset[1], 0.5)
        ax3.get_yaxis().set_label_coords(offset[1], 0.5)

    elif offset:
        ax1.get_xaxis().set_label_coords(0.5, offset[0])
        ax1.get_yaxis().set_label_coords(offset[1], 0.5)

    else:
        ax1.get_xaxis().set_label_coords(0.5, -0.12)
        ax1.get_yaxis().set_label_coords(-0.1, 0.5)

    ax1.set_xscale(scale[0])
    ax1.set_yscale(scale[1])

    if filters:
        ax2.set_xscale(scale[0])

    if residuals:
        ax3.set_xscale(scale[0])

    color_obj_phot = None
    color_obj_spec = None

    for j, boxitem in enumerate(boxes):
        if isinstance(boxitem, (box.SpectrumBox, box.ModelBox)):
            wavelength = boxitem.wavelength
            flux = boxitem.flux

            if isinstance(wavelength[0], (np.float32, np.float64)):
                data = np.array(flux, dtype=np.float64)
                masked = np.ma.array(data, mask=np.isnan(data))

                if isinstance(boxitem, box.ModelBox):
                    param = boxitem.parameters

                    par_key, par_unit = plot_util.quantity_unit(
                        param=list(param.keys()), object_type=object_type)

                    par_val = list(param.values())

                    label = ''
                    for i, item in enumerate(par_key):

                        if item == r'$T_\mathregular{eff}$':
                            value = f'{par_val[i]:.1f}'
                        elif item in (r'$\log\,g$', '[Fe/H]'):
                            value = f'{par_val[i]:.2f}'
                        elif item == r'$R$':
                            if object_type == 'planet':
                                value = f'{par_val[i]:.2f}'
                            elif object_type == 'star':
                                value = f'{par_val[i]*constants.R_JUP/constants.R_SUN:.2f}'
                        elif item == r'$M$':
                            if object_type == 'planet':
                                value = f'{par_val[i]:.2f}'
                            elif object_type == 'star':
                                value = f'{par_val[i]*constants.M_JUP/constants.M_SUN:.2f}'
                        elif item == r'$L$':
                            value = f'{par_val[i]:.1e}'
                        else:
                            continue

                        label += item + ' = ' + str(value) + ' ' + par_unit[i]

                        if i < len(par_key) - 1:
                            label += ', '

                else:
                    label = None

                if colors:
                    ax1.plot(wavelength,
                             masked / scaling,
                             color=colors[j],
                             lw=0.5,
                             label=label,
                             zorder=2)
                else:
                    ax1.plot(wavelength,
                             masked / scaling,
                             lw=0.5,
                             label=label,
                             zorder=2)

            elif isinstance(wavelength[0], (np.ndarray)):
                for i, item in enumerate(wavelength):
                    data = np.array(flux[i], dtype=np.float64)
                    masked = np.ma.array(data, mask=np.isnan(data))

                    ax1.plot(item, masked / scaling, lw=0.5)

        elif isinstance(boxitem, tuple):
            for i, item in enumerate(boxitem):
                wavelength = item.wavelength
                flux = item.flux

                data = np.array(flux, dtype=np.float64)
                masked = np.ma.array(data, mask=np.isnan(data))

                if colors:
                    ax1.plot(wavelength,
                             masked / scaling,
                             lw=0.2,
                             color=colors[j],
                             alpha=0.5,
                             zorder=1)
                else:
                    ax1.plot(wavelength,
                             masked / scaling,
                             lw=0.2,
                             alpha=0.5,
                             zorder=1)

        elif isinstance(boxitem, box.PhotometryBox):
            if colors:
                ax1.plot(boxitem.wavelength,
                         boxitem.flux / scaling,
                         marker=next(marker),
                         ms=6,
                         color=colors[j],
                         label=boxitem.name,
                         zorder=3)
            else:
                ax1.plot(boxitem.wavelength,
                         boxitem.flux / scaling,
                         marker=next(marker),
                         ms=6,
                         label=boxitem.name,
                         zorder=3)

        elif isinstance(boxitem, box.ObjectBox):
            if boxitem.flux is not None:
                for item in boxitem.flux:
                    transmission = read_filter.ReadFilter(item)
                    wavelength = transmission.mean_wavelength()
                    fwhm = transmission.filter_fwhm()

                    color_obj_phot = colors[j][0]

                    ax1.errorbar(wavelength,
                                 boxitem.flux[item][0] / scaling,
                                 xerr=fwhm / 2.,
                                 yerr=boxitem.flux[item][1] / scaling,
                                 marker='s',
                                 ms=5,
                                 zorder=3,
                                 color=color_obj_phot,
                                 markerfacecolor=color_obj_phot)

            if boxitem.spectrum is not None:
                masked = np.ma.array(boxitem.spectrum,
                                     mask=np.isnan(boxitem.spectrum))

                color_obj_spec = colors[j][1]

                if colors is None:
                    ax1.errorbar(masked[:, 0],
                                 masked[:, 1] / scaling,
                                 yerr=masked[:, 2] / scaling,
                                 ms=2,
                                 marker='s',
                                 zorder=2.5,
                                 ls='none')

                else:
                    ax1.errorbar(masked[:, 0],
                                 masked[:, 1] / scaling,
                                 yerr=masked[:, 2] / scaling,
                                 marker='o',
                                 ms=2,
                                 zorder=2.5,
                                 color=color_obj_spec,
                                 markerfacecolor=color_obj_spec,
                                 ls='none')

        elif isinstance(boxitem, box.SynphotBox):
            for item in boxitem.flux:
                transmission = read_filter.ReadFilter(item)
                wavelength = transmission.mean_wavelength()
                fwhm = transmission.filter_fwhm()

                ax1.errorbar(wavelength,
                             boxitem.flux[item] / scaling,
                             xerr=fwhm / 2.,
                             yerr=None,
                             alpha=0.7,
                             marker='s',
                             ms=5,
                             zorder=4,
                             color=colors[j],
                             markerfacecolor='white')

    if filters:
        for i, item in enumerate(filters):
            transmission = read_filter.ReadFilter(item)
            data = transmission.get_filter()

            ax2.plot(data[0, ],
                     data[1, ],
                     '-',
                     lw=0.7,
                     color='black',
                     zorder=1)

    if residuals:
        res_max = 0.

        if residuals.photometry is not None:
            ax3.plot(residuals.photometry[0, ],
                     residuals.photometry[1, ],
                     marker='s',
                     ms=5,
                     linestyle='none',
                     color=color_obj_phot,
                     zorder=2)

            res_max = np.nanmax(np.abs(residuals.photometry[1, ]))

        if residuals.spectrum is not None:
            ax3.plot(residuals.spectrum[0, ],
                     residuals.spectrum[1, ],
                     marker='o',
                     ms=2,
                     linestyle='none',
                     color=color_obj_spec,
                     zorder=1)

            max_tmp = np.nanmax(np.abs(residuals.spectrum[1, ]))

            if max_tmp > res_max:
                res_max = max_tmp

        res_lim = math.ceil(res_max)

        ax3.axhline(0.0,
                    linestyle='--',
                    color='gray',
                    dashes=(2, 4),
                    zorder=0.5)
        ax3.set_ylim(-res_lim, res_lim)

    if filters:
        ax2.set_ylim(0., 1.1)

    sys.stdout.write('Plotting spectrum: ' + output + '...')
    sys.stdout.flush()

    if title:
        if filters:
            ax2.set_title(title, y=1.02, fontsize=15)
        else:
            ax1.set_title(title, y=1.02, fontsize=15)

    handles, _ = ax1.get_legend_handles_labels()

    if handles and legend:
        ax1.legend(loc=legend, prop={'size': 9}, frameon=False)

    plt.savefig(os.getcwd() + '/' + output, bbox_inches='tight')
    plt.close()

    sys.stdout.write(' [DONE]\n')
    sys.stdout.flush()
예제 #25
0
def interp_lognorm(
    inc_phot: List[str],
    inc_spec: List[str],
    spec_data: Optional[Dict[str, Tuple[np.ndarray, Optional[np.ndarray],
                                        Optional[np.ndarray], float]]],
) -> Tuple[Dict[str, Union[interp2d, List[interp2d]]], np.ndarray, np.ndarray]:
    """
    Function for interpolating the log-normal dust cross sections for
    each filter and spectrum.

    Parameters
    ----------
    inc_phot : list(str)
        List with filter names. Not used if the list is empty.
    inc_spec : list(str)
        List with the spectrum names (as stored in the database with
        :func:`~species.data.database.Database.add_object`). Not used
        if the list is empty.
    spec_data : dict, None
        Dictionary with the spectrum data. Only required in combination
        with ``inc_spec``, otherwise the argument needs to be set to
        ``None``.

    Returns
    -------
    dict
        Dictionary with the extinction cross section for each filter
        and spectrum.
    np.ndarray
        Grid points of the geometric mean radius.
    np.ndarray
        Grid points of the geometric standard deviation.
    """

    database_path = check_dust_database()

    with h5py.File(database_path, "r") as h5_file:
        cross_section = np.asarray(
            h5_file["dust/lognorm/mgsio3/crystalline/cross_section"])
        wavelength = np.asarray(
            h5_file["dust/lognorm/mgsio3/crystalline/wavelength"])
        radius_g = np.asarray(
            h5_file["dust/lognorm/mgsio3/crystalline/radius_g"])
        sigma_g = np.asarray(
            h5_file["dust/lognorm/mgsio3/crystalline/sigma_g"])

    print("Grid boundaries of the dust opacities:")
    print(f"   - Wavelength (um) = {wavelength[0]:.2f} - {wavelength[-1]:.2f}")
    print(
        f"   - Geometric mean radius (um) = {radius_g[0]:.2e} - {radius_g[-1]:.2e}"
    )
    print(
        f"   - Geometric standard deviation = {sigma_g[0]:.2f} - {sigma_g[-1]:.2f}"
    )

    inc_phot.append("Generic/Bessell.V")

    cross_sections = {}

    for phot_item in inc_phot:
        read_filt = read_filter.ReadFilter(phot_item)
        filt_trans = read_filt.get_filter()

        cross_phot = np.zeros((radius_g.shape[0], sigma_g.shape[0]))

        for i in range(radius_g.shape[0]):
            for j in range(sigma_g.shape[0]):
                cross_interp = interp1d(wavelength,
                                        cross_section[:, i, j],
                                        kind="linear",
                                        bounds_error=True)

                cross_tmp = cross_interp(filt_trans[:, 0])

                integral1 = np.trapz(filt_trans[:, 1] * cross_tmp,
                                     filt_trans[:, 0])
                integral2 = np.trapz(filt_trans[:, 1], filt_trans[:, 0])

                # Filter-weighted average of the extinction cross section
                cross_phot[i, j] = integral1 / integral2

        cross_sections[phot_item] = interp2d(sigma_g,
                                             radius_g,
                                             cross_phot,
                                             kind="linear",
                                             bounds_error=True)

    print("Interpolating dust opacities...", end="")

    for spec_item in inc_spec:
        wavel_spec = spec_data[spec_item][0][:, 0]

        cross_spec = np.zeros(
            (wavel_spec.shape[0], radius_g.shape[0], sigma_g.shape[0]))

        for i in range(radius_g.shape[0]):
            for j in range(sigma_g.shape[0]):
                cross_interp = interp1d(wavelength,
                                        cross_section[:, i, j],
                                        kind="linear",
                                        bounds_error=True)

                cross_spec[:, i, j] = cross_interp(wavel_spec)

        cross_sections[spec_item] = []

        for i in range(wavel_spec.shape[0]):
            cross_tmp = interp2d(sigma_g,
                                 radius_g,
                                 cross_spec[i, :, :],
                                 kind="linear",
                                 bounds_error=True)

            cross_sections[spec_item].append(cross_tmp)

    print(" [DONE]")

    return cross_sections, radius_g, sigma_g
예제 #26
0
파일: phot_util.py 프로젝트: gotten/species
def get_residuals(datatype: str,
                  spectrum: str,
                  parameters: Dict[str, float],
                  objectbox: box.ObjectBox,
                  inc_phot: Union[bool, List[str]] = True,
                  inc_spec: Union[bool, List[str]] = True,
                  **kwargs_radtrans: Optional[dict]) -> box.ResidualsBox:
    """
    Parameters
    ----------
    datatype : str
        Data type ('model' or 'calibration').
    spectrum : str
        Name of the atmospheric model or calibration spectrum.
    parameters : dict
        Parameters and values for the spectrum
    objectbox : species.core.box.ObjectBox
        Box with the photometry and/or spectra of an object. A scaling and/or error inflation of
        the spectra should be applied with :func:`~species.util.read_util.update_spectra`
        beforehand.
    inc_phot : bool, list(str)
        Include photometric data in the fit. If a boolean, either all (``True``) or none
        (``False``) of the data are selected. If a list, a subset of filter names (as stored in
        the database) can be provided.
    inc_spec : bool, list(str)
        Include spectroscopic data in the fit. If a boolean, either all (``True``) or none
        (``False``) of the data are selected. If a list, a subset of spectrum names (as stored
        in the database with :func:`~species.data.database.Database.add_object`) can be
        provided.

    Keyword arguments
    -----------------
    kwargs_radtrans : dict
        Dictionary with the keyword arguments for the ``ReadRadtrans`` object, containing
        ``line_species``, ``cloud_species``, and ``scattering``.

    Returns
    -------
    species.core.box.ResidualsBox
        Box with the residuals.
    """

    if 'filters' in kwargs_radtrans:
        warnings.warn('The \'filters\' parameter has been deprecated. Please use the \'inc_phot\' '
                      'parameter instead. The \'filters\' parameter is ignored.')

    if isinstance(inc_phot, bool) and inc_phot:
        inc_phot = objectbox.filters

    if inc_phot:
        model_phot = multi_photometry(datatype=datatype,
                                      spectrum=spectrum,
                                      filters=inc_phot,
                                      parameters=parameters)

        res_phot = {}

        for item in inc_phot:
            transmission = read_filter.ReadFilter(item)
            res_phot[item] = np.zeros(objectbox.flux[item].shape)

            if objectbox.flux[item].ndim == 1:
                res_phot[item][0] = transmission.mean_wavelength()
                res_phot[item][1] = (objectbox.flux[item][0]-model_phot.flux[item]) / \
                    objectbox.flux[item][1]

            elif objectbox.flux[item].ndim == 2:
                for j in range(objectbox.flux[item].shape[1]):
                    res_phot[item][0, j] = transmission.mean_wavelength()
                    res_phot[item][1, j] = (objectbox.flux[item][0, j]-model_phot.flux[item]) / \
                        objectbox.flux[item][1, j]

    else:
        res_phot = None

    if inc_spec:
        res_spec = {}

        readmodel = None

        for key in objectbox.spectrum:
            if isinstance(inc_spec, bool) or key in inc_spec:
                wavel_range = (0.9*objectbox.spectrum[key][0][0, 0],
                               1.1*objectbox.spectrum[key][0][-1, 0])

                wl_new = objectbox.spectrum[key][0][:, 0]
                spec_res = objectbox.spectrum[key][3]

                if spectrum == 'planck':
                    readmodel = read_planck.ReadPlanck(wavel_range=wavel_range)

                    model = readmodel.get_spectrum(model_param=parameters, spec_res=1000.)

                    flux_new = spectres.spectres(wl_new,
                                                 model.wavelength,
                                                 model.flux,
                                                 spec_errs=None,
                                                 fill=0.,
                                                 verbose=True)

                else:
                    if spectrum == 'petitradtrans':
                        # TODO change back
                        pass

                        # radtrans = read_radtrans.ReadRadtrans(line_species=kwargs_radtrans['line_species'],
                        #                                       cloud_species=kwargs_radtrans['cloud_species'],
                        #                                       scattering=kwargs_radtrans['scattering'],
                        #                                       wavel_range=wavel_range)
                        #
                        # model = radtrans.get_model(parameters, spec_res=None)
                        #
                        # # separate resampling to the new wavelength points
                        #
                        # flux_new = spectres.spectres(wl_new,
                        #                              model.wavelength,
                        #                              model.flux,
                        #                              spec_errs=None,
                        #                              fill=0.,
                        #                              verbose=True)

                    else:
                        readmodel = read_model.ReadModel(spectrum, wavel_range=wavel_range)

                        # resampling to the new wavelength points is done in teh get_model function

                        model_spec = readmodel.get_model(parameters,
                                                         spec_res=spec_res,
                                                         wavel_resample=wl_new,
                                                         smooth=True)

                        flux_new = model_spec.flux

                data_spec = objectbox.spectrum[key][0]
                res_tmp = (data_spec[:, 1]-flux_new) / data_spec[:, 2]

                res_spec[key] = np.column_stack([wl_new, res_tmp])

    else:
        res_spec = None

    print('Calculating residuals... [DONE]')

    print('Residuals (sigma):')

    if res_phot is not None:
        for item in inc_phot:
            if res_phot[item].ndim == 1:
                print(f'   - {item}: {res_phot[item][1]:.2f}')

            elif res_phot[item].ndim == 2:
                for j in range(res_phot[item].shape[1]):
                    print(f'   - {item}: {res_phot[item][1, j]:.2f}')

    if res_spec is not None:
        for key in objectbox.spectrum:
            if isinstance(inc_spec, bool) or key in inc_spec:
                print(f'   - {key}: min: {np.nanmin(res_spec[key]):.2f}, '
                      f'max: {np.nanmax(res_spec[key]):.2f}')

    return box.create_box(boxtype='residuals',
                          name=objectbox.name,
                          photometry=res_phot,
                          spectrum=res_spec)
예제 #27
0
def plot_spectrum(
    boxes: list,
    filters: Optional[List[str]] = None,
    residuals: Optional[box.ResidualsBox] = None,
    plot_kwargs: Optional[List[Optional[dict]]] = None,
    xlim: Optional[Tuple[float, float]] = None,
    ylim: Optional[Tuple[float, float]] = None,
    ylim_res: Optional[Tuple[float, float]] = None,
    scale: Optional[Tuple[str, str]] = None,
    title: Optional[str] = None,
    offset: Optional[Tuple[float, float]] = None,
    legend: Optional[Union[str, dict, Tuple[float, float],
                           List[Optional[Union[dict, str,
                                               Tuple[float,
                                                     float]]]], ]] = None,
    figsize: Optional[Tuple[float, float]] = (10.0, 5.0),
    object_type: str = "planet",
    quantity: str = "flux density",
    output: Optional[str] = "spectrum.pdf",
    leg_param: Optional[List[str]] = None,
):
    """
    Function for plotting a spectral energy distribution and combining various data such as spectra,
    photometric fluxes, model spectra, synthetic photometry, fit residuals, and filter profiles.

    Parameters
    ----------
    boxes : list(species.core.box)
        Boxes with data.
    filters : list(str), None
        Filter IDs for which the transmission profile is plotted. Not plotted if set to None.
    residuals : species.core.box.ResidualsBox, None
        Box with residuals of a fit. Not plotted if set to None.
    plot_kwargs : list(dict), None
        List with dictionaries of keyword arguments for each box. For example, if the ``boxes``
        are a ``ModelBox`` and ``ObjectBox``:

        .. code-block:: python

            plot_kwargs=[{'ls': '-', 'lw': 1., 'color': 'black'},
                         {'spectrum_1': {'marker': 'o', 'ms': 3., 'color': 'tab:brown', 'ls': 'none'},
                          'spectrum_2': {'marker': 'o', 'ms': 3., 'color': 'tab:blue', 'ls': 'none'},
                          'Paranal/SPHERE.IRDIS_D_H23_3': {'marker': 's', 'ms': 4., 'color': 'tab:cyan', 'ls': 'none'},
                          'Paranal/SPHERE.IRDIS_D_K12_1': [{'marker': 's', 'ms': 4., 'color': 'tab:orange', 'ls': 'none'},
                                                           {'marker': 's', 'ms': 4., 'color': 'tab:red', 'ls': 'none'}],
                          'Paranal/NACO.Lp': {'marker': 's', 'ms': 4., 'color': 'tab:green', 'ls': 'none'},
                          'Paranal/NACO.Mp': {'marker': 's', 'ms': 4., 'color': 'tab:green', 'ls': 'none'}}]

        For an ``ObjectBox``, the dictionary contains items for the different spectrum and filter
        names stored with :func:`~species.data.database.Database.add_object`. In case both
        and ``ObjectBox`` and a ``SynphotBox`` are provided, then the latter can be set to ``None``
        in order to use the same (but open) symbols as the data from the ``ObjectBox``. Note that
        if a filter name is duplicated in an ``ObjectBox`` (Paranal/SPHERE.IRDIS_D_K12_1 in the
        example) then a list with two dictionaries should be provided. Colors are automatically
        chosen if ``plot_kwargs`` is set to ``None``.
    xlim : tuple(float, float)
        Limits of the wavelength axis.
    ylim : tuple(float, float)
        Limits of the flux axis.
    ylim_res : tuple(float, float), None
        Limits of the residuals axis. Automatically chosen (based on the minimum and maximum
        residual value) if set to None.
    scale : tuple(str, str), None
        Scale of the x and y axes ('linear' or 'log'). The scale is set to ``('linear', 'linear')``
        if set to ``None``.
    title : str
        Title.
    offset : tuple(float, float)
        Offset for the label of the x- and y-axis.
    legend : str, tuple, dict, list(dict, dict), None
        Location of the legend (str or tuple(float, float)) or a dictionary with the ``**kwargs``
        of ``matplotlib.pyplot.legend``, for example ``{'loc': 'upper left', 'fontsize: 12.}``.
        Alternatively, a list with two values can be provided to separate the model and data
        handles in two legends. Each of these two elements can be set to ``None``. For example,
        ``[None, {'loc': 'upper left', 'fontsize: 12.}]``, if only the data points should be
        included in a legend.
    figsize : tuple(float, float)
        Figure size.
    object_type : str
        Object type ('planet' or 'star'). With 'planet', the radius and mass are expressed in
        Jupiter units. With 'star', the radius and mass are expressed in solar units.
    quantity: str
        The quantity of the y-axis ('flux density', 'flux', or 'magnitude').
    output : str
        Output filename for the plot. The plot is shown in an
        interface window if the argument is set to ``None``.
    leg_param : list(str), None
        List with the parameters to include in the legend of the model spectra. Apart from
        atmospheric parameters (e.g. 'teff', 'logg', 'radius') also parameters such as 'mass'
        and 'luminosity' can be included. The default atmospheric parameters are included in the
        legend if the argument is set to ``None``.

    Returns
    -------
    NoneType
        None
    """

    mpl.rcParams["font.serif"] = ["Bitstream Vera Serif"]
    mpl.rcParams["font.family"] = "serif"

    plt.rc("axes", edgecolor="black", linewidth=2.2)
    plt.rcParams["axes.axisbelow"] = False

    if plot_kwargs is None:
        plot_kwargs = []

    elif plot_kwargs is not None and len(boxes) != len(plot_kwargs):
        raise ValueError(
            f"The number of 'boxes' ({len(boxes)}) should be equal to the "
            f"number of items in 'plot_kwargs' ({len(plot_kwargs)}).")

    if residuals is not None and filters is not None:
        plt.figure(1, figsize=figsize)
        gridsp = mpl.gridspec.GridSpec(3, 1, height_ratios=[1, 3, 1])
        gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)

        ax1 = plt.subplot(gridsp[1, 0])
        ax2 = plt.subplot(gridsp[0, 0])
        ax3 = plt.subplot(gridsp[2, 0])

    elif residuals is not None:
        plt.figure(1, figsize=figsize)
        gridsp = mpl.gridspec.GridSpec(2, 1, height_ratios=[4, 1])
        gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)

        ax1 = plt.subplot(gridsp[0, 0])
        ax2 = None
        ax3 = plt.subplot(gridsp[1, 0])

    elif filters is not None:
        plt.figure(1, figsize=figsize)
        gridsp = mpl.gridspec.GridSpec(2, 1, height_ratios=[1, 4])
        gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)

        ax1 = plt.subplot(gridsp[1, 0])
        ax2 = plt.subplot(gridsp[0, 0])
        ax3 = None

    else:
        plt.figure(1, figsize=figsize)
        gridsp = mpl.gridspec.GridSpec(1, 1)
        gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)

        ax1 = plt.subplot(gridsp[0, 0])
        ax2 = None
        ax3 = None

    if residuals is not None:
        labelbottom = False
    else:
        labelbottom = True

    if scale is None:
        scale = ("linear", "linear")

    ax1.set_xscale(scale[0])
    ax1.set_yscale(scale[1])

    if filters is not None:
        ax2.set_xscale(scale[0])

    if residuals is not None:
        ax3.set_xscale(scale[0])

    ax1.tick_params(
        axis="both",
        which="major",
        colors="black",
        labelcolor="black",
        direction="in",
        width=1,
        length=5,
        labelsize=12,
        top=True,
        bottom=True,
        left=True,
        right=True,
        labelbottom=labelbottom,
    )

    ax1.tick_params(
        axis="both",
        which="minor",
        colors="black",
        labelcolor="black",
        direction="in",
        width=1,
        length=3,
        labelsize=12,
        top=True,
        bottom=True,
        left=True,
        right=True,
        labelbottom=labelbottom,
    )

    if filters is not None:
        ax2.tick_params(
            axis="both",
            which="major",
            colors="black",
            labelcolor="black",
            direction="in",
            width=1,
            length=5,
            labelsize=12,
            top=True,
            bottom=True,
            left=True,
            right=True,
            labelbottom=False,
        )

        ax2.tick_params(
            axis="both",
            which="minor",
            colors="black",
            labelcolor="black",
            direction="in",
            width=1,
            length=3,
            labelsize=12,
            top=True,
            bottom=True,
            left=True,
            right=True,
            labelbottom=False,
        )

    if residuals is not None:
        ax3.tick_params(
            axis="both",
            which="major",
            colors="black",
            labelcolor="black",
            direction="in",
            width=1,
            length=5,
            labelsize=12,
            top=True,
            bottom=True,
            left=True,
            right=True,
        )

        ax3.tick_params(
            axis="both",
            which="minor",
            colors="black",
            labelcolor="black",
            direction="in",
            width=1,
            length=3,
            labelsize=12,
            top=True,
            bottom=True,
            left=True,
            right=True,
        )

    if scale[0] == "linear":
        ax1.xaxis.set_minor_locator(AutoMinorLocator(5))

    if scale[1] == "linear":
        ax1.yaxis.set_minor_locator(AutoMinorLocator(5))

    # ax1.set_yticks([1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0])
    # ax3.set_yticks([-2., 0., 2.])

    if filters is not None:
        if scale[0] == "linear":
            ax2.xaxis.set_minor_locator(AutoMinorLocator(5))

    if residuals is not None:
        if scale[0] == "linear":
            ax3.xaxis.set_minor_locator(AutoMinorLocator(5))

    if residuals is not None and filters is not None:
        ax1.set_xlabel("")
        ax2.set_xlabel("")
        ax3.set_xlabel("Wavelength (µm)", fontsize=13)

    elif residuals is not None:
        ax1.set_xlabel("")
        ax3.set_xlabel("Wavelength (µm)", fontsize=11)

    elif filters is not None:
        ax1.set_xlabel("Wavelength (µm)", fontsize=13)
        ax2.set_xlabel("")

    else:
        ax1.set_xlabel("Wavelength (µm)", fontsize=13)

    if filters is not None:
        ax2.set_ylabel(r"T$_\lambda$", fontsize=13)

    if residuals is not None:
        if quantity == "flux density":
            ax3.set_ylabel(r"$\Delta$$\mathregular{F}_\lambda$ ($\sigma$)",
                           fontsize=11)

        elif quantity == "flux":
            ax3.set_ylabel(r"$\Delta$$\mathregular{F}_\lambda$ ($\sigma$)",
                           fontsize=11)

    if xlim is None:
        ax1.set_xlim(0.6, 6.0)
    else:
        ax1.set_xlim(xlim[0], xlim[1])

    if quantity == "magnitude":
        scaling = 1.0
        ax1.set_ylabel("Flux contrast (mag)", fontsize=13)

        if ylim:
            ax1.set_ylim(ylim[0], ylim[1])

    else:
        if ylim:
            ax1.set_ylim(ylim[0], ylim[1])

            ylim = ax1.get_ylim()

            exponent = math.floor(math.log10(ylim[1]))
            scaling = 10.0**exponent

            if quantity == "flux density":
                ylabel = (r"$\mathregular{F}_\lambda$ (10$^{" + str(exponent) +
                          r"}$ W m$^{-2}$ µm$^{-1}$)")

            elif quantity == "flux":
                ylabel = (r"$\lambda$$\mathregular{F}_\lambda$ (10$^{" +
                          str(exponent) + r"}$ W m$^{-2}$)")

            ax1.set_ylabel(ylabel, fontsize=11)
            ax1.set_ylim(ylim[0] / scaling, ylim[1] / scaling)

            if ylim[0] < 0.0:
                ax1.axhline(0.0,
                            ls="--",
                            lw=0.7,
                            color="gray",
                            dashes=(2, 4),
                            zorder=0.5)

        else:
            if quantity == "flux density":
                ax1.set_ylabel(
                    r"$\mathregular{F}_\lambda$ (W m$^{-2}$ µm$^{-1}$)",
                    fontsize=11)

            elif quantity == "flux":
                ax1.set_ylabel(
                    r"$\lambda$$\mathregular{F}_\lambda$ (W m$^{-2}$)",
                    fontsize=11)

            scaling = 1.0

    xlim = ax1.get_xlim()

    if filters is not None:
        ax2.set_xlim(xlim[0], xlim[1])
        ax2.set_ylim(0.0, 1.0)

    if residuals is not None:
        ax3.set_xlim(xlim[0], xlim[1])

    if offset is not None and residuals is not None and filters is not None:
        ax3.get_xaxis().set_label_coords(0.5, offset[0])

        ax1.get_yaxis().set_label_coords(offset[1], 0.5)
        ax2.get_yaxis().set_label_coords(offset[1], 0.5)
        ax3.get_yaxis().set_label_coords(offset[1], 0.5)

    elif offset is not None and filters is not None:
        ax1.get_xaxis().set_label_coords(0.5, offset[0])

        ax1.get_yaxis().set_label_coords(offset[1], 0.5)
        ax2.get_yaxis().set_label_coords(offset[1], 0.5)

    elif offset is not None and residuals is not None:
        ax3.get_xaxis().set_label_coords(0.5, offset[0])

        ax1.get_yaxis().set_label_coords(offset[1], 0.5)
        ax3.get_yaxis().set_label_coords(offset[1], 0.5)

    elif offset is not None:
        ax1.get_xaxis().set_label_coords(0.5, offset[0])
        ax1.get_yaxis().set_label_coords(offset[1], 0.5)

    else:
        ax1.get_xaxis().set_label_coords(0.5, -0.12)
        ax1.get_yaxis().set_label_coords(-0.1, 0.5)

    for j, boxitem in enumerate(boxes):
        flux_scaling = 1.0

        if j < len(boxes):
            plot_kwargs.append(None)

        if isinstance(boxitem, (box.SpectrumBox, box.ModelBox)):
            wavelength = boxitem.wavelength
            flux = boxitem.flux

            if isinstance(wavelength[0], (np.float32, np.float64)):
                data = np.array(flux, dtype=np.float64)
                masked = np.ma.array(data, mask=np.isnan(data))

                if isinstance(boxitem, box.ModelBox):
                    param = boxitem.parameters

                    if leg_param is not None:
                        for item in list(param.keys()):
                            if item not in leg_param:
                                del param[item]

                    par_key, par_unit, par_label = plot_util.quantity_unit(
                        param=list(param.keys()), object_type=object_type)

                    label = ""
                    # newline = False

                    for i, item in enumerate(par_key):
                        if item[:4] == "teff":
                            value = f"{param[item]:.0f}"

                        elif item in [
                                "logg",
                                "feh",
                                "metallicity",
                                "lognorm_ext",
                                "powerlaw_ext",
                                "ism_ext",
                        ]:

                            value = f"{param[item]:.1f}"

                        elif item in ["co", "c_o_ratio"]:

                            value = f"{param[item]:.2f}"

                        elif item[:6] == "radius":
                            if object_type == "planet":
                                value = f"{param[item]:.1f}"

                                # if item == 'radius_1':
                                #     value = f'{param[item]:.0f}'
                                # else:
                                #     value = f'{param[item]:.1f}'

                            elif object_type == "star":
                                value = (
                                    f"{param[item]*constants.R_JUP/constants.R_SUN:.1f}"
                                )

                        elif (item == "mass" and leg_param is not None
                              and item in leg_param):
                            if object_type == "planet":
                                value = f"{param[item]:.0f}"

                            elif object_type == "star":
                                value = (
                                    f"{param[item]*constants.M_JUP/constants.M_SUN:.1f}"
                                )

                        elif (item == "luminosity" and leg_param is not None
                              and item in leg_param):
                            value = f"{np.log10(param[item]):.2f}"

                        else:
                            continue

                        # if len(label) > 80 and newline == False:
                        #     label += '\n'
                        #     newline = True

                        if par_unit[i] is None:
                            if len(label) > 0:
                                label += ", "

                            label += f"{par_label[i]} = {value}"

                        else:
                            if len(label) > 0:
                                label += ", "

                            label += f"{par_label[i]} = {value} {par_unit[i]}"

                else:
                    label = None

                if plot_kwargs[j]:
                    kwargs_copy = plot_kwargs[j].copy()

                    if "label" in kwargs_copy:
                        if kwargs_copy["label"] is None:
                            label = None
                        else:
                            label = kwargs_copy["label"]

                        del kwargs_copy["label"]

                    if quantity == "flux":
                        flux_scaling = wavelength

                    if "zorder" not in kwargs_copy:
                        kwargs_copy["zorder"] = 2.

                    ax1.plot(
                        wavelength,
                        flux_scaling * masked / scaling,
                        label=label,
                        **kwargs_copy,
                    )

                else:
                    if quantity == "flux":
                        flux_scaling = wavelength

                    ax1.plot(
                        wavelength,
                        flux_scaling * masked / scaling,
                        lw=0.5,
                        label=label,
                        zorder=2,
                    )

            elif isinstance(wavelength[0], (np.ndarray)):
                for i, item in enumerate(wavelength):
                    data = np.array(flux[i], dtype=np.float64)
                    masked = np.ma.array(data, mask=np.isnan(data))

                    if isinstance(boxitem.name[i], bytes):
                        label = boxitem.name[i].decode("utf-8")
                    else:
                        label = boxitem.name[i]

                    if quantity == "flux":
                        flux_scaling = item

                    ax1.plot(item,
                             flux_scaling * masked / scaling,
                             lw=0.5,
                             label=label)

        elif isinstance(boxitem, list):
            for i, item in enumerate(boxitem):
                wavelength = item.wavelength
                flux = item.flux

                data = np.array(flux, dtype=np.float64)
                masked = np.ma.array(data, mask=np.isnan(data))

                if quantity == "flux":
                    flux_scaling = wavelength

                if plot_kwargs[j]:
                    if "zorder" not in plot_kwargs[j]:
                        plot_kwargs[j]["zorder"] = 1.

                    ax1.plot(
                        wavelength,
                        flux_scaling * masked / scaling,
                        **plot_kwargs[j],
                    )
                else:
                    ax1.plot(
                        wavelength,
                        flux_scaling * masked / scaling,
                        color="gray",
                        lw=0.2,
                        alpha=0.5,
                        zorder=1,
                    )

        elif isinstance(boxitem, box.PhotometryBox):
            label_check = []

            for i, item in enumerate(boxitem.wavelength):
                transmission = read_filter.ReadFilter(boxitem.filter_name[i])
                fwhm = transmission.filter_fwhm()

                if quantity == "flux":
                    flux_scaling = item

                if plot_kwargs[j]:
                    if ("label" in plot_kwargs[j]
                            and plot_kwargs[j]["label"] not in label_check):
                        label_check.append(plot_kwargs[j]["label"])

                    elif ("label" in plot_kwargs[j]
                          and plot_kwargs[j]["label"] in label_check):
                        del plot_kwargs[j]["label"]

                    if boxitem.flux[i][1] is None:
                        if "zorder" not in plot_kwargs[j]:
                            plot_kwargs[j]["zorder"] = 3.

                        ax1.errorbar(
                            item,
                            flux_scaling * boxitem.flux[i][0] / scaling,
                            xerr=fwhm / 2.0,
                            yerr=None,
                            **plot_kwargs[j],
                        )

                    else:
                        if "zorder" not in plot_kwargs[j]:
                            plot_kwargs[j]["zorder"] = 3.

                        ax1.errorbar(
                            item,
                            flux_scaling * boxitem.flux[i][0] / scaling,
                            xerr=fwhm / 2.0,
                            yerr=flux_scaling * boxitem.flux[i][1] / scaling,
                            **plot_kwargs[j],
                        )

                else:
                    if boxitem.flux[i][1] is None:
                        ax1.errorbar(
                            item,
                            flux_scaling * boxitem.flux[i][0] / scaling,
                            xerr=fwhm / 2.0,
                            yerr=None,
                            marker="s",
                            ms=6,
                            color="black",
                            zorder=3,
                        )

                    else:
                        ax1.errorbar(
                            item,
                            flux_scaling * boxitem.flux[i][0] / scaling,
                            xerr=fwhm / 2.0,
                            yerr=flux_scaling * boxitem.flux[i][1] / scaling,
                            marker="s",
                            ms=6,
                            color="black",
                            zorder=3,
                        )

        elif isinstance(boxitem, box.ObjectBox):
            if boxitem.spectrum is not None:
                spec_list = []
                wavel_list = []

                for item in boxitem.spectrum:
                    spec_list.append(item)
                    wavel_list.append(boxitem.spectrum[item][0][0, 0])

                sort_index = np.argsort(wavel_list)
                spec_sort = []

                for i in range(sort_index.size):
                    spec_sort.append(spec_list[sort_index[i]])

                for key in spec_sort:
                    masked = np.ma.array(
                        boxitem.spectrum[key][0],
                        mask=np.isnan(boxitem.spectrum[key][0]),
                    )

                    if quantity == "flux":
                        flux_scaling = masked[:, 0]

                    if not plot_kwargs[j] or key not in plot_kwargs[j]:
                        plot_obj = ax1.errorbar(
                            masked[:, 0],
                            flux_scaling * masked[:, 1] / scaling,
                            yerr=flux_scaling * masked[:, 2] / scaling,
                            ms=2,
                            marker="s",
                            zorder=2.5,
                            ls="none",
                        )

                        if plot_kwargs[j] is None:
                            plot_kwargs[j] = {}

                        plot_kwargs[j][key] = {
                            "marker": "s",
                            "ms": 2.0,
                            "ls": "none",
                            "color": plot_obj[0].get_color(),
                        }

                    elif "marker" not in plot_kwargs[j][key]:
                        # Plot the spectrum as a line without error bars
                        # (e.g. when the spectrum has a high spectral resolution)
                        plot_obj = ax1.plot(
                            masked[:, 0],
                            flux_scaling * masked[:, 1] / scaling,
                            **plot_kwargs[j][key],
                        )

                    else:
                        if "zorder" not in plot_kwargs[j][key]:
                            plot_kwargs[j][key]["zorder"] = 2.5

                        ax1.errorbar(
                            masked[:, 0],
                            flux_scaling * masked[:, 1] / scaling,
                            yerr=flux_scaling * masked[:, 2] / scaling,
                            **plot_kwargs[j][key],
                        )

            if boxitem.flux is not None:
                filter_list = []
                wavel_list = []

                for item in boxitem.flux:
                    read_filt = read_filter.ReadFilter(item)
                    filter_list.append(item)
                    wavel_list.append(read_filt.mean_wavelength())

                sort_index = np.argsort(wavel_list)
                filter_sort = []

                for i in range(sort_index.size):
                    filter_sort.append(filter_list[sort_index[i]])

                for item in filter_sort:
                    transmission = read_filter.ReadFilter(item)
                    wavelength = transmission.mean_wavelength()
                    fwhm = transmission.filter_fwhm()

                    if not plot_kwargs[j] or item not in plot_kwargs[j]:
                        if not plot_kwargs[j]:
                            plot_kwargs[j] = {}

                        if quantity == "flux":
                            flux_scaling = wavelength

                        scale_tmp = flux_scaling / scaling

                        if isinstance(boxitem.flux[item][0], np.ndarray):
                            for i in range(boxitem.flux[item].shape[1]):

                                plot_obj = ax1.errorbar(
                                    wavelength,
                                    scale_tmp * boxitem.flux[item][0, i],
                                    xerr=fwhm / 2.0,
                                    yerr=scale_tmp * boxitem.flux[item][1, i],
                                    marker="s",
                                    ms=5,
                                    zorder=3,
                                    color="black",
                                )

                        else:

                            plot_obj = ax1.errorbar(
                                wavelength,
                                scale_tmp * boxitem.flux[item][0],
                                xerr=fwhm / 2.0,
                                yerr=scale_tmp * boxitem.flux[item][1],
                                marker="s",
                                ms=5,
                                zorder=3,
                                color="black",
                            )

                        plot_kwargs[j][item] = {
                            "marker": "s",
                            "ms": 5.0,
                            "color": plot_obj[0].get_color(),
                        }

                    else:
                        if quantity == "flux":
                            flux_scaling = wavelength

                        if isinstance(boxitem.flux[item][0], np.ndarray):
                            if not isinstance(plot_kwargs[j][item], list):
                                raise ValueError(
                                    f"A list with {boxitem.flux[item].shape[1]} "
                                    f"dictionaries are required because the filter "
                                    f"{item} has {boxitem.flux[item].shape[1]} "
                                    f"values.")

                            for i in range(boxitem.flux[item].shape[1]):
                                if "zorder" not in plot_kwargs[j][item][i]:
                                    plot_kwargs[j][item][i]["zorder"] = 3.

                                ax1.errorbar(
                                    wavelength,
                                    flux_scaling * boxitem.flux[item][0, i] /
                                    scaling,
                                    xerr=fwhm / 2.0,
                                    yerr=flux_scaling *
                                    boxitem.flux[item][1, i] / scaling,
                                    **plot_kwargs[j][item][i],
                                )

                        else:
                            if boxitem.flux[item][1] == 0.0:
                                if "zorder" not in plot_kwargs[j][item]:
                                    plot_kwargs[j][item]["zorder"] = 3.

                                ax1.errorbar(
                                    wavelength,
                                    flux_scaling * boxitem.flux[item][0] /
                                    scaling,
                                    xerr=fwhm / 2.0,
                                    yerr=0.5 * flux_scaling *
                                    boxitem.flux[item][0] / scaling,
                                    uplims=True,
                                    capsize=2.0,
                                    capthick=0.0,
                                    **plot_kwargs[j][item],
                                )

                            else:
                                if "zorder" not in plot_kwargs[j][item]:
                                    plot_kwargs[j][item]["zorder"] = 3.

                                ax1.errorbar(
                                    wavelength,
                                    flux_scaling * boxitem.flux[item][0] /
                                    scaling,
                                    xerr=fwhm / 2.0,
                                    yerr=flux_scaling * boxitem.flux[item][1] /
                                    scaling,
                                    **plot_kwargs[j][item],
                                )

        elif isinstance(boxitem, box.SynphotBox):
            for i, find_item in enumerate(boxes):
                if isinstance(find_item, box.ObjectBox):
                    obj_index = i
                    break

            for item in boxitem.flux:
                transmission = read_filter.ReadFilter(item)
                wavelength = transmission.mean_wavelength()
                fwhm = transmission.filter_fwhm()

                if quantity == "flux":
                    flux_scaling = wavelength

                if not plot_kwargs[obj_index] or item not in plot_kwargs[
                        obj_index]:
                    ax1.errorbar(
                        wavelength,
                        flux_scaling * boxitem.flux[item] / scaling,
                        xerr=fwhm / 2.0,
                        yerr=None,
                        alpha=0.7,
                        marker="s",
                        ms=5,
                        zorder=4,
                        mfc="white",
                    )

                else:
                    if isinstance(plot_kwargs[obj_index][item], list):
                        # In case of multiple photometry values for the same filter, use the
                        # plot_kwargs of the first data point

                        kwargs_copy = plot_kwargs[obj_index][item][0].copy()

                        if "label" in kwargs_copy:
                            del kwargs_copy["label"]

                        if "zorder" not in kwargs_copy:
                            kwargs_copy["zorder"] = 4.

                        ax1.errorbar(
                            wavelength,
                            flux_scaling * boxitem.flux[item] / scaling,
                            xerr=fwhm / 2.0,
                            yerr=None,
                            mfc="white",
                            **kwargs_copy,
                        )

                    else:
                        kwargs_copy = plot_kwargs[obj_index][item].copy()

                        if "label" in kwargs_copy:
                            del kwargs_copy["label"]

                        if "mfc" in kwargs_copy:
                            del kwargs_copy["mfc"]

                        if "zorder" not in kwargs_copy:
                            kwargs_copy["zorder"] = 4.

                        ax1.errorbar(
                            wavelength,
                            flux_scaling * boxitem.flux[item] / scaling,
                            xerr=fwhm / 2.0,
                            yerr=None,
                            mfc="white",
                            **kwargs_copy,
                        )

    if filters is not None:
        for i, item in enumerate(filters):
            transmission = read_filter.ReadFilter(item)
            data = transmission.get_filter()

            ax2.plot(data[:, 0],
                     data[:, 1],
                     "-",
                     lw=0.7,
                     color="black",
                     zorder=1)

    if residuals is not None:
        for i, find_item in enumerate(boxes):
            if isinstance(find_item, box.ObjectBox):
                obj_index = i
                break

        res_max = 0.0

        if residuals.photometry is not None:
            for item in residuals.photometry:
                if not plot_kwargs[obj_index] or item not in plot_kwargs[
                        obj_index]:
                    ax3.plot(
                        residuals.photometry[item][0],
                        residuals.photometry[item][1],
                        marker="s",
                        ms=5,
                        linestyle="none",
                        zorder=2,
                    )

                else:
                    if residuals.photometry[item].ndim == 1:
                        if "zorder" not in plot_kwargs[obj_index][item]:
                            plot_kwargs[obj_index][item]["zorder"] = 2.

                        ax3.errorbar(
                            residuals.photometry[item][0],
                            residuals.photometry[item][1],
                            **plot_kwargs[obj_index][item],
                        )

                    elif residuals.photometry[item].ndim == 2:
                        for i in range(residuals.photometry[item].shape[1]):
                            if isinstance(plot_kwargs[obj_index][item], list):
                                if "zorder" not in plot_kwargs[obj_index][
                                        item][i]:
                                    plot_kwargs[obj_index][item][i][
                                        "zorder"] = 2.

                                ax3.errorbar(
                                    residuals.photometry[item][0, i],
                                    residuals.photometry[item][1, i],
                                    **plot_kwargs[obj_index][item][i],
                                )

                            else:
                                if "zorder" not in plot_kwargs[obj_index][
                                        item]:
                                    plot_kwargs[obj_index][item]["zorder"] = 2.

                                ax3.errorbar(
                                    residuals.photometry[item][0, i],
                                    residuals.photometry[item][1, i],
                                    **plot_kwargs[obj_index][item],
                                )

                finite = np.isfinite(residuals.photometry[item][1])
                res_max = np.max(np.abs(residuals.photometry[item][1][finite]))

        if residuals.spectrum is not None:
            for key, value in residuals.spectrum.items():
                if not plot_kwargs[obj_index] or key not in plot_kwargs[
                        obj_index]:
                    ax3.errorbar(value[:, 0],
                                 value[:, 1],
                                 marker="o",
                                 ms=2,
                                 ls="none",
                                 zorder=1)

                else:
                    if "zorder" not in plot_kwargs[obj_index][key]:
                        plot_kwargs[obj_index][key]["zorder"] = 1.

                    ax3.errorbar(
                        value[:, 0],
                        value[:, 1],
                        **plot_kwargs[obj_index][key],
                    )

                max_tmp = np.nanmax(np.abs(value[:, 1]))

                if max_tmp > res_max:
                    res_max = max_tmp

        res_lim = math.ceil(1.1 * res_max)

        if res_lim > 10.0:
            res_lim = 5.0

        ax3.axhline(0.0,
                    ls="--",
                    lw=0.7,
                    color="gray",
                    dashes=(2, 4),
                    zorder=0.5)

        if res_lim > 5.0 or (ylim_res is not None and ylim_res[0] < -5.0
                             and ylim_res[1] > 5.0):
            ax3.axhline(-5.0,
                        ls=":",
                        lw=0.7,
                        color="gray",
                        dashes=(1, 4),
                        zorder=0.5)
            ax3.axhline(5.0,
                        ls=":",
                        lw=0.7,
                        color="gray",
                        dashes=(1, 4),
                        zorder=0.5)

        if ylim_res is None:
            ax3.set_ylim(-res_lim, res_lim)

        else:
            ax3.set_ylim(ylim_res[0], ylim_res[1])

    if filters is not None:
        ax2.set_ylim(0.0, 1.1)

    if output is None:
        print("Plotting spectrum...", end="", flush=True)
    else:
        print(f"Plotting spectrum: {output}...", end="", flush=True)

    if title is not None:
        if filters:
            ax2.set_title(title, y=1.02, fontsize=13)
        else:
            ax1.set_title(title, y=1.02, fontsize=13)

    handles, labels = ax1.get_legend_handles_labels()

    if handles and legend is not None:
        if isinstance(legend, list):
            model_handles = []
            data_handles = []

            model_labels = []
            data_labels = []

            for i, item in enumerate(handles):
                if isinstance(item, mpl.lines.Line2D):
                    model_handles.append(item)
                    model_labels.append(labels[i])

                elif isinstance(item, mpl.container.ErrorbarContainer):
                    data_handles.append(item)
                    data_labels.append(labels[i])

                else:
                    warnings.warn(
                        f"The object type {item} is not implemented for the legend."
                    )

            if legend[0] is not None:
                if isinstance(legend[0], (str, tuple)):
                    leg_1 = ax1.legend(
                        model_handles,
                        model_labels,
                        loc=legend[0],
                        fontsize=10.0,
                        frameon=False,
                    )
                else:
                    leg_1 = ax1.legend(model_handles, model_labels,
                                       **legend[0])

            else:
                leg_1 = None

            if legend[1] is not None:
                if isinstance(legend[1], (str, tuple)):
                    ax1.legend(
                        data_handles,
                        data_labels,
                        loc=legend[1],
                        fontsize=8,
                        frameon=False,
                    )
                else:
                    ax1.legend(data_handles, data_labels, **legend[1])

            if leg_1 is not None:
                ax1.add_artist(leg_1)

        elif isinstance(legend, (str, tuple)):
            ax1.legend(loc=legend, fontsize=8, frameon=False)

        else:
            ax1.legend(**legend)

    if scale[0] == "log":
        ax1.xaxis.set_major_formatter(ScalarFormatter())

        if ax2 is not None:
            ax2.xaxis.set_major_formatter(ScalarFormatter())

        if ax3 is not None:
            ax3.xaxis.set_major_formatter(ScalarFormatter())

    # filters = ['Paranal/SPHERE.ZIMPOL_N_Ha',
    #            'MUSE/Hbeta',
    #            'ALMA/855']
    #
    # filters = ['Paranal/SPHERE.IRDIS_B_Y',
    #            'MKO/NSFCam.J',
    #            'Paranal/SPHERE.IRDIS_D_H23_2',
    #            'Paranal/SPHERE.IRDIS_D_H23_3',
    #            'Paranal/SPHERE.IRDIS_D_K12_1',
    #            'Paranal/SPHERE.IRDIS_D_K12_2',
    #            'Paranal/NACO.Lp',
    #            'Paranal/NACO.NB405',
    #            'Paranal/NACO.Mp']
    #
    # for i, item in enumerate(filters):
    #     readfilter = read_filter.ReadFilter(item)
    #     filter_wavelength = readfilter.mean_wavelength()
    #     filter_width = readfilter.filter_fwhm()
    #
    #     # if i == 5:
    #     #     ax1.errorbar(filter_wavelength, 1.3e4, xerr=filter_width/2., color='dimgray', elinewidth=2.5, zorder=10)
    #     # else:
    #     #     ax1.errorbar(filter_wavelength, 6e3, xerr=filter_width/2., color='dimgray', elinewidth=2.5, zorder=10)
    #
    #     if i == 0:
    #         ax1.text(filter_wavelength, 1e-2, r'H$\alpha$', ha='center', va='center', fontsize=10, color='black')
    #     elif i == 1:
    #         ax1.text(filter_wavelength, 1e-2, r'H$\beta$', ha='center', va='center', fontsize=10, color='black')
    #     elif i == 2:
    #         ax1.text(filter_wavelength, 1e-2, 'ALMA\nband 7 rms', ha='center', va='center', fontsize=8, color='black')
    #
    #     if i == 0:
    #         ax1.text(filter_wavelength, 1.4, 'Y', ha='center', va='center', fontsize=10, color='black')
    #     elif i == 1:
    #         ax1.text(filter_wavelength, 1.4, 'J', ha='center', va='center', fontsize=10, color='black')
    #     elif i == 2:
    #         ax1.text(filter_wavelength-0.04, 1.4, 'H2', ha='center', va='center', fontsize=10, color='black')
    #     elif i == 3:
    #         ax1.text(filter_wavelength+0.04, 1.4, 'H3', ha='center', va='center', fontsize=10, color='black')
    #     elif i == 4:
    #         ax1.text(filter_wavelength, 1.4, 'K1', ha='center', va='center', fontsize=10, color='black')
    #     elif i == 5:
    #         ax1.text(filter_wavelength, 1.4, 'K2', ha='center', va='center', fontsize=10, color='black')
    #     elif i == 6:
    #         ax1.text(filter_wavelength, 1.4, 'L$\'$', ha='center', va='center', fontsize=10, color='black')
    #     elif i == 7:
    #         ax1.text(filter_wavelength, 1.4, 'NB4.05', ha='center', va='center', fontsize=10, color='black')
    #     elif i == 8:
    #         ax1.text(filter_wavelength, 1.4, 'M$\'}$', ha='center', va='center', fontsize=10, color='black')
    #
    # ax1.text(1.26, 0.58, 'VLT/SPHERE', ha='center', va='center', fontsize=8., color='slateblue', rotation=43.)
    # ax1.text(2.5, 1.28, 'VLT/SINFONI', ha='left', va='center', fontsize=8., color='darkgray')

    print(" [DONE]")

    if output is None:
        plt.show()
    else:
        plt.savefig(output, bbox_inches="tight")

    plt.clf()
    plt.close()
예제 #28
0
def get_residuals(datatype,
                  spectrum,
                  parameters,
                  filters,
                  objectbox,
                  inc_phot=True,
                  inc_spec=False):
    """
    Parameters
    ----------
    datatype : str
        Data type ('model' or 'calibration').
    spectrum : str
        Name of the atmospheric model or calibration spectrum.
    parameters : dict
        Parameters and values for the spectrum
    filters : tuple(str, )
        Filter IDs. All available photometry of the object is used if set to None.
    objectbox : species.core.box.ObjectBox
        Box with the photometry and/or spectrum of an object.
    inc_phot : bool
        Include photometry.
    inc_spec : bool
        Include spectrum.

    Returns
    -------
    species.core.box.ResidualsBox
        Box with the photometry and/or spectrum residuals.
    """

    if filters is None:
        filters = objectbox.filter

    if inc_phot:
        model_phot = multi_photometry(datatype=datatype,
                                      spectrum=spectrum,
                                      filters=filters,
                                      parameters=parameters)

        res_phot = np.zeros((2, len(objectbox.flux)))

        for i, item in enumerate(filters):
            transmission = read_filter.ReadFilter(item)

            res_phot[0, i] = transmission.mean_wavelength()
            res_phot[1, i] = (objectbox.flux[item][0] -
                              model_phot.flux[item]) / objectbox.flux[item][1]

    else:
        res_phot = None

    sys.stdout.write('Calculating residuals...')
    sys.stdout.flush()

    if inc_spec:
        wl_range = (0.9 * objectbox.spectrum[0, 0],
                    1.1 * objectbox.spectrum[-1, 0])

        readmodel = read_model.ReadModel(spectrum, wl_range)
        model = readmodel.get_model(parameters)

        wl_new = objectbox.spectrum[:, 0]

        flux_new = spectres.spectres(new_spec_wavs=wl_new,
                                     old_spec_wavs=model.wavelength,
                                     spec_fluxes=model.flux,
                                     spec_errs=None)

        res_spec = np.zeros((2, objectbox.spectrum.shape[0]))

        res_spec[0, :] = wl_new
        res_spec[1, :] = (objectbox.spectrum[:, 1] -
                          flux_new) / objectbox.spectrum[:, 2]

    else:
        res_spec = None

    sys.stdout.write(' [DONE]\n')
    sys.stdout.flush()

    return box.create_box(boxtype='residuals',
                          name=objectbox.name,
                          photometry=res_phot,
                          spectrum=res_spec)
예제 #29
0
    def spectrum_to_photometry(self, wavelength, flux, threshold=None):
        """
        Parameters
        ----------
        wavelength : numpy.ndarray
            Wavelength (micron).
        flux : numpy.ndarray
            Flux density (W m-2 micron-1).
        threshold : float
            Transmission threshold (value between 0 and 1). If the minimum transmission value is
            larger than the threshold, a NaN is returned. This will happen if the input spectrum
            does not cover the full wavelength range of the filter profile. Not used if set to
            None.

        Returns
        -------
        float or numpy.ndarray
            Average flux density (W m-2 micron-1).
        """

        if not self.filter_interp:
            transmission = read_filter.ReadFilter(self.filter_name)
            self.filter_interp = transmission.interpolate()

            if self.wl_range is None:
                self.wl_range = transmission.wavelength_range()

        if isinstance(wavelength[0], (np.float32, np.float64)):
            indices = np.where((self.wl_range[0] < wavelength)
                               & (wavelength < self.wl_range[1]))[0]

            if indices.size == 1:
                raise ValueError(
                    "Calculating synthetic photometry requires more than one "
                    "wavelength point.")

            wavelength = wavelength[indices]
            flux = flux[indices]

            transmission = self.filter_interp(wavelength)

            indices = np.isnan(transmission)
            indices = np.logical_not(indices)

            integrand1 = transmission[indices] * flux[indices]
            integrand2 = transmission[indices]

            integral1 = np.trapz(integrand1, wavelength[indices])
            integral2 = np.trapz(integrand2, wavelength[indices])

            photometry = integral1 / integral2

        else:
            photometry = []

            for i, wl_item in enumerate(wavelength):
                indices = np.where((self.wl_range[0] <= wl_item)
                                   & (wl_item <= self.wl_range[1]))[0]

                if indices.size < 2:
                    photometry.append(np.nan)

                    warnings.warn(
                        'Calculating synthetic photometry requires more than one '
                        'wavelength point. Photometry is set to NaN.',
                        RuntimeWarning)

                else:
                    if threshold is None and (wl_item[0] > self.wl_range[0] or
                                              wl_item[-1] < self.wl_range[1]):

                        warnings.warn(
                            'Filter profile of ' + self.filter_name +
                            ' extends beyond the '
                            'spectrum (' + str(wl_item[0]) + '-' +
                            str(wl_item[-1]) + '). The '
                            'magnitude is set to NaN.', RuntimeWarning)

                        photometry.append(np.nan)

                    else:
                        wl_item = wl_item[indices]
                        flux_item = flux[i][indices]

                        transmission = self.filter_interp(wl_item)

                        if threshold is not None and (
                                transmission[0] > threshold
                                or transmission[-1] > threshold):

                            warnings.warn(
                                f'Filter profile of {self.filter_name} extends beyond '
                                f'the spectrum ({wl_item[0]} - {wl_item[-1]}). The '
                                f'magnitude is set to NaN.', RuntimeWarning)

                            photometry.append(np.nan)

                        else:
                            indices = np.isnan(transmission)
                            indices = np.logical_not(indices)

                            integrand1 = transmission[indices] * flux_item[
                                indices]
                            integrand2 = transmission[indices]

                            integral1 = np.trapz(integrand1, wl_item[indices])
                            integral2 = np.trapz(integrand2, wl_item[indices])

                            photometry.append(integral1 / integral2)

            photometry = np.asarray(photometry)

        return photometry
예제 #30
0
def add_spex(input_path, database):
    """
    Function for adding the SpeX Prism Spectral Library to the database.

    Parameters
    ----------
    input_path : str
        Path of the data folder.
    database : h5py._hl.files.File
        Database.

    Returns
    -------
    NoneType
        None
    """

    database.create_group('spectra/spex')

    data_path = os.path.join(input_path, 'spex')

    if not os.path.exists(data_path):
        os.makedirs(data_path)

    url_all = 'http://svo2.cab.inta-csic.es/vocats/v2/spex/' \
              'cs.php?RA=180.000000&DEC=0.000000&SR=180.000000&VERB=2'

    xml_file = os.path.join(data_path, 'spex.xml')

    urlretrieve(url_all, xml_file)

    table = parse_single_table(xml_file)
    name = table.array['name']
    twomass = table.array['name2m']
    url = table.array['access_url']

    os.remove(xml_file)

    for i, item in enumerate(url):
        xml_file = os.path.join(data_path, twomass[i].decode('utf-8') + '.xml')
        urlretrieve(item.decode('utf-8'), xml_file)

        table = parse_single_table(xml_file)
        name = table.array['ID']
        name = name[0].decode('utf-8')
        url = table.array['access_url']

        sys.stdout.write('\rDownloading SpeX Prism Spectral Library... ' +
                         '{:<40}'.format(name))
        sys.stdout.flush()

        os.remove(xml_file)

        xml_file = os.path.join(data_path, name + '.xml')
        urlretrieve(url[0].decode('utf-8'), xml_file)

    sys.stdout.write('\rDownloading SpeX Prism Spectral Library... ' +
                     '{:<40}'.format('[DONE]') + '\n')
    sys.stdout.flush()

    h_twomass = photometry.SyntheticPhotometry('2MASS/2MASS.H')

    transmission = read_filter.ReadFilter('2MASS/2MASS.H')
    transmission.get_filter()

    # 2MASS H band zero point for 0 mag (Cogen et al. 2003)
    h_zp = 1.133e-9  # [W m-2 micron-1]

    for votable in os.listdir(data_path):
        if votable.endswith('.xml'):
            xml_file = os.path.join(data_path, votable)

            table = parse_single_table(xml_file)

            wavelength = table.array['wavelength']  # [Angstrom]
            flux = table.array['flux']  # Normalized units

            wavelength = np.array(wavelength * 1e-4)  # [micron]
            flux = np.array(flux)

            # 2MASS magnitudes
            j_mag = table.get_field_by_id('jmag').value
            h_mag = table.get_field_by_id('hmag').value
            ks_mag = table.get_field_by_id('ksmag').value

            if j_mag == b'':
                j_mag = np.nan
            else:
                j_mag = float(j_mag)

            if h_mag == b'':
                h_mag = np.nan
            else:
                h_mag = float(h_mag)

            if ks_mag == b'':
                ks_mag = np.nan
            else:
                ks_mag = float(ks_mag)

            name = table.get_field_by_id('name').value
            name = name.decode('utf-8')
            twomass_id = table.get_field_by_id('name2m').value

            sys.stdout.write('\rAdding SpeX Prism Spectral Library... ' +
                             '{:<40}'.format(name))
            sys.stdout.flush()

            try:
                sptype = table.get_field_by_id('nirspty').value
                sptype = sptype.decode('utf-8')

            except KeyError:
                try:
                    sptype = table.get_field_by_id('optspty').value
                    sptype = sptype.decode('utf-8')

                except KeyError:
                    sptype = 'None'

            sptype = data_util.update_sptype(np.array([sptype]))[0]

            h_flux, _ = h_twomass.magnitude_to_flux(h_mag, None, h_zp)
            phot = h_twomass.spectrum_to_photometry(wavelength,
                                                    flux)  # Normalized units

            flux *= h_flux / phot  # [W m-2 micron-1]

            spdata = np.vstack((wavelength, flux))

            simbad_id, distance = queries.get_distance(
                '2MASS ' + twomass_id.decode('utf-8'))  # [pc]

            dset = database.create_dataset('spectra/spex/' + name, data=spdata)

            dset.attrs['name'] = str(name)
            dset.attrs['sptype'] = str(sptype)
            dset.attrs['simbad'] = str(simbad_id)
            dset.attrs['2MASS/2MASS.J'] = j_mag
            dset.attrs['2MASS/2MASS.H'] = h_mag
            dset.attrs['2MASS/2MASS.Ks'] = ks_mag
            dset.attrs['distance'] = distance  # [pc]

    sys.stdout.write('\rAdding SpeX Prism Spectral Library... ' +
                     '{:<40}'.format('[DONE]') + '\n')
    sys.stdout.flush()

    database.close()